Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ celerybeat.pid
*.sage.py

# Environments
.env
.env*
!.env.sample
.venv
env/
venv/
Expand Down Expand Up @@ -166,4 +167,5 @@ requirements.txt
!src/tests/fixtures/*.csv
!src/tests/fixtures/*.toml
!sample_data/*.csv
!sample_data/*.toml
!sample_data/*.toml
.vscode
7 changes: 1 addition & 6 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
version: '3.3'
services:
db:
image: pgvector/pgvector:pg15
Expand All @@ -7,7 +6,6 @@ services:
- '5432:5432'
volumes:
- ./etc/db-ext-vector-test.sql:/docker-entrypoint-initdb.d/init.sql
- postgres_data:/var/lib/postgresql/data/
environment:
POSTGRES_USER: talkdai
POSTGRES_PASSWORD: talkdai
Expand All @@ -31,7 +29,4 @@ services:
env_file:
- ./src/tests/.env.testing
volumes:
- ./src/tests/:/app/src/tests/

volumes:
postgres_data:
- ./src/tests/:/app/src/tests/
7 changes: 3 additions & 4 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
version: '3.3'
services:
db:
image: pgvector/pgvector:pg15
Expand All @@ -10,8 +9,8 @@ services:
POSTGRES_PASSWORD: talkdai
POSTGRES_DB: talkdai
volumes:
- ./data/db:/var/lib/postgresql/data
- ./psql/db-ext-vector.sql:/docker-entrypoint-initdb.d/db-ext-vector.sql
- db-data:/var/lib/postgresql/data
- ./ext/db-ext-vector.sql:/docker-entrypoint-initdb.d/db-ext-vector.sql
healthcheck:
test: ["CMD", "pg_isready", "-d", "talkdai", "-U", "talkdai"]
interval: 10s
Expand All @@ -25,7 +24,6 @@ services:
tty: true
volumes:
- ./:/app
- ./.empty:/app/data/db
- ./static:/app/static
- ./sample_data:/app/src/sample_data
ports:
Expand All @@ -41,3 +39,4 @@ services:

volumes:
open-webui:
db-data:
4 changes: 3 additions & 1 deletion src/load_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

session = next(get_session())

def load_csv_and_generate_embeddings(path, cleardb=False, embed_columns=("content",)):
def load_csv_and_generate_embeddings(path: str, cleardb=False, embed_columns: None | list[str] = None) -> None:
if not embed_columns:
embed_columns = ["content"]
df = pd.read_csv(path)
necessary_cols = ["category", "subcategory", "question", "content"]
for col in necessary_cols:
Expand Down
9 changes: 9 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,17 @@
from dialog_lib.db.utils import create_chat_session
from dialog.db import get_session

import dotenv

dotenv.load_dotenv()

SQLALCHEMY_DATABASE_URL = "postgresql://talkdai:talkdai@db/test_talkdai"

TEST_DATABASE_URL = os.getenv('TEST_DATABASE_URL')
if TEST_DATABASE_URL:
SQLALCHEMY_DATABASE_URL = TEST_DATABASE_URL


@pytest.fixture
def dbsession(mocker):
engine = create_engine(SQLALCHEMY_DATABASE_URL)
Expand Down
104 changes: 104 additions & 0 deletions src/tests/test_load_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import csv
import pytest
import tempfile

import load_csv

from unittest.mock import Mock, patch


@pytest.fixture
def csv_file() -> str:
temp_file = _create_csv()
return temp_file


def _create_csv(
columns: list[str] | None = None, data: list[list[str]] | None = None
) -> str:
temp_file = tempfile.NamedTemporaryFile(
prefix="test-dialog", suffix=".csv", delete=False
)

if not columns:
columns = ["category", "subcategory", "question", "content", "dataset"]

if not data:
data = [
["cat1", "subcat1", "q1", "content1", "dataset1"],
["cat2", "subcat2", "q2", "content2", "dataset2"],
]

with open(temp_file.name, "w", newline="\n") as f:
writer = csv.writer(f)
writer.writerow(columns)
writer.writerows(data)
return temp_file.name


def test_load_csv(mocker, dbsession, csv_file: str):
mock_generate_embeddings: Mock = mocker.patch("load_csv.generate_embeddings")
mock_generate_embeddings.return_value = [
[0.1] * 1536,
[0.2] * 1536,
] # 1536 is the expected dimension of the embeddings

load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)

result = dbsession.query(load_csv.CompanyContent).all()
assert len(result) == 2


def test_multiple_columns_embedding(mocker, dbsession, csv_file: str):
mock_generate_embeddings: Mock = mocker.patch("load_csv.generate_embeddings")
mock_generate_embeddings.return_value = [
[0.1] * 1536,
[0.2] * 1536,
] # 1536 is the expected dimension of the embeddings

load_csv.load_csv_and_generate_embeddings(
csv_file, cleardb=True, embed_columns=["category", "subcategory", "content"]
)

mock_generate_embeddings.assert_called_with(
["cat1\nsubcat1\ncontent1", "cat2\nsubcat2\ncontent2"],
embedding_llm_instance=load_csv.EMBEDDINGS_LLM,
)


def test_clear_db(mocker, dbsession, csv_file: str):
mock_generate_embeddings: Mock = mocker.patch("load_csv.generate_embeddings")
mock_generate_embeddings.return_value = [
[0.1] * 1536,
[0.2] * 1536,
] # 1536 is the expected dimension of the embeddings

load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
initial_run = dbsession.query(load_csv.CompanyContent).all()

load_csv.load_csv_and_generate_embeddings(csv_file, cleardb=True)
clear_db_run = dbsession.query(load_csv.CompanyContent).all()

other_csv_file = _create_csv(
data=[
["cat3", "subcat3", "q3", "content3", "dataset3"],
["cat4", "subcat4", "q4", "content4", "dataset4"],
]
)
load_csv.load_csv_and_generate_embeddings(other_csv_file, cleardb=False)
dont_clear_db_run = dbsession.query(load_csv.CompanyContent).all()

assert len(initial_run) == 2
assert len(clear_db_run) == 2
assert len(dont_clear_db_run) == 4


def test_ensure_necessary_columns():
with pytest.raises(Exception):
load_csv.load_csv_and_generate_embeddings(
_create_csv(
columns=["category", "subcategory", "question"],
data=[["cat1", "subcat1", "q1"]],
),
cleardb=True,
) # missing content column