Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain.memory.buffer import ConversationBufferMemory
from learn.idf import categorize_conversation_history
from llm.memory import generate_memory_instance
from models import CompanyContent
Expand Down Expand Up @@ -90,20 +91,23 @@ def process_user_intent(session_id, message):
messages=prompt_templating
)

psql_memory = generate_memory_instance(session_id)
chat_memory = generate_memory_instance(session_id)
memory = ConversationBufferMemory(
chat_memory=chat_memory,
memory_key="chat_history",
return_messages=True
)
conversation = LLMChain(
llm=CHAT_LLM,
prompt=prompt,
verbose=VERBOSE_LLM
verbose=VERBOSE_LLM,
memory=memory
)
ai_message = conversation({
"user_message": message,
"chat_history": psql_memory.messages
})
user_message = psql_memory.add_user_message(message)
psql_memory.add_ai_message(ai_message["text"], user_message.id)

# categorize conversation history in background
asyncio.create_task(categorize_conversation_history(psql_memory))
asyncio.create_task(categorize_conversation_history(chat_memory))

return ai_message
45 changes: 12 additions & 33 deletions src/llm/memory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@

from langchain.memory import PostgresChatMessageHistory
from langchain.schema.messages import (
AIMessage,
BaseMessage,
HumanMessage,
_message_to_dict,
)
from models import Chat, ChatMessages
Expand All @@ -15,6 +13,10 @@ class CustomPostgresChatMessageHistory(PostgresChatMessageHistory):
"""
Custom chat message history for LLM
"""
def __init__(self, *args, parent_session_id=None, **kwargs):
self.parent_session_id = parent_session_id
super().__init__(*args, **kwargs)

def _create_table_if_not_exists(self) -> None:
"""
create table if it does not exist
Expand All @@ -34,46 +36,23 @@ def add_tags(self, tags: str) -> None:
session.query(Chat).where(Chat.uuid == self.session_id).update({Chat.tags: tags})
session.commit()

def add_message(self, message: BaseMessage, parent_id: int = None) -> ChatMessages:
"""Append the message to the record in PostgreSQL
returning the ChatMessages created for use in the parent logic
"""
values = {"session_id": self.session_id, "message": _message_to_dict(message)}
if parent_id:
values["parent"] = parent_id
new_message = ChatMessages(**values)
session.add(new_message)
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in PostgreSQL"""
message = ChatMessages(session_id=self.session_id, message=_message_to_dict(message))
if self.parent_session_id:
message.parent = self.parent_session_id
session.add(message)
session.commit()
return new_message

def add_user_message(self, message: str) -> ChatMessages:
"""Convenience method for adding a human message string to the store.

Args:
message: The string contents of a human message.
"""
new_message = self.add_message(HumanMessage(content=message))
return new_message

def add_ai_message(self, message: str, parent_id: int) -> None:
"""Convenience method for adding an AI message string to the store.

Args:
message: The string contents of an AI message.
"""
values = {"message": AIMessage(content=message)}
if parent_id:
values["parent_id"] = parent_id
self.add_message(**values)


def generate_memory_instance(session_id):
def generate_memory_instance(session_id, parent_session_id=None):
"""
Generate a memory instance for a given session_id
"""
return CustomPostgresChatMessageHistory(
connection_string=DATABASE_URL,
session_id=session_id,
parent_session_id=parent_session_id,
table_name="chat_messages"
)

Expand Down