Skip to content

Commit a1fdfa7

Browse files
authored
Merge pull request llmware-ai#942 from llmware-ai/agent-text-2-sql-checker
adding-text-2-sql-checker
2 parents 17d5bb8 + 09f5cd1 commit a1fdfa7

File tree

1 file changed

+95
-10
lines changed

1 file changed

+95
-10
lines changed

‎llmware/agents.py‎

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def __init__(self, api_key=None, verbose=True, analyze_mode=True):
121121
self.api_key = api_key
122122
self.api_exec = False
123123

124+
self.sql_query = None
125+
124126
# check for llmware path & create if not already set up, e.g., "first time use"
125127
if not os.path.exists(LLMWareConfig.get_llmware_path()):
126128
LLMWareConfig.setup_llmware_workspace()
@@ -1104,7 +1106,57 @@ def sql(self, query, table_schema):
11041106

11051107
return output_response
11061108

1107-
def query_custom_table(self, query, db=None,table=None,table_schema=None,db_name="llmware"):
1109+
def sql_checker(self, sql_query, custom_sql_checker=None):
1110+
1111+
""" Implements a basic post processing check on text-2-sql generation to confirm that
1112+
the query is a SELECT statement and not a form of DB WRITE command.
1113+
1114+
By passing a custom_sql_checker function, you can enhance this basic check.
1115+
1116+
The custom_sql_checker function should accept a string sql_query as input,
1117+
and return two outputs:
1118+
1119+
1- confirmation: a boolean truth value of True/False to indicate whether to move ahead
1120+
2- sql_query_updated: a return string that may be identical/modification of original sql query
1121+
1122+
"""
1123+
1124+
# if no red-flags identified, then will return True and original sql_query
1125+
confirmation = True
1126+
sql_query_updated = sql_query
1127+
1128+
logger.debug(f"LLMfx - sql_checker - {sql_query} - being reviewed.")
1129+
1130+
if custom_sql_checker:
1131+
confirmation, sql_query_updated = custom_sql_checker(sql_query)
1132+
1133+
else:
1134+
1135+
# reviews any SQL statement that does not start with SELECT
1136+
1137+
if not sql_query.startswith("SELECT"):
1138+
1139+
sql_tokens = sql_query.split(" ")
1140+
1141+
logger.warning(f"LLMfx - sql_checker - sql query statement does not start "
1142+
f"with SELECT statement - {sql_query}")
1143+
1144+
# this list can be enhanced
1145+
basic_write_commands = ["DROP", "INSERT", "CREATE", "DELETE", "ALTER"]
1146+
1147+
for toks in sql_tokens:
1148+
1149+
if toks.upper() in basic_write_commands:
1150+
logger.warning(f"LLMfx - sql_checker - sql query statement appears to create "
1151+
f"WRITE elements - {toks} - stopping.")
1152+
1153+
confirmation = False
1154+
break
1155+
1156+
return confirmation, sql_query_updated
1157+
1158+
def query_custom_table(self, query, db=None,table=None,table_schema=None,db_name="llmware",
1159+
custom_sql_checker=None):
11081160

11091161
""" Executes a text-to-sql query on a CustomTable database table. """
11101162

