Skip to content

Commit f702d74

Browse files
committed
updates to retrieval class
1 parent 56f0823 commit f702d74

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

‎llmware/retrieval.py‎

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
1212
# implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
14-
"""The retrieval module implements the Query class.
1514

16-
The query class executes queries against vector databases and depends on a library object.
17-
"""
15+
"""The retrieval module implements the Query class. The Query class provides a high-level interface for executing
16+
a variety of queries on a Library collection, whether instantiated on Mongo, Postgres, or SQLite.
17+
18+
The Query class includes both text retrieval strategies, which operate directly as queries on the text collection
19+
database, as well as vector embedding semantic retrieval strategies, which require the use of o vector DB and that the
20+
embeddings were previously created for the Library. There are also a number of convenience methods that provide
21+
'hybrid' strategies combining elements of semantic and text querying."""
1822

1923

2024
import logging
@@ -33,14 +37,15 @@
3337

3438

3539
class Query:
36-
"""Implements the query capabilities against a ``library``.
40+
41+
"""Implements the query capabilities against a ``Library` object`.
3742
3843
Query is responsible for executing queries against an indexed library. The library can be semantic, text, custom,
3944
or hybrid. A query object requires a library object as input, which will be the source of the query.
4045
4146
Parameters
4247
----------
43-
library : object
48+
library : Library object
4449
A ``library`` object.
4550
4651
embedding_model : object, default=None
@@ -103,6 +108,7 @@ class Query:
103108
"Federal Constitutional Law of 1920. The political system of the Second Republic with its nine federal "
104109
"states is based on the constitution of 1920, amended in 1929, which was re-enacted on 1 May 1945. [108] "
105110
"""
111+
106112
def __init__(self, library, embedding_model=None, tokenizer=None, vector_db_api_key=None,
107113
query_id=None, from_hf=False, from_sentence_transformer=False,embedding_model_name=None,
108114
save_history=True, query_mode=None, vector_db=None, model_api_key=None):
@@ -407,7 +413,6 @@ def query(self, query, query_type="text", result_count=20, results_only=True):
407413

408414
return output_result
409415

410-
# basic simple text query method - only requires entering the query
411416
def text_query (self, query, exact_mode=False, result_count=20, exhaust_full_cursor=False, results_only=True):
412417

413418
""" Execute a basic text query. """
@@ -693,7 +698,6 @@ def _cursor_to_qr (self, query, cursor_results, result_count=20, exhaust_full_cu
693698

694699
return qr_dict
695700

696-
# basic semantic query
697701
def semantic_query(self, query, result_count=20, embedding_distance_threshold=None, custom_filter=None, results_only=True):
698702

699703
""" Main method to execute a semantic query - only required parameter is the query. """
@@ -850,7 +854,8 @@ def similar_blocks_embedding(self, block, result_count=20, embedding_distance_th
850854

851855
return results_dict
852856

853-
def dual_pass_query(self, query, result_count=20, primary="text", safety_check=True, custom_filter=None, results_only=True):
857+
def dual_pass_query(self, query, result_count=20, primary="text",
858+
safety_check=True, custom_filter=None, results_only=True):
854859

855860
""" Executes a combination of text and semantic queries and attempts to interweave and re-rank based on
856861
correspondence between the two query attempts. """
@@ -872,12 +877,14 @@ def dual_pass_query(self, query, result_count=20, primary="text", safety_check=T
872877
# run dual pass - text + semantic
873878
# Choose appropriate text query method based on custom_filter
874879
if custom_filter:
875-
retrieval_dict_text = self.text_query_with_custom_filter(query, custom_filter, result_count=result_count, results_only=True)
880+
retrieval_dict_text = self.text_query_with_custom_filter(query, custom_filter,
881+
result_count=result_count, results_only=True)
876882
else:
877883
retrieval_dict_text = self.text_query(query, result_count=result_count, results_only=True)
878884

879885
# Semantic query with custom filter
880-
retrieval_dict_semantic = self.semantic_query(query, result_count=result_count, custom_filter=custom_filter, results_only=True)
886+
retrieval_dict_semantic = self.semantic_query(query, result_count=result_count,
887+
custom_filter=custom_filter, results_only=True)
881888

882889
if primary == "text":
883890
first_list = retrieval_dict_text
@@ -930,17 +937,16 @@ def dual_pass_query(self, query, result_count=20, primary="text", safety_check=T
930937
doc_fn_list.append(qr["file_source"])
931938

932939
retrieval_dict = {"results": merged_results,
933-
"text_results": retrieval_dict_text,
934-
"semantic_results": retrieval_dict_semantic,
935-
"doc_ID": doc_id_list,
936-
"file_source": doc_fn_list}
940+
"text_results": retrieval_dict_text,
941+
"semantic_results": retrieval_dict_semantic,
942+
"doc_ID": doc_id_list,
943+
"file_source": doc_fn_list}
937944

938945
if results_only:
939946
return merged_results
940947

941948
return retrieval_dict
942949

943-
944950
def augment_qr (self, query_result, query_topic, augment_query="semantic"):
945951

946952
""" Augments the set of query results using alternative retrieval strategy. """
@@ -1038,7 +1044,9 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):
10381044
if "doc_ID" in doc_id_list:
10391045
doc_id_list = doc_id_list["doc_ID"]
10401046
else:
1041-
logging.warning("warning: could not recognize doc id list requested. by default, will set to all documents in the library collection.")
1047+
logging.warning("warning: could not recognize doc id list requested. by default, "
1048+
"will set to all documents in the library collection.")
1049+
10421050
doc_id_list = self.list_doc_id()
10431051

10441052
if not page_list:
@@ -1051,7 +1059,8 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):
10511059
else:
10521060
page_dict = {"doc_ID": {"$in":doc_id_list}, "master_index": {"$in": page_list}}
10531061

1054-
cursor_results = CollectionRetrieval(self.library_name, account_name=self.account_name).filter_by_key_dict(page_dict)
1062+
cursor_results = CollectionRetrieval(self.library_name,
1063+
account_name=self.account_name).filter_by_key_dict(page_dict)
10551064

10561065
output = []
10571066

@@ -1064,12 +1073,12 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):
10641073

10651074
return output
10661075

1067-
# new method to extract whole library
10681076
def get_whole_library(self, selected_keys=None):
10691077

10701078
""" Gets the whole library - and will return as a list in-memory. """
10711079

1072-
match_results_cursor = CollectionRetrieval(self.library_name, account_name=self.account_name).get_whole_collection()
1080+
match_results_cursor = CollectionRetrieval(self.library_name,
1081+
account_name=self.account_name).get_whole_collection()
10731082

10741083
match_results = match_results_cursor.pull_all()
10751084

@@ -1098,7 +1107,6 @@ def get_whole_library(self, selected_keys=None):
10981107

10991108
return qr
11001109

1101-
# new method to generate csv files for each table entry
11021110
def export_all_tables(self, query="", output_fp=None):
11031111

11041112
""" Exports all tables, with query option to limit the list from a library. """
@@ -1110,7 +1118,8 @@ def export_all_tables(self, query="", output_fp=None):
11101118

11111119
if not query:
11121120

1113-
match_results = CollectionRetrieval(self.library_name, account_name=self.account_name).filter_by_key("content_type","table")
1121+
match_results = CollectionRetrieval(self.library_name,
1122+
account_name=self.account_name).filter_by_key("content_type","table")
11141123

11151124
else:
11161125
kv_dict = {"content_type": "table"}
@@ -1274,7 +1283,8 @@ def list_doc_fn(self):
12741283

12751284
""" Utility function - returns list of all document names in the library. """
12761285

1277-
doc_fn_raw_list = CollectionRetrieval(self.library_name, account_name=self.account_name).get_distinct_list("file_source")
1286+
doc_fn_raw_list = CollectionRetrieval(self.library_name,
1287+
account_name=self.account_name).get_distinct_list("file_source")
12781288

12791289
doc_fn_out = []
12801290
for i, file in enumerate(doc_fn_raw_list):
@@ -1712,15 +1722,18 @@ def expand_text_result_after(self, block, window_size=400):
17121722
return output
17131723

17141724
def generate_csv_report(self):
1725+
17151726
"""Generates a csv report from the current query status. """
1727+
17161728
output = QueryState(self).generate_query_report_current_state()
17171729
return output
17181730

17191731
def filter_by_key_value_range(self, key, value_range, results_only=True):
17201732

17211733
""" Executes a filter by key value range. """
17221734

1723-
cursor = CollectionRetrieval(self.library_name, account_name=self.account_name).filter_by_key_value_range(key,value_range)
1735+
cursor = CollectionRetrieval(self.library_name,
1736+
account_name=self.account_name).filter_by_key_value_range(key,value_range)
17241737

17251738
query= ""
17261739
result_dict = self._cursor_to_qr(query, cursor, exhaust_full_cursor=True)

0 commit comments

Comments
 (0)