ai-hackaton-backend/rag/memory.py

103 lines
3.6 KiB
Python

import json
import redis
from langchain.memory import ConversationSummaryBufferMemory
from langchain.schema import AIMessage, HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from rag.settings import settings
class ChatMemoryManager:
def __init__(self, llm, token_limit=3000):
self.redis = redis.Redis(
host=settings.redis_cache_url,
port=settings.redis_cache_port,
db=settings.redis_cache_db,
)
self.llm = llm
self.token_limit = token_limit
def _convert_to_langchain(self, messages: list[dict]):
return [
AIMessage(content=msg["content"])
if msg["is_ai"]
else HumanMessage(content=msg["content"])
for msg in messages
]
def _annotate_messages(self, messages: list):
# Convert to format compatible with langchain
# Assuming messages have some way to identify if they're from AI
return [
{
**msg,
"is_ai": msg.get("user_type") == "AI"
or msg.get("username") == "SOMMELIER",
}
for msg in messages
]
def _serialize_messages(self, messages: list[dict]):
return [
{**msg, "created_at": msg["created_at"].isoformat()} for msg in messages
]
def _cache_key(self, session_id: int) -> str:
return f"chat_memory:{session_id}"
async def load_chat_history(
self, session_id: int, session: AsyncSession
) -> list[HumanMessage | AIMessage]:
cache_key = self._cache_key(session_id)
serialized = self.redis.get(cache_key)
if serialized:
cached_messages = json.loads(serialized)
if cached_messages:
# last_time = datetime.fromisoformat(cached_messages[-1]["created_at"])
# TODO: Replace with actual Message model query when available
# This would need to be implemented with SQLModel/SQLAlchemy
new_messages = [] # Placeholder for actual DB query
if new_messages:
annotated_messages = self._annotate_messages(new_messages)
all_messages = cached_messages + self._serialize_messages(
annotated_messages
)
self.redis.setex(cache_key, 3600, json.dumps(all_messages))
return self._convert_to_langchain(all_messages)
return self._convert_to_langchain(cached_messages)
# TODO: Replace with actual Message model query when available
# This would need to be implemented with SQLModel/SQLAlchemy
db_messages = [] # Placeholder for actual DB query
if db_messages:
annotated_messages = self._annotate_messages(db_messages)
self.redis.setex(
cache_key,
3600,
json.dumps(self._serialize_messages(annotated_messages)),
)
return self._convert_to_langchain(annotated_messages)
return []
async def get_session_memory(
self, session_id: int, session: AsyncSession
) -> ConversationSummaryBufferMemory:
memory = ConversationSummaryBufferMemory(
llm=self.llm, max_token_limit=self.token_limit
)
messages = await self.load_chat_history(session_id, session)
for msg in messages:
if isinstance(msg, HumanMessage):
memory.chat_memory.add_user_message(msg.content)
elif isinstance(msg, AIMessage):
memory.chat_memory.add_ai_message(msg.content)
return memory