Skip to content
Prev Previous commit
Next Next commit
fix metadata cols as fields for pks
  • Loading branch information
Luan Fernandes committed Jun 6, 2024
commit baf09e77071b07e832621ae735ce343c47412de2
12 changes: 7 additions & 5 deletions src/load_csv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
from typing import List
from typing import List, Optional
import hashlib
import csv

Expand Down Expand Up @@ -68,12 +68,14 @@ def get_document_pk(doc: Document, pk_metadata_fields: List[str]) -> str:
return hashlib.md5(concatened_fields.encode()).hexdigest()


def load_csv_and_generate_embeddings(path, cleardb=False, embed_columns=("content",)):
def load_csv_and_generate_embeddings(path, cleardb=False, embed_columns: Optional[list[str]]=None):
"""
Load the knowledge base CSV, get their embeddings and store them into the vector store.
"""
if not embed_columns:
embed_columns = ["content"]

metadata_columns = [col for col in _get_csv_cols(path) if col not in embed_columns]
pk_metadata_fields = ["category", "subcategory", "question"]

loader = CSVLoader(path, metadata_columns=metadata_columns)
docs: List[Document] = loader.load()
Expand All @@ -95,14 +97,14 @@ def load_csv_and_generate_embeddings(path, cleardb=False, embed_columns=("conten
docs_in_db: List[Document] = retrieve_docs_from_vectordb()
logging.info(f"Existing docs: {len(docs_in_db)}")
existing_pks: List[str] = [
get_document_pk(doc, pk_metadata_fields) for doc in docs_in_db
get_document_pk(doc, metadata_columns) for doc in docs_in_db
]

# Keep only new keys
docs_filtered: List[Document] = [
doc
for doc in docs
if get_document_pk(doc, pk_metadata_fields) not in existing_pks
if get_document_pk(doc, metadata_columns) not in existing_pks
]
if len(docs_filtered) == 0:
return
Expand Down