121 lines
5.0 KiB
Python
121 lines
5.0 KiB
Python
import os
|
|
import logging
|
|
import traceback
|
|
from datetime import datetime
|
|
from django.conf import settings
|
|
from neo4j import GraphDatabase
|
|
from langchain_neo4j import Neo4jGraph # ✅ Corrected import
|
|
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
|
|
from neo4j_graphrag.llm import OpenAILLM
|
|
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
|
|
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
|
|
from .models import Neo4jProfile
|
|
import asyncio
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_PROFILE_NAME = "DefaultNeo4jProfile" # Define a default profile
|
|
|
|
|
|
class Neo4jDatabase:
|
|
"""
|
|
Handles connection to Neo4j using different profiles for various use cases.
|
|
"""
|
|
|
|
def __init__(self, profile_name=None):
|
|
try:
|
|
# Load the specified Neo4j profile, or fallback to the default
|
|
if profile_name:
|
|
self.profile = Neo4jProfile.objects.get(name=profile_name)
|
|
logger.warning(f"✅ Loaded specified profile: {self.profile}")
|
|
else:
|
|
logger.warning("⚠️ No profile specified. Using default Neo4j profile.")
|
|
self.profile = Neo4jProfile.objects.get(name=DEFAULT_PROFILE_NAME)
|
|
|
|
logger.info(f"🚀 Initializing Neo4jGraph with URI: {self.profile.uri} and username: {self.profile.username}")
|
|
|
|
# ✅ Ensure Neo4j connection is passed
|
|
self.driver = GraphDatabase.driver(
|
|
self.profile.uri,
|
|
auth=(self.profile.username, self.profile.password)
|
|
)
|
|
|
|
# ✅ Initialize OpenAI Embeddings
|
|
self.embedder = OpenAIEmbeddings(api_key=self.profile.openai_api_key)
|
|
|
|
# ✅ FIX: Directly use OpenAILLM with correct parameters
|
|
self.llm = OpenAILLM(
|
|
model_name=self.profile.model_name, # ✅ Ensure model_name is passed
|
|
api_key=self.profile.openai_api_key,
|
|
model_params={"temperature": 0}
|
|
)
|
|
|
|
# ✅ Correctly Initialize SimpleKGPipeline
|
|
self.kg_builder = SimpleKGPipeline(
|
|
llm=self.llm, # ✅ Directly pass OpenAILLM
|
|
driver=self.driver, # ✅ Pass the Neo4j driver
|
|
embedder=self.embedder, # ✅ Pass the OpenAI embedder
|
|
from_pdf=False # ✅ Keep text-based processing
|
|
)
|
|
|
|
logger.info(f"✅ Neo4jDatabase initialized successfully with profile: {self.profile.name}")
|
|
|
|
except Neo4jProfile.DoesNotExist:
|
|
logger.error(f"❌ Neo4j profile '{profile_name}' not found.")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to initialize Neo4jDatabase: {str(e)}")
|
|
logger.error(traceback.format_exc()) # Log the full stack trace
|
|
raise
|
|
|
|
def store_interaction(self, user_id, bot_id, user_message, bot_response, platform):
|
|
"""
|
|
✅ Extracts structured knowledge graphs from chatbot interactions and stores them in Neo4j (Fully Async).
|
|
"""
|
|
try:
|
|
timestamp = datetime.utcnow().isoformat()
|
|
combined_text = f"User: {user_message}\nBot: {bot_response}" # ✅ Combine texts
|
|
|
|
logger.info(f"📡 Processing interaction for KG extraction: {combined_text}")
|
|
|
|
# ✅ Run the knowledge graph extraction asynchronously
|
|
kg_results = asyncio.run(self.kg_builder.run_async(text=combined_text))
|
|
|
|
# ✅ Log the extracted queries (for debugging)
|
|
if not kg_results:
|
|
logger.warning("⚠️ No knowledge graph extracted. Skipping storage.")
|
|
return
|
|
|
|
logger.info(f"📌 Extracted KG Queries: {kg_results}")
|
|
|
|
# ✅ Let `kg_builder` handle storage automatically (Like your Colab Code)
|
|
logger.info(f"✅ Knowledge graph successfully extracted and stored (Profile: {self.profile.name})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to process and store interaction in Neo4j: {str(e)}")
|
|
logger.error(traceback.format_exc()) # Log full traceback
|
|
raise
|
|
|
|
def query_graph(self, user_query):
|
|
"""
|
|
✅ Queries the graph using GraphCypherQAChain and returns a structured response.
|
|
"""
|
|
logger.info(f"🔎 Querying Neo4j: {user_query}")
|
|
try:
|
|
# ✅ Use AI model to generate Cypher query
|
|
qa_chain = GraphCypherQAChain.from_llm(
|
|
llm=self.llm, # ✅ Ensure correct LLM instance
|
|
graph=self.driver,
|
|
verbose=True,
|
|
allow_dangerous_requests=True
|
|
)
|
|
|
|
result = qa_chain.invoke({"query": user_query})
|
|
logger.info(f"✅ Query Result: {result}")
|
|
return result.get('result', None)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Graph query failed: {e}")
|
|
logger.error(traceback.format_exc()) # Log full stack trace
|
|
return None
|