@@ -121,6 +121,8 @@ def __init__(self, api_key=None, verbose=True, analyze_mode=True):
121121 self .api_key = api_key
122122 self .api_exec = False
123123
124+ self .sql_query = None
125+
124126 # check for llmware path & create if not already set up, e.g., "first time use"
125127 if not os .path .exists (LLMWareConfig .get_llmware_path ()):
126128 LLMWareConfig .setup_llmware_workspace ()
@@ -1104,7 +1106,57 @@ def sql(self, query, table_schema):
11041106
11051107 return output_response
11061108
1107- def query_custom_table (self , query , db = None ,table = None ,table_schema = None ,db_name = "llmware" ):
1109+ def sql_checker (self , sql_query , custom_sql_checker = None ):
1110+
1111+ """ Implements a basic post processing check on text-2-sql generation to confirm that
1112+ the query is a SELECT statement and not a form of DB WRITE command.
1113+
1114+ By passing a custom_sql_checker function, you can enhance this basic check.
1115+
1116+ The custom_sql_checker function should accept a string sql_query as input,
1117+ and return two outputs:
1118+
1119+ 1- confirmation: a boolean truth value of True/False to indicate whether to move ahead
1120+ 2- sql_query_updated: a return string that may be identical/modification of original sql query
1121+
1122+ """
1123+
1124+ # if no red-flags identified, then will return True and original sql_query
1125+ confirmation = True
1126+ sql_query_updated = sql_query
1127+
1128+ logger .debug (f"LLMfx - sql_checker - { sql_query } - being reviewed." )
1129+
1130+ if custom_sql_checker :
1131+ confirmation , sql_query_updated = custom_sql_checker (sql_query )
1132+
1133+ else :
1134+
1135+ # reviews any SQL statement that does not start with SELECT
1136+
1137+ if not sql_query .startswith ("SELECT" ):
1138+
1139+ sql_tokens = sql_query .split (" " )
1140+
1141+ logger .warning (f"LLMfx - sql_checker - sql query statement does not start "
1142+ f"with SELECT statement - { sql_query } " )
1143+
1144+ # this list can be enhanced
1145+ basic_write_commands = ["DROP" , "INSERT" , "CREATE" , "DELETE" , "ALTER" ]
1146+
1147+ for toks in sql_tokens :
1148+
1149+ if toks .upper () in basic_write_commands :
1150+ logger .warning (f"LLMfx - sql_checker - sql query statement appears to create "
1151+ f"WRITE elements - { toks } - stopping." )
1152+
1153+ confirmation = False
1154+ break
1155+
1156+ return confirmation , sql_query_updated
1157+
1158+ def query_custom_table (self , query , db = None ,table = None ,table_schema = None ,db_name = "llmware" ,
1159+ custom_sql_checker = None ):
11081160
11091161 """ Executes a text-to-sql query on a CustomTable database table. """
11101162
@@ -1117,30 +1169,46 @@ def query_custom_table(self, query, db=None,table=None,table_schema=None,db_name
11171169 # step 1 - convert question into sql
11181170
11191171 if not table_schema :
1120- logging .warning ("update: LLMfx - query_db - could not identify table schema - can not proceed" )
1172+ logging .warning ("LLMfx - query_db - could not identify table schema - can not proceed" )
11211173 return - 1
11221174
11231175 # run inference with query and table schema to get SQL query response
11241176 response = self .sql (query , table_schema )
11251177
11261178 # step 2 - run query
11271179 sql_query = response ["llm_response" ]
1180+ self .sql_query = sql_query
1181+
1182+ # basic sql verification checker
1183+ confirmation , self .sql_query = self .sql_checker (self .sql_query , custom_sql_checker = custom_sql_checker )
1184+
1185+ if not confirmation :
1186+ logger .warning (f"LLMfx - query_custom_db - sql query generated appears to be potentially unsafe - "
1187+ f"{ self .sql_query } so not moving ahead with query." )
1188+
1189+ empty_result = {"step" : self .step , "tool" : "sql" , "db_response" : [],
1190+ "sql_query" : self .sql_query + "-NOT_EXECUTED" ,
1191+ "query" : query , "db" : db , "work_item" : table_schema }
1192+
1193+ self .research_list .append (empty_result )
1194+
1195+ return empty_result
11281196
11291197 # initial journal update
11301198 journal_update = f"executing research call - executing query on db\n "
11311199 journal_update += f"\t \t \t \t -- db - { db } \n "
1132- journal_update += f"\t \t \t \t -- sql_query - { sql_query } "
1200+ journal_update += f"\t \t \t \t -- sql_query - { self . sql_query } "
11331201 self .write_to_journal (journal_update )
11341202
1135- db_output = custom_table .custom_lookup (response [ "llm_response" ] )
1203+ db_output = custom_table .custom_lookup (self . sql_query )
11361204
11371205 output = []
11381206 db_response = list (db_output )
11391207
11401208 for rows in db_response :
11411209 output .append (rows )
11421210
1143- result = {"step" : self .step , "tool" : "sql" , "db_response" : output , "sql_query" : response [ "llm_response" ] ,
1211+ result = {"step" : self .step , "tool" : "sql" , "db_response" : output , "sql_query" : self . sql_query ,
11441212 "query" : query ,"db" : db , "work_item" : table_schema }
11451213
11461214 self .research_list .append (result )
@@ -1155,7 +1223,8 @@ def query_custom_table(self, query, db=None,table=None,table_schema=None,db_name
11551223
11561224 return result
11571225
1158- def query_db (self , query , table = None , table_schema = None , db = None , db_name = None ):
1226+ def query_db (self , query , table = None , table_schema = None , db = None , db_name = None ,
1227+ custom_sql_checker = None ):
11591228
11601229 """ Executes two steps - converts input query into SQL, and then executes the SQL query on the DB. """
11611230
@@ -1168,31 +1237,47 @@ def query_db(self, query, table=None, table_schema=None, db=None, db_name=None):
11681237 # step 1 - convert question into sql
11691238
11701239 if not table_schema :
1171- logging .warning ("update: LLMfx - query_db - could not identify table schema - can not proceed" )
1240+ logging .warning ("LLMfx - query_db - could not identify table schema - can not proceed" )
11721241 return - 1
11731242
11741243 # run inference with query and table schema to get SQL query response
11751244 response = self .sql (query , table_schema )
11761245
11771246 # step 2 - run query
11781247 sql_query = response ["llm_response" ]
1248+ self .sql_query = sql_query
11791249 sql_db_name = sql_db .db_file
11801250
1251+ # basic sql safety check
1252+ confirmation , self .sql_query = self .sql_checker (self .sql_query , custom_sql_checker = custom_sql_checker )
1253+
1254+ if not confirmation :
1255+ logger .warning (f"LLMfx - query_db - sql query generated appears to be potentially unsafe - "
1256+ f"{ self .sql_query } so not moving ahead with query." )
1257+
1258+ empty_result = {"step" : self .step , "tool" : "sql" , "db_response" : [],
1259+ "sql_query" : self .sql_query + "-NOT_EXECUTED" ,
1260+ "query" : query , "db" : db , "work_item" : table_schema }
1261+
1262+ self .research_list .append (empty_result )
1263+
1264+ return empty_result
1265+
11811266 # initial journal update
11821267 journal_update = f"executing research call - executing query on db\n "
11831268 journal_update += f"\t \t \t \t -- db - { sql_db_name } \n "
1184- journal_update += f"\t \t \t \t -- sql_query - { sql_query } "
1269+ journal_update += f"\t \t \t \t -- sql_query - { self . sql_query } "
11851270 self .write_to_journal (journal_update )
11861271
1187- db_output = sql_db .query_db (response [ "llm_response" ] )
1272+ db_output = sql_db .query_db (self . sql_query )
11881273
11891274 output = []
11901275 db_response = list (db_output )
11911276
11921277 for rows in db_response :
11931278 output .append (rows )
11941279
1195- result = {"step" : self .step , "tool" : "sql" , "db_response" : output , "sql_query" : response [ "llm_response" ] ,
1280+ result = {"step" : self .step , "tool" : "sql" , "db_response" : output , "sql_query" : self . sql_query ,
11961281 "query" : query ,"db" : sql_db_name , "work_item" : table_schema }
11971282
11981283 self .research_list .append (result )
0 commit comments