Skip to content

Commit 856dfa9

Browse files
authored
Merge pull request #748 from wwulfric/main
feat: support pgvecto.rs
2 parents 36bdcde + 6453492 commit 856dfa9

File tree

3 files changed

+278
-7
lines changed

3 files changed

+278
-7
lines changed

‎src/vanna/pgvector/__init__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .pgvector import PG_VectorStore
2+
from .pgvecto_rs import PG_Vecto_rsStore

‎src/vanna/pgvector/pgvecto_rs.py‎

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import ast
2+
import json
3+
import logging
4+
import uuid
5+
6+
import pandas as pd
7+
from langchain_core.documents import Document
8+
from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs
9+
from sqlalchemy import create_engine, text
10+
11+
from .. import ValidationError
12+
from ..base import VannaBase
13+
from ..types import TrainingPlan, TrainingPlanItem
14+
from ..utils import deterministic_uuid
15+
16+
17+
class PG_Vecto_rsStore(VannaBase):
18+
def __init__(self, config=None):
19+
if not config or "connection_string" not in config:
20+
raise ValueError(
21+
"A valid 'config' dictionary with a 'connection_string' is required.")
22+
23+
VannaBase.__init__(self, config=config)
24+
25+
if config and "connection_string" in config:
26+
self.connection_string = config.get("connection_string")
27+
self.n_results = config.get("n_results", 10)
28+
29+
if config and "embedding_function" in config:
30+
self.embedding_function = config.get("embedding_function")
31+
self.vector_dimension = config.get("vector_dimension")
32+
else:
33+
from langchain_huggingface import HuggingFaceEmbeddings
34+
self.embedding_function = HuggingFaceEmbeddings(
35+
model_name="all-MiniLM-L6-v2")
36+
self.vector_dimension = 384
37+
self.sql_collection = PGVecto_rs(
38+
embedding=self.embedding_function,
39+
collection_name="sql",
40+
db_url=self.connection_string,
41+
dimension=self.vector_dimension,
42+
)
43+
self.ddl_collection = PGVecto_rs(
44+
embedding=self.embedding_function,
45+
collection_name="ddl",
46+
db_url=self.connection_string,
47+
dimension=self.vector_dimension,
48+
)
49+
self.documentation_collection = PGVecto_rs(
50+
embedding=self.embedding_function,
51+
collection_name="documentation",
52+
db_url=self.connection_string,
53+
dimension=self.vector_dimension,
54+
)
55+
56+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
57+
question_sql_json = json.dumps(
58+
{
59+
"question": question,
60+
"sql": sql,
61+
},
62+
ensure_ascii=False,
63+
)
64+
id = deterministic_uuid(question_sql_json) + "-sql"
65+
createdat = kwargs.get("createdat")
66+
doc = Document(
67+
page_content=question_sql_json,
68+
metadata={"id": id, "createdat": createdat},
69+
)
70+
self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
71+
72+
return id
73+
74+
def add_ddl(self, ddl: str, **kwargs) -> str:
75+
_id = deterministic_uuid(ddl) + "-ddl"
76+
doc = Document(
77+
page_content=ddl,
78+
metadata={"id": _id},
79+
)
80+
self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
81+
return _id
82+
83+
def add_documentation(self, documentation: str, **kwargs) -> str:
84+
_id = deterministic_uuid(documentation) + "-doc"
85+
doc = Document(
86+
page_content=documentation,
87+
metadata={"id": _id},
88+
)
89+
self.documentation_collection.add_documents([doc],
90+
ids=[doc.metadata["id"]])
91+
return _id
92+
93+
def get_collection(self, collection_name):
94+
match collection_name:
95+
case "sql":
96+
return self.sql_collection
97+
case "ddl":
98+
return self.ddl_collection
99+
case "documentation":
100+
return self.documentation_collection
101+
case _:
102+
raise ValueError("Specified collection does not exist.")
103+
104+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
105+
documents = self.sql_collection.similarity_search(query=question,
106+
k=self.n_results)
107+
return [ast.literal_eval(document.page_content) for document in documents]
108+
109+
def get_related_ddl(self, question: str, **kwargs) -> list:
110+
documents = self.ddl_collection.similarity_search(query=question,
111+
k=self.n_results)
112+
return [document.page_content for document in documents]
113+
114+
def get_related_documentation(self, question: str, **kwargs) -> list:
115+
documents = self.documentation_collection.similarity_search(query=question,
116+
k=self.n_results)
117+
return [document.page_content for document in documents]
118+
119+
def train(
120+
self,
121+
question: str | None = None,
122+
sql: str | None = None,
123+
ddl: str | None = None,
124+
documentation: str | None = None,
125+
plan: TrainingPlan | None = None,
126+
createdat: str | None = None,
127+
):
128+
if question and not sql:
129+
raise ValidationError("Please provide a SQL query.")
130+
131+
if documentation:
132+
logging.info(f"Adding documentation: {documentation}")
133+
return self.add_documentation(documentation)
134+
135+
if sql and question:
136+
return self.add_question_sql(question=question, sql=sql,
137+
createdat=createdat)
138+
139+
if ddl:
140+
logging.info(f"Adding ddl: {ddl}")
141+
return self.add_ddl(ddl)
142+
143+
if plan:
144+
for item in plan._plan:
145+
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
146+
self.add_ddl(item.item_value)
147+
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
148+
self.add_documentation(item.item_value)
149+
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
150+
self.add_question_sql(question=item.item_name, sql=item.item_value)
151+
152+
def get_training_data(self, **kwargs) -> pd.DataFrame:
153+
# Establishing the connection
154+
engine = create_engine(self.connection_string)
155+
156+
# Querying the 'langchain_pg_embedding' table
157+
query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
158+
df_embedding = pd.read_sql(query_embedding, engine)
159+
160+
# List to accumulate the processed rows
161+
processed_rows = []
162+
163+
# Process each row in the DataFrame
164+
for _, row in df_embedding.iterrows():
165+
custom_id = row["cmetadata"]["id"]
166+
document = row["document"]
167+
training_data_type = "documentation" if custom_id[
168+
-3:] == "doc" else custom_id[-3:]
169+
170+
if training_data_type == "sql":
171+
# Convert the document string to a dictionary
172+
try:
173+
doc_dict = ast.literal_eval(document)
174+
question = doc_dict.get("question")
175+
content = doc_dict.get("sql")
176+
except (ValueError, SyntaxError):
177+
logging.info(
178+
f"Skipping row with custom_id {custom_id} due to parsing error.")
179+
continue
180+
elif training_data_type in ["documentation", "ddl"]:
181+
question = None # Default value for question
182+
content = document
183+
else:
184+
# If the suffix is not recognized, skip this row
185+
logging.info(
186+
f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
187+
continue
188+
189+
# Append the processed data to the list
190+
processed_rows.append(
191+
{"id": custom_id, "question": question, "content": content,
192+
"training_data_type": training_data_type}
193+
)
194+
195+
# Create a DataFrame from the list of processed rows
196+
df_processed = pd.DataFrame(processed_rows)
197+
198+
return df_processed
199+
200+
def remove_training_data(self, id: str, **kwargs) -> bool:
201+
# Create the database engine
202+
engine = create_engine(self.connection_string)
203+
204+
# SQL DELETE statement
205+
delete_statement = text(
206+
"""
207+
DELETE FROM langchain_pg_embedding
208+
WHERE cmetadata ->> 'id' = :id
209+
"""
210+
)
211+
212+
# Connect to the database and execute the delete statement
213+
with engine.connect() as connection:
214+
# Start a transaction
215+
with connection.begin() as transaction:
216+
try:
217+
result = connection.execute(delete_statement, {"id": id})
218+
# Commit the transaction if the delete was successful
219+
transaction.commit()
220+
# Check if any row was deleted and return True or False accordingly
221+
return result.rowcount() > 0
222+
except Exception as e:
223+
# Rollback the transaction in case of error
224+
logging.error(f"An error occurred: {e}")
225+
transaction.rollback()
226+
return False
227+
228+
def remove_collection(self, collection_name: str) -> bool:
229+
engine = create_engine(self.connection_string)
230+
231+
# Determine the suffix to look for based on the collection name
232+
suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
233+
suffix = suffix_map.get(collection_name)
234+
235+
if not suffix:
236+
logging.info(
237+
"Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
238+
return False
239+
240+
# SQL query to delete rows based on the condition
241+
query = text(
242+
f"""
243+
DELETE FROM langchain_pg_embedding
244+
WHERE cmetadata->>'id' LIKE '%{suffix}'
245+
"""
246+
)
247+
248+
# Execute the deletion within a transaction block
249+
with engine.connect() as connection:
250+
with connection.begin() as transaction:
251+
try:
252+
result = connection.execute(query)
253+
transaction.commit() # Explicitly commit the transaction
254+
if result.rowcount() > 0:
255+
logging.info(
256+
f"Deleted {result.rowcount()} rows from "
257+
f"langchain_pg_embedding where collection is {collection_name}."
258+
)
259+
return True
260+
else:
261+
logging.info(f"No rows deleted for collection {collection_name}.")
262+
return False
263+
except Exception as e:
264+
logging.error(f"An error occurred: {e}")
265+
transaction.rollback() # Rollback in case of error
266+
return False
267+
268+
def generate_embedding(self, *args, **kwargs):
269+
pass

‎src/vanna/pgvector/pgvector.py‎

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .. import ValidationError
1212
from ..base import VannaBase
1313
from ..types import TrainingPlan, TrainingPlanItem
14+
from ..utils import deterministic_uuid
1415

1516

1617
class PG_VectorStore(VannaBase):
@@ -55,7 +56,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
5556
},
5657
ensure_ascii=False,
5758
)
58-
id = str(uuid.uuid4()) + "-sql"
59+
id = deterministic_uuid(question_sql_json) + "-sql"
5960
createdat = kwargs.get("createdat")
6061
doc = Document(
6162
page_content=question_sql_json,
@@ -66,7 +67,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
6667
return id
6768

6869
def add_ddl(self, ddl: str, **kwargs) -> str:
69-
_id = str(uuid.uuid4()) + "-ddl"
70+
_id = deterministic_uuid(ddl) + "-ddl"
7071
doc = Document(
7172
page_content=ddl,
7273
metadata={"id": _id},
@@ -75,7 +76,7 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
7576
return _id
7677

7778
def add_documentation(self, documentation: str, **kwargs) -> str:
78-
_id = str(uuid.uuid4()) + "-doc"
79+
_id = deterministic_uuid(documentation) + "-doc"
7980
doc = Document(
8081
page_content=documentation,
8182
metadata={"id": _id},
@@ -94,7 +95,7 @@ def get_collection(self, collection_name):
9495
case _:
9596
raise ValueError("Specified collection does not exist.")
9697

97-
def get_similar_question_sql(self, question: str) -> list:
98+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
9899
documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
99100
return [ast.literal_eval(document.page_content) for document in documents]
100101

@@ -203,7 +204,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool:
203204
# Commit the transaction if the delete was successful
204205
transaction.commit()
205206
# Check if any row was deleted and return True or False accordingly
206-
return result.rowcount > 0
207+
return result.rowcount() > 0
207208
except Exception as e:
208209
# Rollback the transaction in case of error
209210
logging.error(f"An error occurred: {e}")
@@ -235,9 +236,9 @@ def remove_collection(self, collection_name: str) -> bool:
235236
try:
236237
result = connection.execute(query)
237238
transaction.commit() # Explicitly commit the transaction
238-
if result.rowcount > 0:
239+
if result.rowcount() > 0:
239240
logging.info(
240-
f"Deleted {result.rowcount} rows from "
241+
f"Deleted {result.rowcount()} rows from "
241242
f"langchain_pg_embedding where collection is {collection_name}."
242243
)
243244
return True

0 commit comments

Comments
 (0)