Skip to content

Commit 509b534

Browse files
committed
simplify llm_with_state a bit
1 parent a35aece commit 509b534

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

‎llm_with_state.py‎

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ class LLMResult:
1515
tokens_query: int = 0
1616
tokens_response: int = 0
1717

18+
19+
TPL_NEXT = Template(filename='templates/query_next_command.txt')
20+
TPL_ANALYZE = Template(filename="templates/analyze_cmd.txt")
21+
TPL_STATE = Template(filename="templates/update_state.txt")
22+
1823
class LLMWithState:
1924
def __init__(self, run_id, llm_connection, history, config):
2025
self.llm_connection = llm_connection
@@ -35,13 +40,10 @@ def get_state_size(self, model):
3540

3641
def get_next_cmd(self):
3742

38-
template_file = 'query_next_command.txt'
3943
model = self.llm_connection.get_model()
4044

4145
state_size = self.get_state_size(model)
42-
43-
template = Template(filename='templates/' + template_file)
44-
template_size = num_tokens_from_string(model, template.source)
46+
template_size = num_tokens_from_string(model, TPL_NEXT.source)
4547

4648
history = get_cmd_history_v3(model, self.llm_connection.get_context_size(), self.run_id, self.db, state_size+template_size)
4749

@@ -50,7 +52,7 @@ def get_next_cmd(self):
5052
else:
5153
target_user = "Administrator"
5254

53-
return self.create_and_ask_prompt_text(template_file, history=history, state=self.state, target=self.target, update_state=self.enable_update_state, target_user=target_user)
55+
return self.create_and_ask_prompt_text(TPL_NEXT, history=history, state=self.state, target=self.target, update_state=self.enable_update_state, target_user=target_user)
5456

5557
def analyze_result(self, cmd, result):
5658

@@ -66,20 +68,19 @@ def analyze_result(self, cmd, result):
6668
result = result[cut_off:]
6769
current_size = num_tokens_from_string(model, result)
6870

69-
result = self.create_and_ask_prompt_text('analyze_cmd.txt', cmd=cmd, resp=result, facts=self.state)
71+
result = self.create_and_ask_prompt_text(TPL_ANALYZE, cmd=cmd, resp=result, facts=self.state)
7072
return result
7173

7274
def update_state(self, cmd, result):
73-
result = self.create_and_ask_prompt_text('update_state.txt', cmd=cmd, resp=result, facts=self.state)
75+
result = self.create_and_ask_prompt_text(TPL_STATE, cmd=cmd, resp=result, facts=self.state)
7476
self.state = result.result
7577
return result
7678

7779
def get_current_state(self):
7880
return self.state
7981

80-
def create_and_ask_prompt_text(self, template_file, **params):
82+
def create_and_ask_prompt_text(self, template, **params):
8183
# prepare the prompt
82-
template = Template(filename='templates/' + template_file)
8384
prompt = template.render(**params)
8485

8586
if not self.llm_connection.get_model().startswith("gpt-"):

‎wintermute.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_history_table(config, run_id: int, db: DbStorage, round: int) -> Table:
3131

3232
return table
3333

34-
# setup some infrastructure for outputing information
34+
# setup infrastructure for outputing information
3535
console = Console()
3636

3737
# parse arguments
@@ -47,17 +47,17 @@ def get_history_table(config, run_id: int, db: DbStorage, round: int) -> Table:
4747
run_id = db.create_new_run(config)
4848

4949
# create the connection to the target
50-
conn = create_target_connection(config.target)
50+
target = create_target_connection(config.target)
5151

52-
# setup LLM connection and internal model representation
52+
# setup the connection to the LLM server
5353
llm_connection = get_llm_connection(config)
5454

5555
# instantiate the concrete LLM model
5656
llm_gpt = LLMWithState(run_id, llm_connection, db, config)
5757

5858
# setup round meta-data
5959
round : int = 0
60-
gotRoot = False
60+
gotRoot : bool = False
6161

6262
# and start everything up
6363
while round < config.max_rounds and not gotRoot:
@@ -71,7 +71,7 @@ def get_history_table(config, run_id: int, db: DbStorage, round: int) -> Table:
7171
cmd, result, gotRoot = handle_ssh(config.target, answer.result)
7272
else:
7373
console.print(Panel(answer.result, title=f"[bold cyan]Got command from LLM:"))
74-
cmd, result, gotRoot = handle_cmd(conn, answer.result)
74+
cmd, result, gotRoot = handle_cmd(target, answer.result)
7575

7676
db.add_log_query(run_id, round, cmd, result, answer)
7777

0 commit comments

Comments
 (0)