Skip to content

Commit 71ff378

Browse files
committed
Changes session
1 parent a123b30 commit 71ff378

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

‎src/dialog/llm/abstract_llm.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from langchain.chains.llm import LLMChain
22
from langchain.memory.chat_memory import BaseChatMemory
33
from langchain.prompts import ChatPromptTemplate
4+
from dialog.models.db import get_session
45

56

67
class AbstractLLM:
@@ -25,7 +26,7 @@ def __init__(self, config, session_id=None, parent_session_id=None, dataset=None
2526
self.dataset = dataset
2627
self.llm_key = llm_key
2728
self.parent_session_id = parent_session_id
28-
self.dbsession = dbsession
29+
self.dbsession = dbsession or next(get_session())
2930

3031
def get_prompt(self, input) -> ChatPromptTemplate:
3132
"""

‎src/dialog/llm/embeddings.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def generate_embedding(document: str):
2424

2525

2626
def get_most_relevant_contents_from_message(message, top=5, dataset=None, session=None):
27+
if session is None:
28+
session = next(get_session())
29+
2730
message_embedding = generate_embedding(message)
2831
filters = [
2932
CompanyContent.embedding.cosine_distance(message_embedding) < Settings().COSINE_SIMILARITY_THRESHOLD,

‎src/dialog/llm/memory.py‎

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class CustomPostgresChatMessageHistory(PostgresChatMessageHistory):
1313

1414
def __init__(self, *args, parent_session_id=None, dbsession=None, **kwargs):
1515
self.parent_session_id = parent_session_id
16-
self.dbsession = dbsession
16+
self.dbsession = dbsession or next(get_session())
1717
super().__init__(*args, **kwargs)
1818

1919
def _create_table_if_not_exists(self) -> None:
@@ -52,6 +52,9 @@ def generate_memory_instance(session_id, parent_session_id=None, dbsession=None)
5252
"""
5353
Generate a memory instance for a given session_id
5454
"""
55+
if not dbsession:
56+
dbsession = next(get_session())
57+
5558
return CustomPostgresChatMessageHistory(
5659
connection_string=Settings().DATABASE_URL,
5760
session_id=session_id,
@@ -69,6 +72,9 @@ def add_user_message_to_message_history(session_id, message, memory=None, dbsess
6972
if not memory:
7073
memory = generate_memory_instance(session_id)
7174

75+
if not dbsession:
76+
dbsession = next(get_session())
77+
7278
memory.add_user_message(message)
7379
return memory
7480

@@ -77,5 +83,8 @@ def get_messages(session_id, dbsession=None):
7783
"""
7884
Get all messages for a given session_id
7985
"""
86+
if not dbsession:
87+
dbsession = next(get_session())
88+
8089
memory = generate_memory_instance(session_id, dbsession=dbsession)
8190
return memory.messages

0 commit comments

Comments
 (0)