Skip to content
Merged
29 changes: 13 additions & 16 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,16 +580,16 @@ async def __query_collection(
For best hybrid search performance, consider creating a TSV column
and adding GIN index.
"""
if not k:
k = (
max(
self.k,
self.hybrid_search_config.primary_top_k,
self.hybrid_search_config.secondary_top_k,
)
if self.hybrid_search_config
else self.k
)
hybrid_search_config = kwargs.get(
"hybrid_search_config", self.hybrid_search_config
)

final_k = k if k is not None else self.k

dense_limit = final_k
if hybrid_search_config:
dense_limit = hybrid_search_config.primary_top_k

operator = self.distance_strategy.operator
search_function = self.distance_strategy.search_function

Expand Down Expand Up @@ -617,9 +617,9 @@ async def __query_collection(
embedding_data_string = ":query_embedding"
where_filters = f"WHERE {safe_filter}" if safe_filter else ""
dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance
FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k;
FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :dense_limit;
"""
param_dict = {"query_embedding": query_embedding, "k": k}
param_dict = {"query_embedding": query_embedding, "dense_limit": dense_limit}
if filter_dict:
param_dict.update(filter_dict)
if self.index_query_options:
Expand All @@ -637,16 +637,13 @@ async def __query_collection(
result_map = result.mappings()
dense_results = result_map.fetchall()

hybrid_search_config = kwargs.get(
"hybrid_search_config", self.hybrid_search_config
)
fts_query = (
hybrid_search_config.fts_query
if hybrid_search_config and hybrid_search_config.fts_query
else kwargs.get("fts_query", "")
)
if hybrid_search_config and fts_query:
hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k
hybrid_search_config.fusion_function_parameters["fetch_top_k"] = final_k
# do the sparse query
lang = (
f"'{hybrid_search_config.tsv_lang}',"
Expand Down