11import abc
2- from dataclasses import dataclass
32import datetime
4- from typing import List , Optional
5- import re
63
4+ from dataclasses import dataclass
75from mako .template import Template
8-
9- from hackingBuddyGPT .capabilities .capability import capabilities_to_simple_text_handler
6+ from hackingBuddyGPT .capability import capabilities_to_simple_text_handler
107from hackingBuddyGPT .usecases .base import UseCase
118from hackingBuddyGPT .utils import llm_util
12- from hackingBuddyGPT .utils .cli_history import SlidingCliHistory
9+ from hackingBuddyGPT .utils .histories import HistoryCmdOnly , HistoryFull , HistoryNone
1310from hackingBuddyGPT .utils .openai .openai_llm import OpenAIConnection
1411from hackingBuddyGPT .utils .logging import log_conversation , Logger , log_param , log_section
1512from hackingBuddyGPT .utils .capability_manager import CapabilityManager
16- from hackingBuddyGPT .utils .shell_root_detection import got_root
13+ from typing import List
14+
1715
1816@dataclass
1917class CommandStrategy (UseCase , abc .ABC ):
2018
2119 _capabilities : CapabilityManager = None
2220
23- _sliding_history : SlidingCliHistory = None
24-
25- _max_history_size : int = 0
26-
2721 _template : Template = None
2822
2923 _template_params = {}
@@ -41,125 +35,108 @@ class CommandStrategy(UseCase, abc.ABC):
4135 def before_run (self ):
4236 pass
4337
44- def after_run (self ):
38+ def after_command_execution (self , cmd , result , got_root ):
4539 pass
4640
47- def after_round (self , cmd , result , got_root ):
48- pass
49-
50- def get_space_for_history (self ):
51- pass
41+ def get_token_overhead (self ) -> int :
42+ return 0
5243
5344 def init (self ):
5445 super ().init ()
5546
5647 self ._capabilities = CapabilityManager (self .log )
5748
58- self ._sliding_history = SlidingCliHistory (self .llm )
59-
60- @log_section ("Asking LLM for a new command..." )
61- def get_next_command (self ) -> tuple [str , int ]:
62- history = ""
63- if not self .disable_history :
49+ # TODO: make this more beautiful by just configuring a History-Instance
50+ if self .disable_history :
51+ self ._history = HistoryNone ()
52+ else :
6453 if self .enable_compressed_history :
65- history = self ._sliding_history . get_commands_and_last_output ( self . _max_history_size - self . get_state_size () )
54+ self ._history = HistoryCmdOnly ( )
6655 else :
67- history = self ._sliding_history .get_history (self ._max_history_size - self .get_state_size ())
56+ self ._history = HistoryFull ()
57+
58+ @log_conversation ("Starting run..." )
59+ def run (self , configuration ):
6860
69- self ._template_params .update ({"history" : history })
70- cmd = self .llm .get_response (self ._template , ** self ._template_params )
71- message_id = self .log .call_response (cmd )
61+ self .configuration = configuration
62+ self .log .start_run (self .get_name (), self .serialize_configuration (configuration ))
7263
73- return cmd . result , message_id
64+ self . _template_params [ "capabilities" ] = self . _capabilities . get_capability_block ()
7465
75- @log_section ("Executing that command..." )
76- def run_command (self , cmd , message_id ) -> tuple [Optional [str ], bool ]:
77- _capability_descriptions , parser = capabilities_to_simple_text_handler (self ._capabilities ._capabilities , default_capability = self ._capabilities ._default_capability )
78- start_time = datetime .datetime .now ()
79- success , * output = parser (cmd )
80- if not success :
81- self .log .add_tool_call (message_id , tool_call_id = 0 , function_name = "" , arguments = cmd , result_text = output [0 ], duration = 0 )
82- return output [0 ], False
66+ self .before_run ()
8367
84- assert len (output ) == 1
85- capability , cmd , (result , got_root ) = output [0 ]
86- duration = datetime .datetime .now () - start_time
87- self .log .add_tool_call (message_id , tool_call_id = 0 , function_name = capability , arguments = cmd , result_text = result , duration = duration )
68+ task_successful = False
69+ turn = 1
70+ try :
71+ while turn <= self .max_turns and not task_successful :
72+ with self .log .section (f"round { turn } " ):
73+ self .log .console .log (f"[yellow]Starting turn { turn } of { self .max_turns } " )
74+ task_successful = self .perform_round (turn )
75+ turn += 1
76+ except Exception :
77+ import traceback
78+ self .log .run_was_failure ("exception occurred" , details = f":\n \n { traceback .format_exc ()} " )
79+ raise
8880
89- return result , got_root
81+ # write the final result to the database and console
82+ if task_successful :
83+ self .log .run_was_success ()
84+ else :
85+ self .log .run_was_failure ("maximum turn number reached" )
86+ return task_successful
9087
91- def check_success (self , cmd , result ) -> bool :
92- ansi_escape = re .compile (r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])" )
93- last_line = result .split ("\n " )[- 1 ] if result else ""
94- last_line = ansi_escape .sub ("" , last_line )
95- return got_root (self .conn .hostname , last_line )
96-
97- def postprocess_commands (self , cmd :str ) -> List [str ]:
98- return [cmd ]
99-
100- @log_conversation ("Asking LLM for a new command..." )
88+ @log_conversation ("Asking LLM for a new command(s)..." )
10189 def perform_round (self , turn : int ) -> bool :
10290 # get the next command and run it
10391 cmd , message_id = self .get_next_command ()
10492
10593 cmds = self .postprocess_commands (cmd )
10694 for cmd in cmds :
107- result , task_successful = self .run_command (cmd , message_id )
108-
109- # maybe move the 'got root' detection here?
110- # TODO: also can I use llm-as-judge for that? or do I have to do this
111- # on a per-action base (maybe add a .task_successful(cmd, result, options) -> boolean to the action?
112- task_successful2 = self .check_success (cmd , result )
113- assert (task_successful == task_successful2 )
95+ result = self .run_command (cmd , message_id )
96+ # store the results in our local history
97+ self ._history .append (cmd , result )
11498
115- self .after_round (cmd , result , task_successful )
116-
117- # store the results in our local history
118- if not self .disable_history :
119- if self .enable_compressed_history :
120- self ._sliding_history .add_command_only (cmds , result )
121- else :
122- self ._sliding_history .add_command (cmds , result )
99+ task_successful = self .check_success (cmd , result )
100+ self .after_command_execution (cmd , result , task_successful )
101+ if task_successful :
102+ return True
123103
124104 # signal if we were successful in our task
125- return task_successful
126-
127- @log_conversation ("Starting run..." )
128- def run (self , configuration ):
129-
130- self .configuration = configuration
131- self .log .start_run (self .get_name (), self .serialize_configuration (configuration ))
132-
133- self ._template_params ["capabilities" ] = self ._capabilities .get_capability_block ()
134-
105+ return False
135106
136- # calculate sizes
137- self ._max_history_size = self .llm .context_size - llm_util .SAFETY_MARGIN - self .llm .count_tokens (self ._template .source )
107+ @log_section ("Asking LLM for a new command..." )
108+ def get_next_command (self ) -> tuple [str , int ]:
109+ history = self ._history .get_text_representation ()
138110
139- self .before_run ()
111+ # calculate max history size
112+ max_history_size = self .llm .context_size - llm_util .SAFETY_MARGIN - self .llm .count_tokens (self ._template .source ) - self .get_token_overhead ()
113+ history = llm_util .trim_result_front (self .llm , max_history_size , history )
140114
141- got_root = False
115+ self ._template_params .update ({"history" : history })
116+ cmd = self .llm .get_response (self ._template , ** self ._template_params )
117+ message_id = self .log .call_response (cmd )
142118
143- turn = 1
144- try :
145- while turn <= self .max_turns and not got_root :
146- with self .log .section (f"round { turn } " ):
147- self .log .console .log (f"[yellow]Starting turn { turn } of { self .max_turns } " )
119+ return cmd .result , message_id
148120
149- got_root = self .perform_round (turn )
121+ @log_section ("Executing that command..." )
122+ def run_command (self , cmd , message_id ) -> str :
123+ _capability_descriptions , parser = capabilities_to_simple_text_handler (self ._capabilities ._capabilities , default_capability = self ._capabilities ._default_capability )
124+ start_time = datetime .datetime .now ()
125+ success , * output = parser (cmd )
126+ if not success :
127+ self .log .add_tool_call (message_id , tool_call_id = 0 , function_name = "" , arguments = cmd , result_text = output [0 ], duration = 0 )
128+ return output [0 ]
150129
151- turn += 1
130+ assert len (output ) == 1
131+ capability , cmd , result = output [0 ]
132+ duration = datetime .datetime .now () - start_time
133+ self .log .add_tool_call (message_id , tool_call_id = 0 , function_name = capability , arguments = cmd , result_text = result , duration = duration )
152134
153- self . after_run ()
135+ return result
154136
155- # write the final result to the database and console
156- if got_root :
157- self .log .run_was_success ()
158- else :
159- self .log .run_was_failure ("maximum turn number reached" )
137+ @abc .abstractmethod
138+ def check_success (self , cmd :str , result :str ) -> bool :
139+ return False
160140
161- return got_root
162- except Exception :
163- import traceback
164- self .log .run_was_failure ("exception occurred" , details = f":\n \n { traceback .format_exc ()} " )
165- raise
141+ def postprocess_commands (self , cmd :str ) -> List [str ]:
142+ return [cmd ]
0 commit comments