1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
1212# implied. See the License for the specific language governing
1313# permissions and limitations under the License.
14- """The retrieval module implements the Query class.
1514
16- The query class executes queries against vector databases and depends on a library object.
17- """
15+ """The retrieval module implements the Query class. The Query class provides a high-level interface for executing
16+ a variety of queries on a Library collection, whether instantiated on Mongo, Postgres, or SQLite.
17+
18+ The Query class includes both text retrieval strategies, which operate directly as queries on the text collection
19+ database, as well as vector embedding semantic retrieval strategies, which require the use of o vector DB and that the
20+ embeddings were previously created for the Library. There are also a number of convenience methods that provide
21+ 'hybrid' strategies combining elements of semantic and text querying."""
1822
1923
2024import logging
3337
3438
3539class Query :
36- """Implements the query capabilities against a ``library``.
40+
41+ """Implements the query capabilities against a ``Library` object`.
3742
3843 Query is responsible for executing queries against an indexed library. The library can be semantic, text, custom,
3944 or hybrid. A query object requires a library object as input, which will be the source of the query.
4045
4146 Parameters
4247 ----------
43- library : object
48+ library : Library object
4449 A ``library`` object.
4550
4651 embedding_model : object, default=None
@@ -103,6 +108,7 @@ class Query:
103108 "Federal Constitutional Law of 1920. The political system of the Second Republic with its nine federal "
104109 "states is based on the constitution of 1920, amended in 1929, which was re-enacted on 1 May 1945. [108] "
105110 """
111+
106112 def __init__ (self , library , embedding_model = None , tokenizer = None , vector_db_api_key = None ,
107113 query_id = None , from_hf = False , from_sentence_transformer = False ,embedding_model_name = None ,
108114 save_history = True , query_mode = None , vector_db = None , model_api_key = None ):
@@ -407,7 +413,6 @@ def query(self, query, query_type="text", result_count=20, results_only=True):
407413
408414 return output_result
409415
410- # basic simple text query method - only requires entering the query
411416 def text_query (self , query , exact_mode = False , result_count = 20 , exhaust_full_cursor = False , results_only = True ):
412417
413418 """ Execute a basic text query. """
@@ -693,7 +698,6 @@ def _cursor_to_qr (self, query, cursor_results, result_count=20, exhaust_full_cu
693698
694699 return qr_dict
695700
696- # basic semantic query
697701 def semantic_query (self , query , result_count = 20 , embedding_distance_threshold = None , custom_filter = None , results_only = True ):
698702
699703 """ Main method to execute a semantic query - only required parameter is the query. """
@@ -850,7 +854,8 @@ def similar_blocks_embedding(self, block, result_count=20, embedding_distance_th
850854
851855 return results_dict
852856
853- def dual_pass_query (self , query , result_count = 20 , primary = "text" , safety_check = True , custom_filter = None , results_only = True ):
857+ def dual_pass_query (self , query , result_count = 20 , primary = "text" ,
858+ safety_check = True , custom_filter = None , results_only = True ):
854859
855860 """ Executes a combination of text and semantic queries and attempts to interweave and re-rank based on
856861 correspondence between the two query attempts. """
@@ -872,12 +877,14 @@ def dual_pass_query(self, query, result_count=20, primary="text", safety_check=T
872877 # run dual pass - text + semantic
873878 # Choose appropriate text query method based on custom_filter
874879 if custom_filter :
875- retrieval_dict_text = self .text_query_with_custom_filter (query , custom_filter , result_count = result_count , results_only = True )
880+ retrieval_dict_text = self .text_query_with_custom_filter (query , custom_filter ,
881+ result_count = result_count , results_only = True )
876882 else :
877883 retrieval_dict_text = self .text_query (query , result_count = result_count , results_only = True )
878884
879885 # Semantic query with custom filter
880- retrieval_dict_semantic = self .semantic_query (query , result_count = result_count , custom_filter = custom_filter , results_only = True )
886+ retrieval_dict_semantic = self .semantic_query (query , result_count = result_count ,
887+ custom_filter = custom_filter , results_only = True )
881888
882889 if primary == "text" :
883890 first_list = retrieval_dict_text
@@ -930,17 +937,16 @@ def dual_pass_query(self, query, result_count=20, primary="text", safety_check=T
930937 doc_fn_list .append (qr ["file_source" ])
931938
932939 retrieval_dict = {"results" : merged_results ,
933- "text_results" : retrieval_dict_text ,
934- "semantic_results" : retrieval_dict_semantic ,
935- "doc_ID" : doc_id_list ,
936- "file_source" : doc_fn_list }
940+ "text_results" : retrieval_dict_text ,
941+ "semantic_results" : retrieval_dict_semantic ,
942+ "doc_ID" : doc_id_list ,
943+ "file_source" : doc_fn_list }
937944
938945 if results_only :
939946 return merged_results
940947
941948 return retrieval_dict
942949
943-
944950 def augment_qr (self , query_result , query_topic , augment_query = "semantic" ):
945951
946952 """ Augments the set of query results using alternative retrieval strategy. """
@@ -1038,7 +1044,9 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):
10381044 if "doc_ID" in doc_id_list :
10391045 doc_id_list = doc_id_list ["doc_ID" ]
10401046 else :
1041- logging .warning ("warning: could not recognize doc id list requested. by default, will set to all documents in the library collection." )
1047+ logging .warning ("warning: could not recognize doc id list requested. by default, "
1048+ "will set to all documents in the library collection." )
1049+
10421050 doc_id_list = self .list_doc_id ()
10431051
10441052 if not page_list :
@@ -1051,7 +1059,8 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):
10511059 else :
10521060 page_dict = {"doc_ID" : {"$in" :doc_id_list }, "master_index" : {"$in" : page_list }}
10531061
1054- cursor_results = CollectionRetrieval (self .library_name , account_name = self .account_name ).filter_by_key_dict (page_dict )
1062+ cursor_results = CollectionRetrieval (self .library_name ,
1063+ account_name = self .account_name ).filter_by_key_dict (page_dict )
10551064
10561065 output = []
10571066
@@ -1064,12 +1073,12 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):
10641073
10651074 return output
10661075
1067- # new method to extract whole library
10681076 def get_whole_library (self , selected_keys = None ):
10691077
10701078 """ Gets the whole library - and will return as a list in-memory. """
10711079
1072- match_results_cursor = CollectionRetrieval (self .library_name , account_name = self .account_name ).get_whole_collection ()
1080+ match_results_cursor = CollectionRetrieval (self .library_name ,
1081+ account_name = self .account_name ).get_whole_collection ()
10731082
10741083 match_results = match_results_cursor .pull_all ()
10751084
@@ -1098,7 +1107,6 @@ def get_whole_library(self, selected_keys=None):
10981107
10991108 return qr
11001109
1101- # new method to generate csv files for each table entry
11021110 def export_all_tables (self , query = "" , output_fp = None ):
11031111
11041112 """ Exports all tables, with query option to limit the list from a library. """
@@ -1110,7 +1118,8 @@ def export_all_tables(self, query="", output_fp=None):
11101118
11111119 if not query :
11121120
1113- match_results = CollectionRetrieval (self .library_name , account_name = self .account_name ).filter_by_key ("content_type" ,"table" )
1121+ match_results = CollectionRetrieval (self .library_name ,
1122+ account_name = self .account_name ).filter_by_key ("content_type" ,"table" )
11141123
11151124 else :
11161125 kv_dict = {"content_type" : "table" }
@@ -1274,7 +1283,8 @@ def list_doc_fn(self):
12741283
12751284 """ Utility function - returns list of all document names in the library. """
12761285
1277- doc_fn_raw_list = CollectionRetrieval (self .library_name , account_name = self .account_name ).get_distinct_list ("file_source" )
1286+ doc_fn_raw_list = CollectionRetrieval (self .library_name ,
1287+ account_name = self .account_name ).get_distinct_list ("file_source" )
12781288
12791289 doc_fn_out = []
12801290 for i , file in enumerate (doc_fn_raw_list ):
@@ -1712,15 +1722,18 @@ def expand_text_result_after(self, block, window_size=400):
17121722 return output
17131723
17141724 def generate_csv_report (self ):
1725+
17151726 """Generates a csv report from the current query status. """
1727+
17161728 output = QueryState (self ).generate_query_report_current_state ()
17171729 return output
17181730
17191731 def filter_by_key_value_range (self , key , value_range , results_only = True ):
17201732
17211733 """ Executes a filter by key value range. """
17221734
1723- cursor = CollectionRetrieval (self .library_name , account_name = self .account_name ).filter_by_key_value_range (key ,value_range )
1735+ cursor = CollectionRetrieval (self .library_name ,
1736+ account_name = self .account_name ).filter_by_key_value_range (key ,value_range )
17241737
17251738 query = ""
17261739 result_dict = self ._cursor_to_qr (query , cursor , exhaust_full_cursor = True )
0 commit comments