103 lines
3.6 KiB
Python
103 lines
3.6 KiB
Python
import json
|
|
|
|
import redis
|
|
from langchain.memory import ConversationSummaryBufferMemory
|
|
from langchain.schema import AIMessage, HumanMessage
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from rag.settings import settings
|
|
|
|
|
|
class ChatMemoryManager:
|
|
def __init__(self, llm, token_limit=3000):
|
|
self.redis = redis.Redis(
|
|
host=settings.redis_cache_url,
|
|
port=settings.redis_cache_port,
|
|
db=settings.redis_cache_db,
|
|
)
|
|
self.llm = llm
|
|
self.token_limit = token_limit
|
|
|
|
def _convert_to_langchain(self, messages: list[dict]):
|
|
return [
|
|
AIMessage(content=msg["content"])
|
|
if msg["is_ai"]
|
|
else HumanMessage(content=msg["content"])
|
|
for msg in messages
|
|
]
|
|
|
|
def _annotate_messages(self, messages: list):
|
|
# Convert to format compatible with langchain
|
|
# Assuming messages have some way to identify if they're from AI
|
|
return [
|
|
{
|
|
**msg,
|
|
"is_ai": msg.get("user_type") == "AI"
|
|
or msg.get("username") == "SOMMELIER",
|
|
}
|
|
for msg in messages
|
|
]
|
|
|
|
def _serialize_messages(self, messages: list[dict]):
|
|
return [
|
|
{**msg, "created_at": msg["created_at"].isoformat()} for msg in messages
|
|
]
|
|
|
|
def _cache_key(self, session_id: int) -> str:
|
|
return f"chat_memory:{session_id}"
|
|
|
|
async def load_chat_history(
|
|
self, session_id: int, session: AsyncSession
|
|
) -> list[HumanMessage | AIMessage]:
|
|
cache_key = self._cache_key(session_id)
|
|
serialized = self.redis.get(cache_key)
|
|
|
|
if serialized:
|
|
cached_messages = json.loads(serialized)
|
|
if cached_messages:
|
|
# last_time = datetime.fromisoformat(cached_messages[-1]["created_at"])
|
|
|
|
# TODO: Replace with actual Message model query when available
|
|
# This would need to be implemented with SQLModel/SQLAlchemy
|
|
new_messages = [] # Placeholder for actual DB query
|
|
|
|
if new_messages:
|
|
annotated_messages = self._annotate_messages(new_messages)
|
|
all_messages = cached_messages + self._serialize_messages(
|
|
annotated_messages
|
|
)
|
|
self.redis.setex(cache_key, 3600, json.dumps(all_messages))
|
|
return self._convert_to_langchain(all_messages)
|
|
|
|
return self._convert_to_langchain(cached_messages)
|
|
|
|
# TODO: Replace with actual Message model query when available
|
|
# This would need to be implemented with SQLModel/SQLAlchemy
|
|
db_messages = [] # Placeholder for actual DB query
|
|
|
|
if db_messages:
|
|
annotated_messages = self._annotate_messages(db_messages)
|
|
self.redis.setex(
|
|
cache_key,
|
|
3600,
|
|
json.dumps(self._serialize_messages(annotated_messages)),
|
|
)
|
|
return self._convert_to_langchain(annotated_messages)
|
|
|
|
return []
|
|
|
|
async def get_session_memory(
|
|
self, session_id: int, session: AsyncSession
|
|
) -> ConversationSummaryBufferMemory:
|
|
memory = ConversationSummaryBufferMemory(
|
|
llm=self.llm, max_token_limit=self.token_limit
|
|
)
|
|
|
|
messages = await self.load_chat_history(session_id, session)
|
|
for msg in messages:
|
|
if isinstance(msg, HumanMessage):
|
|
memory.chat_memory.add_user_message(msg.content)
|
|
elif isinstance(msg, AIMessage):
|
|
memory.chat_memory.add_ai_message(msg.content)
|
|
return memory
|