Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion usecases/agents.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict

from capabilities.capability import Capability, capabilities_to_simple_text_handler
from usecases.common_patterns import RoundBasedUseCase

from mako.template import Template
from rich.panel import Panel
from utils import llm_util

@dataclass
class Agent(RoundBasedUseCase, ABC):
Expand All @@ -25,3 +28,61 @@ def get_capability(self, name: str) -> Capability:
def get_capability_block(self) -> str:
capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities)
return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values())

@dataclass
class AgentWorldview(ABC):

@abstractmethod
def to_template(self):
pass

@abstractmethod
def update(self, capability, cmd, result):
pass

class TemplatedAgent(Agent):

_state: AgentWorldview = None
_template: Template = None
_template_size: int = 0

def init(self):
super().init()

def set_initial_state(self, initial_state):
print("setting state!")
self._state = initial_state

def set_template(self, template):
self._template = Template(filename=template)
self._template_size = self.llm.count_tokens(self._template.source)

def perform_round(self, turn):
got_root : bool = False

with self.console.status("[bold green]Asking LLM for a new command..."):
# TODO output/log state
options = self._state.to_template()
options.update({
'capabilities': self.get_capability_block()
})

print(str(options))

# get the next command from the LLM
answer = self.llm.get_response(self._template, **options)
cmd = llm_util.cmd_output_fixer(answer.result)

with self.console.status("[bold green]Executing that command..."):
self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:"))
capability = self.get_capability(cmd.split(" ", 1)[0])
result, got_root = capability(cmd)

# log and output the command and its result
self.log_db.add_log_query(self._run_id, turn, cmd, result, answer)
self._state.update(capability, cmd, result)
# TODO output/log new state
self.console.print(Panel(result, title=f"[bold cyan]{cmd}"))

# if we got root, we can stop the loop
return got_root
3 changes: 2 additions & 1 deletion usecases/minimal/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .minimal import MinimalLinuxPrivesc
from .agent import MinimalLinuxPrivesc
from .agent_with_state import MinimalLinuxTemplatedPrivesc
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def perform_round(self, turn):
self.console.print(Panel(result, title=f"[bold cyan]{cmd}"))

# if we got root, we can stop the loop
return got_root
return got_root
50 changes: 50 additions & 0 deletions usecases/minimal/agent_with_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

import pathlib
from dataclasses import dataclass

from capabilities import SSHRunCommand, SSHTestCredential
from utils import SSHConnection, llm_util
from usecases.base import use_case
from usecases.agents import TemplatedAgent, AgentWorldview
from utils.cli_history import SlidingCliHistory

@dataclass
class MinimalLinuxTemplatedPrivescState(AgentWorldview):
sliding_history: SlidingCliHistory = None
max_history_size: int = 0

conn: SSHConnection = None

def __init__(self, conn, llm, max_history_size):
self.sliding_history = SlidingCliHistory(llm)
self.max_history_size = max_history_size
self.conn = conn

def update(self, capability, cmd, result):
self.sliding_history.add_command(cmd, result)

def to_template(self):
return {
'history': self.sliding_history.get_history(self.max_history_size),
'conn': self.conn
}

@use_case("minimal_linux_templated_agent", "Showcase Minimal Linux Priv-Escalation")
@dataclass
class MinimalLinuxTemplatedPrivesc(TemplatedAgent):

conn: SSHConnection = None

def init(self):
super().init()

# setup default template
self.set_template(str(pathlib.Path(__file__).parent / "next_cmd.txt"))

# setup capabilities
self.add_capability(SSHRunCommand(conn=self.conn), default=True)
self.add_capability(SSHTestCredential(conn=self.conn))

# setup state
max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size
self.set_initial_state(MinimalLinuxTemplatedPrivescState(self.conn, self.llm, max_history_size))