Skip to content
Merged

V3 #2

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
code clean-ups
  • Loading branch information
andreashappe committed Sep 14, 2023
commit 4e508b3c1ab3f94521b0fbcbee48a4988276bc92
15 changes: 9 additions & 6 deletions db_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ def create_new_run(self, model, context_size):
self.cursor.execute("INSERT INTO runs (model, context_size) VALUES (?, ?)", (model, context_size))
return self.cursor.lastrowid

def add_log_query(self, run_id, round, cmd, result, duration=0, tokens_query=0, tokens_response=0):
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.query_cmd_id, cmd, result, duration, tokens_query, tokens_response))
def add_log_query(self, run_id, round, cmd, result, answer):
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.query_cmd_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))

def add_log_analyze_response(self, run_id, round, cmd, result, duration=0, tokens_query=0, tokens_response=0):
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.analyze_response_id, cmd, result, duration, tokens_query, tokens_response))
def add_log_analyze_response(self, run_id, round, cmd, result, answer):
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.analyze_response_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))

def add_log_update_state(self, run_id, round, cmd, result, duration=0, tokens_query=0, tokens_response=0):
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, duration, tokens_query, tokens_response))
def add_log_update_state(self, run_id, round, cmd, result, answer):

if answer != None:
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))
else:
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, 0, 0, 0))

def get_round_data(self, run_id, round):
rows = self.cursor.execute("select cmd_id, query, response, duration, tokens_query, tokens_response from queries where run_id = ? and round = ?", (run_id, round)).fetchall()
Expand Down
26 changes: 18 additions & 8 deletions llm_with_state.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import json
import time
import typing

from dataclasses import dataclass
from mako.template import Template
from helper import *

@dataclass
class LLMResult:
result: typing.Any
duration: float = 0
tokens_query: int = 0
tokens_response: int = 0

class LLMWithState:
def __init__(self, run_id, llm_connection, history, initial_user, initial_password):
self.llm_connection = llm_connection
Expand All @@ -19,20 +28,19 @@ def __init__(self, run_id, llm_connection, history, initial_user, initial_passwo
def get_next_cmd(self):
state_size = num_tokens_from_string(self.state)

next_cmd, diff, tok_query, tok_res = self.create_and_ask_prompt('query_next_command.txt', user=self.initial_user, password=self.initial_password, history=get_cmd_history(self.run_id, self.db, self.llm_connection.get_context_size()-state_size), state=self.state)

return next_cmd, diff, tok_query, tok_res
return self.create_and_ask_prompt('query_next_command.txt', user=self.initial_user, password=self.initial_password, history=get_cmd_history(self.run_id, self.db, self.llm_connection.get_context_size()-state_size), state=self.state)

def analyze_result(self, cmd, result):
resp_success, diff_2, tok_query, tok_resp = self.create_and_ask_prompt('successfull.txt', cmd=cmd, resp=result, facts=self.state)
result = self.create_and_ask_prompt('successfull.txt', cmd=cmd, resp=result, facts=self.state)

self.tmp_state = resp_success["facts"]
print("new state: " + str(result.result["facts"]))
self.tmp_state = result.result["facts"]

return resp_success, diff_2, tok_query, tok_resp
return result

def update_state(self):
self.state = "\n".join(map(lambda x: "- " + x, self.tmp_state))
return self.state
return LLMResult(self.state, 0, 0, 0)

def get_current_state(self):
return self.state
Expand All @@ -44,4 +52,6 @@ def create_and_ask_prompt(self, template_file, **params):
result, tok_query, tok_res = self.llm_connection.exec_query(prompt)
toc = time.perf_counter()
print(str(result))
return json.loads(result), str(toc-tic), tok_query, tok_res
json_answer = json.loads(result)

return LLMResult(json_answer, toc-tic, tok_query, tok_res)
21 changes: 10 additions & 11 deletions wintermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,25 @@
while round < max_rounds and not gotRoot:

console.log(f"Starting round {round} of {max_rounds}")
answer = llm_gpt.get_next_cmd()

next_cmd, diff, tok_query, tok_res = llm_gpt.get_next_cmd()
if answer.result["type"] == "cmd":
cmd, result, gotRoot = handle_cmd(conn, answer.result)
elif answer.result["type"] == "ssh":
cmd, result = handle_ssh(answer.result)

if next_cmd["type"] == "cmd":
cmd, result, gotRoot = handle_cmd(conn, next_cmd)
elif next_cmd["type"] == "ssh":
cmd, result = handle_ssh(next_cmd)

db.add_log_query(run_id, round, cmd, result, diff, tok_query, tok_res)
db.add_log_query(run_id, round, cmd, result, answer)

# output the command and it's result
console.print(Panel(result, title=cmd))

# analyze the result and update your state
resp_success, diff_2, tok_query, tok_resp = llm_gpt.analyze_result(cmd, result)
db.add_log_analyze_response(run_id, round, cmd, resp_success["reason"], diff_2, tok_query, tok_resp)
answer = llm_gpt.analyze_result(cmd, result)
db.add_log_analyze_response(run_id, round, cmd, answer.result["reason"], answer)

state = llm_gpt.update_state()
console.print(Panel(state, title="my new fact list"))
db.add_log_update_state(run_id, round, "", state, 0, 0, 0)
console.print(Panel(state.result, title="my new fact list"))
db.add_log_update_state(run_id, round, "", state.result, None)

# update our command history and output it
console.print(get_history_table(run_id, db, round))
Expand Down