Skip to content

Commit 569a9f4

Browse files
authored
Merge pull request #62 from ipa-lab/agent_with_worldview
Agent with worldview
2 parents a92428a + 5ded9d4 commit 569a9f4

File tree

4 files changed

+115
-3
lines changed

4 files changed

+115
-3
lines changed

‎usecases/agents.py‎

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from abc import ABC
1+
from abc import ABC, abstractmethod
22
from dataclasses import dataclass, field
33
from typing import Dict
44

55
from capabilities.capability import Capability, capabilities_to_simple_text_handler
66
from usecases.common_patterns import RoundBasedUseCase
77

8+
from mako.template import Template
9+
from rich.panel import Panel
10+
from utils import llm_util
811

912
@dataclass
1013
class Agent(RoundBasedUseCase, ABC):
@@ -25,3 +28,61 @@ def get_capability(self, name: str) -> Capability:
2528
def get_capability_block(self) -> str:
2629
capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities)
2730
return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values())
31+
32+
@dataclass
33+
class AgentWorldview(ABC):
34+
35+
@abstractmethod
36+
def to_template(self):
37+
pass
38+
39+
@abstractmethod
40+
def update(self, capability, cmd, result):
41+
pass
42+
43+
class TemplatedAgent(Agent):
44+
45+
_state: AgentWorldview = None
46+
_template: Template = None
47+
_template_size: int = 0
48+
49+
def init(self):
50+
super().init()
51+
52+
def set_initial_state(self, initial_state):
53+
print("setting state!")
54+
self._state = initial_state
55+
56+
def set_template(self, template):
57+
self._template = Template(filename=template)
58+
self._template_size = self.llm.count_tokens(self._template.source)
59+
60+
def perform_round(self, turn):
61+
got_root : bool = False
62+
63+
with self.console.status("[bold green]Asking LLM for a new command..."):
64+
# TODO output/log state
65+
options = self._state.to_template()
66+
options.update({
67+
'capabilities': self.get_capability_block()
68+
})
69+
70+
print(str(options))
71+
72+
# get the next command from the LLM
73+
answer = self.llm.get_response(self._template, **options)
74+
cmd = llm_util.cmd_output_fixer(answer.result)
75+
76+
with self.console.status("[bold green]Executing that command..."):
77+
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
78+
capability = self.get_capability(cmd.split(" ", 1)[0])
79+
result, got_root = capability(cmd)
80+
81+
# log and output the command and its result
82+
self.log_db.add_log_query(self._run_id, turn, cmd, result, answer)
83+
self._state.update(capability, cmd, result)
84+
# TODO output/log new state
85+
self.console.print(Panel(result, title=f"[bold cyan]{cmd}"))
86+
87+
# if we got root, we can stop the loop
88+
return got_root

‎usecases/minimal/__init__.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .minimal import MinimalLinuxPrivesc
1+
from .agent import MinimalLinuxPrivesc
2+
from .agent_with_state import MinimalLinuxTemplatedPrivesc

‎usecases/minimal/minimal.py‎ renamed to ‎usecases/minimal/agent.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ def perform_round(self, turn):
4949
self.console.print(Panel(result, title=f"[bold cyan]{cmd}"))
5050

5151
# if we got root, we can stop the loop
52-
return got_root
52+
return got_root
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
import pathlib
3+
from dataclasses import dataclass
4+
5+
from capabilities import SSHRunCommand, SSHTestCredential
6+
from utils import SSHConnection, llm_util
7+
from usecases.base import use_case
8+
from usecases.agents import TemplatedAgent, AgentWorldview
9+
from utils.cli_history import SlidingCliHistory
10+
11+
@dataclass
12+
class MinimalLinuxTemplatedPrivescState(AgentWorldview):
13+
sliding_history: SlidingCliHistory = None
14+
max_history_size: int = 0
15+
16+
conn: SSHConnection = None
17+
18+
def __init__(self, conn, llm, max_history_size):
19+
self.sliding_history = SlidingCliHistory(llm)
20+
self.max_history_size = max_history_size
21+
self.conn = conn
22+
23+
def update(self, capability, cmd, result):
24+
self.sliding_history.add_command(cmd, result)
25+
26+
def to_template(self):
27+
return {
28+
'history': self.sliding_history.get_history(self.max_history_size),
29+
'conn': self.conn
30+
}
31+
32+
@use_case("minimal_linux_templated_agent", "Showcase Minimal Linux Priv-Escalation")
33+
@dataclass
34+
class MinimalLinuxTemplatedPrivesc(TemplatedAgent):
35+
36+
conn: SSHConnection = None
37+
38+
def init(self):
39+
super().init()
40+
41+
# setup default template
42+
self.set_template(str(pathlib.Path(__file__).parent / "next_cmd.txt"))
43+
44+
# setup capabilities
45+
self.add_capability(SSHRunCommand(conn=self.conn), default=True)
46+
self.add_capability(SSHTestCredential(conn=self.conn))
47+
48+
# setup state
49+
max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size
50+
self.set_initial_state(MinimalLinuxTemplatedPrivescState(self.conn, self.llm, max_history_size))

0 commit comments

Comments
 (0)