@@ -47,9 +47,8 @@ def test_load_csv(mocker, dbsession, csv_file: str):
4747 [0.2 ] * 1536 ,
4848 ] # 1536 is the expected dimension of the embeddings
4949
50- load_csv .load_csv_and_generate_embeddings (csv_file , cleardb = True )
50+ result = load_csv .load_csv_and_generate_embeddings (csv_file , cleardb = True )
5151
52- result = dbsession .query (load_csv .CompanyContent ).all ()
5352 assert len (result ) == 2
5453
5554
@@ -66,7 +65,7 @@ def test_multiple_columns_embedding(mocker, dbsession, csv_file: str):
6665
6766 mock_generate_embeddings .assert_called_with (
6867 [
69- "category: cat1\n subcategory: subcat1\n content: content1" ,
68+ "category: cat1\n subcategory: subcat1\n content: content1" ,
7069 "category: cat2\n subcategory: subcat2\n content: content2"
7170 ],
7271 embedding_llm_instance = load_csv .EMBEDDINGS_LLM ,
@@ -80,20 +79,17 @@ def test_clear_db(mocker, dbsession, csv_file: str):
8079 [0.2 ] * 1536 ,
8180 ] # 1536 is the expected dimension of the embeddings
8281
83- load_csv .load_csv_and_generate_embeddings (csv_file , cleardb = True )
84- initial_run = dbsession .query (load_csv .CompanyContent ).all ()
82+ initial_run = load_csv .load_csv_and_generate_embeddings (csv_file , cleardb = True )
8583
86- load_csv .load_csv_and_generate_embeddings (csv_file , cleardb = True )
87- clear_db_run = dbsession .query (load_csv .CompanyContent ).all ()
84+ clear_db_run = load_csv .load_csv_and_generate_embeddings (csv_file , cleardb = True )
8885
8986 other_csv_file = _create_csv (
9087 data = [
9188 ["cat3" , "subcat3" , "q3" , "content3" , "dataset3" ],
9289 ["cat4" , "subcat4" , "q4" , "content4" , "dataset4" ],
9390 ]
9491 )
95- load_csv .load_csv_and_generate_embeddings (other_csv_file , cleardb = False )
96- dont_clear_db_run = dbsession .query (load_csv .CompanyContent ).all ()
92+ dont_clear_db_run = load_csv .load_csv_and_generate_embeddings (other_csv_file , cleardb = False )
9793
9894 assert len (initial_run ) == 2
9995 assert len (clear_db_run ) == 2
@@ -122,7 +118,7 @@ def test_documents_to_company_content():
122118 "link" : "http://test_link"
123119 }
124120 )
125-
121+
126122 # Define a mock embedding
127123 embedding = [0.1 ] * 1536 # Example embedding
128124
@@ -155,7 +151,7 @@ def test_get_document_pk():
155151 "link" : "http://test_link"
156152 }
157153 )
158-
154+
159155 # Define the fields to be used for primary key generation
160156 pk_metadata_fields = ["category" , "subcategory" , "question" ]
161157
@@ -172,7 +168,7 @@ def test_get_document_pk():
172168def test_load_csv_with_metadata (csv_file : str ):
173169 metadata_columns = ["category" , "subcategory" , "question" , "dataset" ]
174170 embed_columns = ["content" ]
175-
171+
176172 # Call the function to test
177173 docs = load_csv .load_csv_with_metadata (csv_file , embed_columns , metadata_columns )
178174
@@ -195,7 +191,7 @@ def test_load_csv_with_metadata(csv_file: str):
195191 "content" : "content2" ,
196192 }
197193
198- def test_retrieve_docs_from_vectordb (mocker ):
194+ def test_retrieve_docs_from_vectordb (mocker , dbsession ):
199195 # Create mock CompanyContent objects
200196 mock_company_contents = [
201197 CompanyContent (
0 commit comments