Skip to content
Prev Previous commit
Next Next commit
simplify llm_with_state a bit
  • Loading branch information
andreashappe committed Sep 21, 2023
commit 509b534333ab5f3016def0b948fdcd67b5f36b5f
19 changes: 10 additions & 9 deletions llm_with_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class LLMResult:
tokens_query: int = 0
tokens_response: int = 0


TPL_NEXT = Template(filename='templates/query_next_command.txt')
TPL_ANALYZE = Template(filename="templates/analyze_cmd.txt")
TPL_STATE = Template(filename="templates/update_state.txt")

class LLMWithState:
def __init__(self, run_id, llm_connection, history, config):
self.llm_connection = llm_connection
Expand All @@ -35,13 +40,10 @@ def get_state_size(self, model):

def get_next_cmd(self):

template_file = 'query_next_command.txt'
model = self.llm_connection.get_model()

state_size = self.get_state_size(model)

template = Template(filename='templates/' + template_file)
template_size = num_tokens_from_string(model, template.source)
template_size = num_tokens_from_string(model, TPL_NEXT.source)

history = get_cmd_history_v3(model, self.llm_connection.get_context_size(), self.run_id, self.db, state_size+template_size)

Expand All @@ -50,7 +52,7 @@ def get_next_cmd(self):
else:
target_user = "Administrator"

return self.create_and_ask_prompt_text(template_file, history=history, state=self.state, target=self.target, update_state=self.enable_update_state, target_user=target_user)
return self.create_and_ask_prompt_text(TPL_NEXT, history=history, state=self.state, target=self.target, update_state=self.enable_update_state, target_user=target_user)

def analyze_result(self, cmd, result):

Expand All @@ -66,20 +68,19 @@ def analyze_result(self, cmd, result):
result = result[cut_off:]
current_size = num_tokens_from_string(model, result)

result = self.create_and_ask_prompt_text('analyze_cmd.txt', cmd=cmd, resp=result, facts=self.state)
result = self.create_and_ask_prompt_text(TPL_ANALYZE, cmd=cmd, resp=result, facts=self.state)
return result

def update_state(self, cmd, result):
result = self.create_and_ask_prompt_text('update_state.txt', cmd=cmd, resp=result, facts=self.state)
result = self.create_and_ask_prompt_text(TPL_STATE, cmd=cmd, resp=result, facts=self.state)
self.state = result.result
return result

def get_current_state(self):
return self.state

def create_and_ask_prompt_text(self, template_file, **params):
def create_and_ask_prompt_text(self, template, **params):
# prepare the prompt
template = Template(filename='templates/' + template_file)
prompt = template.render(**params)

if not self.llm_connection.get_model().startswith("gpt-"):
Expand Down
10 changes: 5 additions & 5 deletions wintermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_history_table(config, run_id: int, db: DbStorage, round: int) -> Table:

return table

# setup some infrastructure for outputing information
# setup infrastructure for outputing information
console = Console()

# parse arguments
Expand All @@ -47,17 +47,17 @@ def get_history_table(config, run_id: int, db: DbStorage, round: int) -> Table:
run_id = db.create_new_run(config)

# create the connection to the target
conn = create_target_connection(config.target)
target = create_target_connection(config.target)

# setup LLM connection and internal model representation
# setup the connection to the LLM server
llm_connection = get_llm_connection(config)

# instantiate the concrete LLM model
llm_gpt = LLMWithState(run_id, llm_connection, db, config)

# setup round meta-data
round : int = 0
gotRoot = False
gotRoot : bool = False

# and start everything up
while round < config.max_rounds and not gotRoot:
Expand All @@ -71,7 +71,7 @@ def get_history_table(config, run_id: int, db: DbStorage, round: int) -> Table:
cmd, result, gotRoot = handle_ssh(config.target, answer.result)
else:
console.print(Panel(answer.result, title=f"[bold cyan]Got command from LLM:"))
cmd, result, gotRoot = handle_cmd(conn, answer.result)
cmd, result, gotRoot = handle_cmd(target, answer.result)

db.add_log_query(run_id, round, cmd, result, answer)

Expand Down