Skip to content

Commit 3fb0691

Browse files
mcbianconivmesel
authored andcommitted
load_csv tests
1 parent 8a0cfc4 commit 3fb0691

File tree

6 files changed

+124
-13
lines changed

6 files changed

+124
-13
lines changed

‎.gitignore‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ celerybeat.pid
120120
*.sage.py
121121

122122
# Environments
123-
.env
123+
.env*
124+
!.env.sample
124125
.venv
125126
env/
126127
venv/
@@ -166,4 +167,5 @@ requirements.txt
166167
!src/tests/fixtures/*.csv
167168
!src/tests/fixtures/*.toml
168169
!sample_data/*.csv
169-
!sample_data/*.toml
170+
!sample_data/*.toml
171+
.vscode

‎docker-compose.test.yml‎

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
version: '3.3'
21
services:
32
db:
43
image: pgvector/pgvector:pg15
@@ -7,7 +6,6 @@ services:
76
- '5432:5432'
87
volumes:
98
- ./etc/db-ext-vector-test.sql:/docker-entrypoint-initdb.d/init.sql
10-
- postgres_data:/var/lib/postgresql/data/
119
environment:
1210
POSTGRES_USER: talkdai
1311
POSTGRES_PASSWORD: talkdai
@@ -31,7 +29,4 @@ services:
3129
env_file:
3230
- ./src/tests/.env.testing
3331
volumes:
34-
- ./src/tests/:/app/src/tests/
35-
36-
volumes:
37-
postgres_data:
32+
- ./src/tests/:/app/src/tests/

‎docker-compose.yml‎

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
version: '3.3'
21
services:
32
db:
43
image: pgvector/pgvector:pg15
@@ -10,8 +9,8 @@ services:
109
POSTGRES_PASSWORD: talkdai
1110
POSTGRES_DB: talkdai
1211
volumes:
13-
- ./data/db:/var/lib/postgresql/data
14-
- ./psql/db-ext-vector.sql:/docker-entrypoint-initdb.d/db-ext-vector.sql
12+
- db-data:/var/lib/postgresql/data
13+
- ./ext/db-ext-vector.sql:/docker-entrypoint-initdb.d/db-ext-vector.sql
1514
healthcheck:
1615
test: ["CMD", "pg_isready", "-d", "talkdai", "-U", "talkdai"]
1716
interval: 10s
@@ -25,7 +24,6 @@ services:
2524
tty: true
2625
volumes:
2726
- ./:/app
28-
- ./.empty:/app/data/db
2927
- ./static:/app/static
3028
- ./sample_data:/app/src/sample_data
3129
ports:
@@ -41,3 +39,4 @@ services:
4139

4240
volumes:
4341
open-webui:
42+
db-data:

‎src/load_csv.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
session = next(get_session())
1313

14-
def load_csv_and_generate_embeddings(path, cleardb=False, embed_columns=("content",)):
14+
def load_csv_and_generate_embeddings(path: str, cleardb=False, embed_columns=None | list[str]) -> None:
15+
if not embed_columns:
16+
embed_columns = ["content"]
1517
df = pd.read_csv(path)
1618
necessary_cols = ["category", "subcategory", "question", "content"]
1719
for col in necessary_cols:

‎src/tests/conftest.py‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,17 @@
99
from dialog_lib.db.utils import create_chat_session
1010
from dialog.db import get_session
1111

12+
import dotenv
13+
14+
dotenv.load_dotenv()
15+
1216
SQLALCHEMY_DATABASE_URL = "postgresql://talkdai:talkdai@db/test_talkdai"
1317

18+
TEST_DATABASE_URL = os.getenv('TEST_DATABASE_URL')
19+
if TEST_DATABASE_URL:
20+
SQLALCHEMY_DATABASE_URL = TEST_DATABASE_URL
21+
22+
1423
@pytest.fixture
1524
def dbsession(mocker):
1625
engine = create_engine(SQLALCHEMY_DATABASE_URL)

‎src/tests/test_load_csv.py‎

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import csv
2+
import pytest
3+
import tempfile
4+
5+
import load_csv
6+
7+
from unittest.mock import Mock, patch
8+
9+
10+
@pytest.fixture
11+
def csv_file() -> str:
12+
temp_file = _create_csv()
13+
return temp_file
14+
15+
16+
def _create_csv(
17+
columns: list[str] | None = None, data: list[list[str]] | None = None
18+
) -> str:
19+
temp_file = tempfile.NamedTemporaryFile(
20+
prefix="test-dialog", suffix=".csv", delete=False
21+
)
22+
23+
if not columns:
24+
columns = ["category", "subcategory", "question", "content", "dataset"]
25+
26+
if not data:
27+
data = [
28+
["cat1", "subcat1", "q1", "content1", "dataset1"],
29+
["cat2", "subcat2", "q2", "content2", "dataset2"],
30+
]
31+
32+
with open(temp_file.name, "w", newline="\n") as f:
33+
writer = csv.writer(f)
34+
writer.writerow(columns)
35+
writer.writerows(data)
36+
return temp_file.name
37+
38+
39+
def test_load_csv(mocker, dbsession, csv_file: str):
40+
mock_generate_embeddings: Mock = mocker.patch("load_csv.generate_embeddings")
41+
mock_generate_embeddings.return_value = [
42+
[0.1] * 1536,
43+
[0.2] * 1536,
44+
] # 1536 is the expected dimension of the embeddings
45+
46+
load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
47+
48+
result = dbsession.query(load_csv.CompanyContent).all()
49+
assert len(result) == 2
50+
51+
52+
def test_multiple_columns_embedding(mocker, dbsession, csv_file: str):
53+
mock_generate_embeddings: Mock = mocker.patch("load_csv.generate_embeddings")
54+
mock_generate_embeddings.return_value = [
55+
[0.1] * 1536,
56+
[0.2] * 1536,
57+
] # 1536 is the expected dimension of the embeddings
58+
59+
load_csv.load_csv_and_generate_embeddings(
60+
csv_file, cleardb=True, embed_columns=["category", "subcategory", "content"]
61+
)
62+
63+
mock_generate_embeddings.assert_called_with(
64+
["cat1\nsubcat1\ncontent1", "cat2\nsubcat2\ncontent2"],
65+
embedding_llm_instance=load_csv.EMBEDDINGS_LLM,
66+
)
67+
68+
69+
def test_clear_db(mocker, dbsession, csv_file: str):
70+
mock_generate_embeddings: Mock = mocker.patch("load_csv.generate_embeddings")
71+
mock_generate_embeddings.return_value = [
72+
[0.1] * 1536,
73+
[0.2] * 1536,
74+
] # 1536 is the expected dimension of the embeddings
75+
76+
load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
77+
initial_run = dbsession.query(load_csv.CompanyContent).all()
78+
79+
load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
80+
clear_db_run = dbsession.query(load_csv.CompanyContent).all()
81+
82+
other_csv_file = _create_csv(
83+
data=[
84+
["cat3", "subcat3", "q3", "content3", "dataset3"],
85+
["cat4", "subcat4", "q4", "content4", "dataset4"],
86+
]
87+
)
88+
load_csv.load_csv_and_generate_embeddings(other_csv_file, cleardb=False)
89+
dont_clear_db_run = dbsession.query(load_csv.CompanyContent).all()
90+
91+
assert len(initial_run) == 2
92+
assert len(clear_db_run) == 2
93+
assert len(dont_clear_db_run) == 4
94+
95+
96+
def test_ensure_necessary_columns():
97+
with pytest.raises(Exception):
98+
load_csv.load_csv_and_generate_embeddings(
99+
_create_csv(
100+
columns=["category", "subcategory", "question"],
101+
data=[["cat1", "subcat1", "q1"]],
102+
),
103+
cleardb=True,
104+
) # missing content column

0 commit comments

Comments
 (0)