Skip to content

Commit bf49ca6

Browse files
committed
fix: context injected, not session object
1 parent bd59cdc commit bf49ca6

File tree

8 files changed

+105
-89
lines changed

8 files changed

+105
-89
lines changed

‎dialog_lib/agents/abstract.py‎

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
parent_session_id=None,
2626
dataset=None,
2727
llm_api_key=None,
28-
dbsession=get_session(),
28+
dbsession=get_session,
2929
):
3030
"""
3131
:param config: Configuration dictionary
@@ -132,12 +132,13 @@ def document_prompt(self):
132132

133133
@property
134134
def retriever(self):
135-
return DialogRetriever(
136-
session=self.dbsession,
137-
embedding_llm=self.embedding_llm,
138-
threshold=self.cosine_similarity_threshold,
139-
top_k=self.top_k
140-
)
135+
with self.dbsession() as session:
136+
return DialogRetriever(
137+
session=session,
138+
embedding_llm=self.embedding_llm,
139+
threshold=self.cosine_similarity_threshold,
140+
top_k=self.top_k
141+
)
141142

142143
@property
143144
def model(self):
@@ -203,20 +204,22 @@ def answer_runnable(self):
203204

204205
@property
205206
def memory(self):
206-
return get_memory_instance(
207-
session_id=self.session_id,
208-
sqlalchemy_session=self.dbsession,
209-
database_url=self.config.get("database_url")
210-
)
207+
with self.dbsession() as session:
208+
return get_memory_instance(
209+
session_id=self.session_id,
210+
sqlalchemy_session=session,
211+
database_url=self.config.get("database_url")
212+
)
211213

212214
def get_session_history(self, something):
213-
return CustomPostgresChatMessageHistory(
214-
connection_string=self.config.get("database_url"),
215-
session_id=self.session_id,
216-
parent_session_id=self.parent_session_id,
217-
table_name="chat_messages",
218-
dbsession=self.dbsession,
219-
)
215+
with self.dbsession() as session:
216+
return CustomPostgresChatMessageHistory(
217+
connection_string=self.config.get("database_url"),
218+
session_id=self.session_id,
219+
parent_session_id=self.parent_session_id,
220+
table_name="chat_messages",
221+
dbsession=session,
222+
)
220223

221224
def chain_router(self, input):
222225
return self.answer_runnable if len(input["relevant_contents"]) > 0 else self.fallback_chain

‎dialog_lib/db/memory.py‎

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(
1515
self,
1616
*args,
1717
parent_session_id=None,
18-
dbsession=get_session(),
18+
dbsession=get_session,
1919
chats_model=Chat,
2020
chat_messages_model=ChatMessages,
2121
ssl_mode=None,
@@ -48,10 +48,10 @@ def create_tables(self) -> None:
4848

4949
def add_tags(self, tags: str) -> None:
5050
"""Add tags for a given session_id/uuid on chats table"""
51-
self.dbsession.query(self.chats_model).where(
52-
self.chats_model.session_id == self._session_id
53-
).update({getattr(self.chats_model, "tags"): tags})
54-
self.dbsession.commit()
51+
with self.dbsession() as session:
52+
session.query(self.chats_model).where(
53+
self.chats_model.session_id == self._session_id
54+
).update({getattr(self.chats_model, "tags"): tags})
5555

5656
def add_message(self, message: BaseMessage) -> None:
5757
"""Append the message to the record in PostgreSQL"""
@@ -61,13 +61,12 @@ def add_message(self, message: BaseMessage) -> None:
6161
if self.parent_session_id:
6262
message.parent = self.parent_session_id
6363
self.dbsession.add(message)
64-
self.dbsession.commit()
6564

6665

6766
def generate_memory_instance(
6867
session_id,
6968
parent_session_id=None,
70-
dbsession=get_session(),
69+
dbsession=get_session,
7170
database_url=None,
7271
chats_model=Chat,
7372
chat_messages_model=ChatMessages,
@@ -88,29 +87,31 @@ def generate_memory_instance(
8887

8988

9089
def add_user_message_to_message_history(
91-
session_id, message, memory=None, dbsession=get_session(), database_url=None
90+
session_id, message, memory=None, dbsession=get_session, database_url=None
9291
):
9392
"""
9493
Add a user message to the message history and returns the updated
9594
memory instance
9695
"""
97-
if not memory:
98-
memory = generate_memory_instance(
99-
session_id, dbsession=dbsession, database_url=database_url
100-
)
96+
with dbsession() as session:
97+
if not memory:
98+
memory = generate_memory_instance(
99+
session_id, dbsession=session, database_url=database_url
100+
)
101101

102-
memory.add_user_message(message)
103-
return memory
102+
memory.add_user_message(message)
103+
return memory
104104

105105

106-
def get_messages(session_id, dbsession=get_session(), database_url=None):
106+
def get_messages(session_id, dbsession=get_session, database_url=None):
107107
"""
108108
Get all messages for a given session_id
109109
"""
110-
memory = generate_memory_instance(
111-
session_id, dbsession=dbsession, database_url=database_url
112-
)
113-
return memory.messages
110+
with dbsession() as session:
111+
memory = generate_memory_instance(
112+
session_id, dbsession=session, database_url=database_url
113+
)
114+
return memory.messages
114115

115116
def get_memory_instance(session_id, sqlalchemy_session, database_url):
116117
return generate_memory_instance(

‎dialog_lib/db/session.py‎

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
import os
22

33
import sqlalchemy as sa
4-
from sqlalchemy.orm import Session
4+
from sqlalchemy.orm import Session, sessionmaker
55

6+
from contextlib import contextmanager
7+
8+
engine = sa.create_engine(os.environ.get("DATABASE_URL"))
9+
Session = sessionmaker(bind=engine)
10+
11+
@contextmanager
612
def get_session():
7-
engine = sa.create_engine(os.environ.get("DATABASE_URL"))
8-
session = Session(engine)
9-
yield session
10-
session.close()
13+
session = Session()
14+
try:
15+
yield session
16+
session.flush()
17+
session.commit()
18+
except Exception as e:
19+
session.rollback()
20+
raise e
21+
finally:
22+
session.close()

‎dialog_lib/db/utils.py‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from .models import Chat
44

55

6-
def create_chat_session(identifier=None, dbsession=get_session(), model=Chat):
6+
def create_chat_session(identifier=None, dbsession=get_session, model=Chat):
77
if identifier is None:
88
identifier = uuid.uuid4().hex
99

10-
chat = dbsession.query(model).filter_by(session_id=identifier).first()
11-
if not chat:
12-
chat = model(session_id=identifier)
13-
dbsession.add(chat)
14-
dbsession.commit()
10+
with dbsession() as session:
11+
chat = session.query(model).filter_by(session_id=identifier).first()
12+
if not chat:
13+
chat = model(session_id=identifier)
14+
session.add(chat)
1515

1616
return {"chat_id": chat.session_id}

‎dialog_lib/loaders/csv.py‎

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def load_csv(
10-
file_path, dbsession=get_session(), embeddings_model_instance=None,
10+
file_path, dbsession=get_session, embeddings_model_instance=None,
1111
embedding_llm_model=None, embedding_llm_api_key=None, company_id=None
1212
):
1313

@@ -20,22 +20,22 @@ def load_csv(
2020
else:
2121
raise ValueError("Invalid embeddings model")
2222

23-
for csv_content in contents:
24-
content = {}
23+
with dbsession() as session:
24+
for csv_content in contents:
25+
content = {}
2526

26-
for line in csv_content.page_content.split("\n"):
27-
values = line.split(": ")
28-
content[values[0]] = values[1]
27+
for line in csv_content.page_content.split("\n"):
28+
values = line.split(": ")
29+
content[values[0]] = values[1]
2930

30-
company_content = CompanyContent(
31-
category="csv",
32-
subcategory="csv-content",
33-
question=content["question"],
34-
content=content["content"],
35-
dataset=company_id,
36-
embedding=generate_embedding(csv_content.page_content, embeddings_model_instance)
37-
)
38-
dbsession.add(company_content)
31+
company_content = CompanyContent(
32+
category="csv",
33+
subcategory="csv-content",
34+
question=content["question"],
35+
content=content["content"],
36+
dataset=company_id,
37+
embedding=generate_embedding(csv_content.page_content, embeddings_model_instance)
38+
)
39+
session.add(company_content)
3940

40-
dbsession.commit()
4141
return company_content

‎dialog_lib/loaders/web.py‎

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44
from langchain_community.document_loaders import WebBaseLoader
55

66

7-
def load_webpage(url, embeddings_model_instance, dbsession=get_session(), company_id=None):
7+
def load_webpage(url, embeddings_model_instance, dbsession=get_session, company_id=None):
88
loader = WebBaseLoader(url)
99
contents = loader.load()
1010

11-
for url_content in contents:
12-
company_content = CompanyContent(
13-
link=url,
14-
category="web",
15-
subcategory="website-content",
16-
question=url_content.metadata["title"],
17-
content=url_content.page_content,
18-
dataset=company_id,
19-
embedding=generate_embedding(url_content.page_content, embeddings_model_instance)
20-
)
21-
dbsession.add(company_content)
22-
dbsession.flush()
11+
with dbsession() as session:
12+
for url_content in contents:
13+
company_content = CompanyContent(
14+
link=url,
15+
category="web",
16+
subcategory="website-content",
17+
question=url_content.metadata["title"],
18+
content=url_content.page_content,
19+
dataset=company_id,
20+
embedding=generate_embedding(url_content.page_content, embeddings_model_instance)
21+
)
22+
session.add(company_content)
2323

2424
return company_content

‎dialog_lib/tests/conftest.py‎

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55
from aioresponses import aioresponses
66
from sqlalchemy.orm import Session
77
from dialog_lib.db.models import Base
8+
from dialog_lib.db import get_session
89

910

1011
@pytest.fixture
1112
def db_engine():
1213
return sqlalchemy.create_engine(os.environ.get('DATABASE_URL'))
1314

15+
1416
@pytest.fixture
15-
def db_session(db_engine):
16-
Base.metadata.create_all(db_engine)
17-
session = Session(db_engine)
18-
yield session
19-
session.rollback()
20-
session.close()
17+
def db_session():
18+
return get_session
19+
2120

2221
@pytest.fixture
2322
def mock_aioresponse():

‎dialog_lib/tests/loaders/test_web_loader.py‎

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ def test_load_web_content(mock_aioresponse, db_session, mocker):
99
mocker.patch('dialog_lib.loaders.web.generate_embedding', return_value=[0] * 1536)
1010
mock_aioresponse.get('http://example.com', body='Hello, world!')
1111

12-
content = load_webpage('http://example.com', None, db_session, 1)
13-
assert content.question == "Example Domain"
14-
assert content.embedding == [0] * 1536
12+
load_webpage('http://example.com', None, db_session, 1)
13+
14+
with db_session() as session:
15+
content = session.query(CompanyContent).first()
16+
assert content.question == "Example Domain"
17+
assert content.embedding.tolist() == [0]*1536
1518

16-
content = db_session.query(CompanyContent).all()
17-
assert len(content) == 1

0 commit comments

Comments
 (0)