@@ -1117,30 +1169,46 @@ def query_custom_table(self, query, db=None,table=None,table_schema=None,db_name
11171169
# step 1 - convert question into sql
11181170

11191171
if not table_schema:
1120-
logging.warning("update: LLMfx - query_db - could not identify table schema - can not proceed")
1172+
logging.warning("LLMfx - query_db - could not identify table schema - can not proceed")
11211173
return -1
11221174

11231175
# run inference with query and table schema to get SQL query response
11241176
response = self.sql(query, table_schema)
11251177

11261178
# step 2 - run query
11271179
sql_query = response["llm_response"]
1180+
self.sql_query = sql_query
1181+
1182+
# basic sql verification checker
1183+
confirmation, self.sql_query = self.sql_checker(self.sql_query, custom_sql_checker=custom_sql_checker)
1184+
1185+
if not confirmation:
1186+
logger.warning(f"LLMfx - query_custom_db - sql query generated appears to be potentially unsafe - "
1187+
f"{self.sql_query} so not moving ahead with query.")
1188+
1189+
empty_result = {"step": self.step, "tool": "sql", "db_response": [],
1190+
"sql_query": self.sql_query + "-NOT_EXECUTED",
1191+
"query": query, "db": db, "work_item": table_schema}
1192+
1193+
self.research_list.append(empty_result)
1194+
1195+
return empty_result
11281196

11291197
# initial journal update
11301198
journal_update = f"executing research call - executing query on db\n"
11311199
journal_update += f"\t\t\t\t -- db - {db}\n"
1132-
journal_update += f"\t\t\t\t -- sql_query - {sql_query}"
1200+
journal_update += f"\t\t\t\t -- sql_query - {self.sql_query}"
11331201
self.write_to_journal(journal_update)
11341202

1135-
db_output = custom_table.custom_lookup(response["llm_response"])
1203+
db_output = custom_table.custom_lookup(self.sql_query)
11361204

11371205
output = []
11381206
db_response = list(db_output)
11391207

11401208
for rows in db_response:
11411209
output.append(rows)
11421210

1143-
result = {"step": self.step, "tool": "sql", "db_response": output, "sql_query": response["llm_response"],
1211+
result = {"step": self.step, "tool": "sql", "db_response": output, "sql_query": self.sql_query,
11441212
"query": query,"db": db, "work_item": table_schema}
11451213

11461214
self.research_list.append(result)
@@ -1155,7 +1223,8 @@ def query_custom_table(self, query, db=None,table=None,table_schema=None,db_name
11551223

11561224
return result
11571225

1158-
def query_db(self, query, table=None, table_schema=None, db=None, db_name=None):
1226+
def query_db(self, query, table=None, table_schema=None, db=None, db_name=None,
1227+
custom_sql_checker=None):
11591228

11601229
""" Executes two steps - converts input query into SQL, and then executes the SQL query on the DB. """
11611230

@@ -1168,31 +1237,47 @@ def query_db(self, query, table=None, table_schema=None, db=None, db_name=None):
11681237
# step 1 - convert question into sql
11691238

11701239
if not table_schema:
1171-
logging.warning("update: LLMfx - query_db - could not identify table schema - can not proceed")
1240+
logging.warning("LLMfx - query_db - could not identify table schema - can not proceed")
11721241
return -1
11731242

11741243
# run inference with query and table schema to get SQL query response
11751244
response = self.sql(query, table_schema)
11761245

11771246
# step 2 - run query
11781247
sql_query = response["llm_response"]
1248+
self.sql_query = sql_query
11791249
sql_db_name = sql_db.db_file
11801250

1251+
# basic sql safety check
1252+
confirmation, self.sql_query = self.sql_checker(self.sql_query, custom_sql_checker=custom_sql_checker)
1253+
1254+
if not confirmation:
1255+
logger.warning(f"LLMfx - query_db - sql query generated appears to be potentially unsafe - "
1256+
f"{self.sql_query} so not moving ahead with query.")
1257+
1258+
empty_result = {"step": self.step, "tool": "sql", "db_response": [],
1259+
"sql_query": self.sql_query + "-NOT_EXECUTED",
1260+
"query": query, "db": db, "work_item": table_schema}
1261+
1262+
self.research_list.append(empty_result)
1263+
1264+
return empty_result
1265+
11811266
# initial journal update
11821267
journal_update = f"executing research call - executing query on db\n"
11831268
journal_update += f"\t\t\t\t -- db - {sql_db_name}\n"
1184-
journal_update += f"\t\t\t\t -- sql_query - {sql_query}"
1269+
journal_update += f"\t\t\t\t -- sql_query - {self.sql_query}"
11851270
self.write_to_journal(journal_update)
11861271

1187-
db_output = sql_db.query_db(response["llm_response"])
1272+
db_output = sql_db.query_db(self.sql_query)
11881273

11891274
output = []
11901275
db_response = list(db_output)
11911276

11921277
for rows in db_response:
11931278
output.append(rows)
11941279

1195-
result = {"step": self.step, "tool": "sql", "db_response": output, "sql_query": response["llm_response"],
1280+
result = {"step": self.step, "tool": "sql", "db_response": output, "sql_query": self.sql_query,
11961281
"query": query,"db": sql_db_name, "work_item": table_schema}
11971282

11981283
self.research_list.append(result)

0 commit comments

Comments
 (0)