11import abc
22import datetime
3- import re
43
54from dataclasses import dataclass
65from mako .template import Template
1110from hackingBuddyGPT .utils .openai .openai_llm import OpenAIConnection
1211from hackingBuddyGPT .utils .logging import log_conversation , Logger , log_param , log_section
1312from hackingBuddyGPT .utils .capability_manager import CapabilityManager
14- from hackingBuddyGPT .utils .shell_root_detection import got_root
15- from typing import List , Optional
13+ from typing import List
1614
1715
1816@dataclass
@@ -37,10 +35,7 @@ class CommandStrategy(UseCase, abc.ABC):
3735 def before_run (self ):
3836 pass
3937
40- def after_run (self ):
41- pass
42-
43- def after_round (self , cmd , result , got_root ):
38+ def after_command_execution (self , cmd , result , got_root ):
4439 pass
4540
4641 def get_token_overhead (self ) -> int :
@@ -51,13 +46,63 @@ def init(self):
5146
5247 self ._capabilities = CapabilityManager (self .log )
5348
49+ # TODO: make this more beautiful by just configuring a History-Instance
5450 if self .disable_history :
5551 self ._history = HistoryNone ()
5652 else :
5753 if self .enable_compressed_history :
5854 self ._history = HistoryCmdOnly ()
5955 else :
6056 self ._history = HistoryFull ()
57+
58+ @log_conversation ("Starting run..." )
59+ def run (self , configuration ):
60+
61+ self .configuration = configuration
62+ self .log .start_run (self .get_name (), self .serialize_configuration (configuration ))
63+
64+ self ._template_params ["capabilities" ] = self ._capabilities .get_capability_block ()
65+
66+ self .before_run ()
67+
68+ task_successful = False
69+ turn = 1
70+ try :
71+ while turn <= self .max_turns and not task_successful :
72+ with self .log .section (f"round { turn } " ):
73+ self .log .console .log (f"[yellow]Starting turn { turn } of { self .max_turns } " )
74+ task_successful = self .perform_round (turn )
75+ turn += 1
76+ except Exception :
77+ import traceback
78+ self .log .run_was_failure ("exception occurred" , details = f":\n \n { traceback .format_exc ()} " )
79+ raise
80+
81+ # write the final result to the database and console
82+ if task_successful :
83+ self .log .run_was_success ()
84+ else :
85+ self .log .run_was_failure ("maximum turn number reached" )
86+ return task_successful
87+
88+ @log_conversation ("Asking LLM for a new command(s)..." )
89+ def perform_round (self , turn : int ) -> bool :
90+ # get the next command and run it
91+ cmd , message_id = self .get_next_command ()
92+
93+ cmds = self .postprocess_commands (cmd )
94+ for cmd in cmds :
95+ result = self .run_command (cmd , message_id )
96+ # store the results in our local history
97+ self ._history .append (cmd , result )
98+
99+ task_successful = self .check_success (cmd , result )
100+ self .after_command_execution (cmd , result , task_successful )
101+ if task_successful :
102+ return True
103+
104+ # signal if we were successful in our task
105+ return False
61106
62107 @log_section ("Asking LLM for a new command..." )
63108 def get_next_command (self ) -> tuple [str , int ]:
@@ -74,84 +119,24 @@ def get_next_command(self) -> tuple[str, int]:
74119 return cmd .result , message_id
75120
76121 @log_section ("Executing that command..." )
77- def run_command (self , cmd , message_id ) -> tuple [ Optional [ str ], bool ] :
122+ def run_command (self , cmd , message_id ) -> str :
78123 _capability_descriptions , parser = capabilities_to_simple_text_handler (self ._capabilities ._capabilities , default_capability = self ._capabilities ._default_capability )
79124 start_time = datetime .datetime .now ()
80125 success , * output = parser (cmd )
81126 if not success :
82127 self .log .add_tool_call (message_id , tool_call_id = 0 , function_name = "" , arguments = cmd , result_text = output [0 ], duration = 0 )
83- return output [0 ], False
128+ return output [0 ]
84129
85130 assert len (output ) == 1
86- capability , cmd , ( result , got_root ) = output [0 ]
131+ capability , cmd , result = output [0 ]
87132 duration = datetime .datetime .now () - start_time
88133 self .log .add_tool_call (message_id , tool_call_id = 0 , function_name = capability , arguments = cmd , result_text = result , duration = duration )
89134
90- return result , got_root
91-
92- def check_success (self , cmd , result ) -> bool :
93- ansi_escape = re .compile (r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])" )
94- last_line = result .split ("\n " )[- 1 ] if result else ""
95- last_line = ansi_escape .sub ("" , last_line )
96- return got_root (self .conn .hostname , last_line )
135+ return result
136+
137+ @abc .abstractmethod
138+ def check_success (self , cmd :str , result :str ) -> bool :
139+ return False
97140
98141 def postprocess_commands (self , cmd :str ) -> List [str ]:
99142 return [cmd ]
100-
101- @log_conversation ("Asking LLM for a new command..." )
102- def perform_round (self , turn : int ) -> bool :
103- # get the next command and run it
104- cmd , message_id = self .get_next_command ()
105-
106- cmds = self .postprocess_commands (cmd )
107- for cmd in cmds :
108- result , task_successful = self .run_command (cmd , message_id )
109- # store the results in our local history
110- self ._history .append (cmd , result )
111-
112- # maybe move the 'got root' detection here?
113- # TODO: also can I use llm-as-judge for that? or do I have to do this
114- # on a per-action base (maybe add a .task_successful(cmd, result, options) -> boolean to the action?
115- task_successful2 = self .check_success (cmd , result )
116- assert (task_successful == task_successful2 )
117-
118- self .after_round (cmd , result , task_successful )
119-
120- # signal if we were successful in our task
121- return task_successful
122-
123- @log_conversation ("Starting run..." )
124- def run (self , configuration ):
125-
126- self .configuration = configuration
127- self .log .start_run (self .get_name (), self .serialize_configuration (configuration ))
128-
129- self ._template_params ["capabilities" ] = self ._capabilities .get_capability_block ()
130-
131- self .before_run ()
132-
133- got_root = False
134-
135- turn = 1
136- try :
137- while turn <= self .max_turns and not got_root :
138- with self .log .section (f"round { turn } " ):
139- self .log .console .log (f"[yellow]Starting turn { turn } of { self .max_turns } " )
140-
141- got_root = self .perform_round (turn )
142-
143- turn += 1
144-
145- self .after_run ()
146-
147- # write the final result to the database and console
148- if got_root :
149- self .log .run_was_success ()
150- else :
151- self .log .run_was_failure ("maximum turn number reached" )
152-
153- return got_root
154- except Exception :
155- import traceback
156- self .log .run_was_failure ("exception occurred" , details = f":\n \n { traceback .format_exc ()} " )
157- raise
0 commit comments