80 lines
2.8 KiB
Python
80 lines
2.8 KiB
Python
from langchain_openai import ChatOpenAI
|
|
from langchain_deepseek import ChatDeepSeek
|
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
|
from pxy_neo4j.neo4j_connector import Neo4jDatabase # Import Neo4j connector
|
|
|
|
class LangchainAIService:
|
|
"""
|
|
A service to dynamically interact with AI models using LangChain.
|
|
"""
|
|
|
|
def __init__(self, assistant):
|
|
"""
|
|
Initialize the AI service with an AIAssistant instance.
|
|
"""
|
|
self.assistant = assistant
|
|
self.client = self.load_model()
|
|
self.graph = None
|
|
|
|
# Initialize Neo4j connection if graph is enabled
|
|
if self.assistant.uses_graph and self.assistant.neo4j_profile:
|
|
self.graph = Neo4jDatabase(profile_name=self.assistant.neo4j_profile.name)
|
|
|
|
def load_model(self):
|
|
"""
|
|
Dynamically load the AI model based on the provider.
|
|
"""
|
|
if not self.assistant.provider:
|
|
raise ValueError("Assistant provider is not set.")
|
|
|
|
provider_name = self.assistant.provider.name.lower()
|
|
|
|
if provider_name == "openai":
|
|
return ChatOpenAI(
|
|
openai_api_key=self.assistant.api_key,
|
|
model_name=self.assistant.model_name
|
|
)
|
|
elif provider_name == "deepseek":
|
|
return ChatDeepSeek(
|
|
model=self.assistant.model_name,
|
|
temperature=0.7,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
max_retries=2,
|
|
api_key=self.assistant.api_key
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {provider_name}")
|
|
|
|
def generate_response(self, user_message):
|
|
"""
|
|
Generate a response using the AI model and optionally query the graph.
|
|
"""
|
|
try:
|
|
system_role = self.assistant.description or "You are a helpful assistant."
|
|
graph_data = None
|
|
|
|
# 🔥 Step 1: Query the graph if enabled
|
|
if self.graph:
|
|
graph_data = self.graph.query_graph(user_message)
|
|
|
|
# 🔥 Step 2: Prepare messages for AI model
|
|
messages = [
|
|
SystemMessage(content=system_role),
|
|
HumanMessage(content=user_message),
|
|
]
|
|
|
|
# 🔥 Step 3: Inject graph response if available
|
|
if graph_data:
|
|
messages.append(AIMessage(content=f"Graph response: {graph_data}"))
|
|
|
|
# 🔥 Step 4: Generate AI response
|
|
response = self.client.invoke(messages)
|
|
|
|
if isinstance(response, AIMessage):
|
|
return response.content
|
|
return str(response)
|
|
|
|
except Exception as e:
|
|
return f"Error generating response: {e}"
|