Skip to content
Closed
144 changes: 95 additions & 49 deletions src/load_csv.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,119 @@
import argparse
from typing import List
import hashlib

import pandas as pd
import csv
from sqlalchemy import text

from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_core.documents import Document

from dialog_lib.embeddings.generate import generate_embeddings
from dialog.llm.embeddings import EMBEDDINGS_LLM
from dialog_lib.db.models import CompanyContent
from dialog.db import get_session
from dialog.settings import Settings

import logging

logging.basicConfig(
level=Settings().LOGGING_LEVEL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)

logger = logging.getLogger("make_embeddings")

session = next(get_session())
NECESSARY_COLS = ["category", "subcategory", "question", "content"]


def _get_csv_cols(path: str) -> List[str]:
"""Gets the csv columns from a csv file"""
with open(path) as f:
reader = csv.DictReader(f)
return reader.fieldnames

def retrieve_docs_from_vectordb() -> List[Document]:
"""Retrieve all documents from the vector store."""
company_contents: List[CompanyContent] = session.query(CompanyContent).all()
return [
Document(
page_content=content.content,
metadata={
"category": content.category,
"subcategory": content.subcategory,
"question": content.question,
},
)
for content in company_contents
]

def documents_to_company_content(doc: Document, embedding: float) -> CompanyContent:
"""Transform each langchain's Document and its embedding to a CompanyContent object."""
return CompanyContent(
category=doc.metadata.get("category"),
subcategory=doc.metadata.get("subcategory"),
question=doc.metadata.get("question"),
content=doc.page_content,
embedding=embedding,
dataset=doc.metadata.get("dataset"),
link=doc.metadata.get("link"),
)


def get_document_pk(doc: Document) -> str:
return (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern here is that a string with special characters could break this logic on some level on the database side, I would just do a hash on top of this concatenation that would prevent anything of this to happen.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, I forgot to aplpy hash again, at some point I've missed it

doc.metadata["category"]
+ doc.metadata["subcategory"]
+ doc.metadata["question"]
)


def load_csv_and_generate_embeddings(path, cleardb=False, embed_columns=("content",)):
df = pd.read_csv(path)
necessary_cols = ["category", "subcategory", "question", "content"]
for col in necessary_cols:
if col not in df.columns:
raise Exception(f"Column {col} not found in {path}")
metadata_columns = [col for col in _get_csv_cols(path) if col not in embed_columns]

if "dataset" in df.columns:
necessary_cols.append("dataset")
loader = CSVLoader(path, metadata_columns=metadata_columns)
docs: List[Document] = loader.load()

df = df[necessary_cols]
logger.info("Metadata columns: %s", metadata_columns)
logger.info("Embedding columns: %s", embed_columns)
logger.info("Glimpse over the first doc: %s", docs[0].page_content[:100])

# Create primary key column using category, subcategory, and question
df["primary_key"] = df["category"] + df["subcategory"] + df["question"]
df["primary_key"] = df["primary_key"].apply(
lambda row: hashlib.md5(row.encode()).hexdigest()
)
for col in NECESSARY_COLS:
if col not in metadata_columns + embed_columns:
raise Exception(f"Column {col} not found in {path}")

if cleardb:
logging.info("Clearing vectorstore completely...")
session.query(CompanyContent).delete()
session.commit()

df_in_db = pd.read_sql(
text(
f"SELECT category, subcategory, question, content, dataset FROM {CompanyContent.__tablename__}"
),
session.get_bind(),
)

# Create primary key column using category, subcategory, and question for df_in_db
new_keys = set(df["primary_key"])
if not df_in_db.empty:
df_in_db["primary_key"] = df_in_db["category"] + df_in_db["subcategory"] + df_in_db["question"]
df_in_db["primary_key"] = df_in_db["primary_key"].apply(
lambda row: hashlib.md5(row.encode()).hexdigest()
)
new_keys = set(df["primary_key"]) - set(df_in_db["primary_key"])

# Filter df for keys present in df and not present in df_in_db
df_filtered = df[df["primary_key"].isin(new_keys)].copy()

print("Generating embeddings for new questions...")
print("New questions:", len(df_filtered))
if len(df_filtered) == 0:
# Get existing docs
docs_in_db: List[Document] = retrieve_docs_from_vectordb()
logging.info(f"Existing docs: {len(docs_in_db)}")
existing_pks: List[str] = [
get_document_pk(doc) for doc in docs_in_db
]

# Keep only new keys
docs_filtered: List[Document] = [
doc for doc in docs if get_document_pk(doc) not in existing_pks
]
if len(docs_filtered) == 0:
return

print("embed_columns: ", embed_columns)
df_filtered.drop(columns=["primary_key"], inplace=True)
df_filtered["embedding"] = generate_embeddings(
list(df_filtered[embed_columns].agg('\n'.join, axis=1)),
embedding_llm_instance=EMBEDDINGS_LLM
)
df_filtered.to_sql(
CompanyContent.__tablename__,
session.get_bind(),
if_exists="append",
index=False,
logging.info("Generating embeddings for new questions...")
logging.info(f"New questions: {len(docs_filtered)}")
logging.info(f"embed_columns: {embed_columns}")

embedded_docs = generate_embeddings(
[doc.page_content for doc in docs_filtered],
embedding_llm_instance=EMBEDDINGS_LLM,
)
company_contents: List[CompanyContent] = [
documents_to_company_content(doc, embedding)
for (doc, embedding) in zip(docs_filtered, embedded_docs)
]
session.add_all(company_contents)


if __name__ == "__main__":
Expand All @@ -79,4 +124,5 @@ def load_csv_and_generate_embeddings(path, cleardb=False, embed_columns=("conten
args = parser.parse_args()

load_csv_and_generate_embeddings(
args.path, args.cleardb, args.embed_columns.split(','))
args.path, args.cleardb, args.embed_columns.split(",")
)