Skip to content

Commit 8540bdf

Browse files
committed
Fix testing structure for load_csv and removes unnecessary commits from session
1 parent 824ad09 commit 8540bdf

File tree

6 files changed

+13
-20
lines changed

6 files changed

+13
-20
lines changed

‎src/dialog/db/__init__.py‎

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@ def session_scope():
1515
with Session(bind=engine) as session:
1616
try:
1717
yield session
18-
session.commit()
1918
except Exception as exc:
2019
session.rollback()
2120
raise exc
22-
finally:
23-
session.close()
2421

2522
def get_session():
2623
with session_scope() as session:

‎src/dialog/routers/openai.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def ask_question_to_llm(message: OpenAIChat, session: Session = Depends(ge
5555
session_id = Settings().OPENWEB_UI_SESSION,
5656
)
5757
session.add(new_chat)
58-
session.flush()
58+
session.commit()
5959
else:
6060
logging.info("Using old chat entity")
6161
new_chat = chat_entity
@@ -72,7 +72,7 @@ async def ask_question_to_llm(message: OpenAIChat, session: Session = Depends(ge
7272
message=_message.content,
7373
)
7474
session.add(new_message)
75-
session.flush()
75+
session.commit()
7676

7777
process_user_message_args = {
7878
"message": non_empty_messages[-1].content,

‎src/load_csv.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def load_csv_and_generate_embeddings(
150150
]
151151
session.add_all(company_contents)
152152
session.commit()
153+
return session.query(CompanyContent).all()
153154

154155

155156
if __name__ == "__main__":

‎src/tests/conftest.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def dbsession(mocker):
3030

3131
with Session() as session:
3232
yield session
33-
session.rollback()
3433

3534
Base.metadata.drop_all(bind=engine)
3635

‎src/tests/test_load_csv.py‎

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ def test_load_csv(mocker, dbsession, csv_file: str):
4747
[0.2] * 1536,
4848
] # 1536 is the expected dimension of the embeddings
4949

50-
load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
50+
result = load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
5151

52-
result = dbsession.query(load_csv.CompanyContent).all()
5352
assert len(result) == 2
5453

5554

@@ -66,7 +65,7 @@ def test_multiple_columns_embedding(mocker, dbsession, csv_file: str):
6665

6766
mock_generate_embeddings.assert_called_with(
6867
[
69-
"category: cat1\nsubcategory: subcat1\ncontent: content1",
68+
"category: cat1\nsubcategory: subcat1\ncontent: content1",
7069
"category: cat2\nsubcategory: subcat2\ncontent: content2"
7170
],
7271
embedding_llm_instance=load_csv.EMBEDDINGS_LLM,
@@ -80,20 +79,17 @@ def test_clear_db(mocker, dbsession, csv_file: str):
8079
[0.2] * 1536,
8180
] # 1536 is the expected dimension of the embeddings
8281

83-
load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
84-
initial_run = dbsession.query(load_csv.CompanyContent).all()
82+
initial_run = load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
8583

86-
load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
87-
clear_db_run = dbsession.query(load_csv.CompanyContent).all()
84+
clear_db_run = load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
8885

8986
other_csv_file = _create_csv(
9087
data=[
9188
["cat3", "subcat3", "q3", "content3", "dataset3"],
9289
["cat4", "subcat4", "q4", "content4", "dataset4"],
9390
]
9491
)
95-
load_csv.load_csv_and_generate_embeddings(other_csv_file, cleardb=False)
96-
dont_clear_db_run = dbsession.query(load_csv.CompanyContent).all()
92+
dont_clear_db_run = load_csv.load_csv_and_generate_embeddings(other_csv_file, cleardb=False)
9793

9894
assert len(initial_run) == 2
9995
assert len(clear_db_run) == 2
@@ -122,7 +118,7 @@ def test_documents_to_company_content():
122118
"link": "http://test_link"
123119
}
124120
)
125-
121+
126122
# Define a mock embedding
127123
embedding = [0.1] * 1536 # Example embedding
128124

@@ -155,7 +151,7 @@ def test_get_document_pk():
155151
"link": "http://test_link"
156152
}
157153
)
158-
154+
159155
# Define the fields to be used for primary key generation
160156
pk_metadata_fields = ["category", "subcategory", "question"]
161157

@@ -172,7 +168,7 @@ def test_get_document_pk():
172168
def test_load_csv_with_metadata(csv_file: str):
173169
metadata_columns = ["category", "subcategory", "question", "dataset"]
174170
embed_columns = ["content"]
175-
171+
176172
# Call the function to test
177173
docs = load_csv.load_csv_with_metadata(csv_file, embed_columns, metadata_columns)
178174

@@ -195,7 +191,7 @@ def test_load_csv_with_metadata(csv_file: str):
195191
"content": "content2",
196192
}
197193

198-
def test_retrieve_docs_from_vectordb(mocker):
194+
def test_retrieve_docs_from_vectordb(mocker, dbsession):
199195
# Create mock CompanyContent objects
200196
mock_company_contents = [
201197
CompanyContent(

‎src/tests/test_views.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_ask_question_no_session_id(client, mocker, llm_mock, dbsession):
3131
def test_get_chat_content(client, chat_session, dbsession):
3232
chat = ChatMessages(session_id=chat_session["chat_id"], message="Hello")
3333
dbsession.add(chat)
34-
dbsession.flush()
34+
dbsession.commit()
3535
response = client.get(f"/chat/{chat_session['chat_id']}")
3636
assert response.status_code == 200
3737
assert "message" in response.json()

0 commit comments

Comments
 (0)