Skip to content
Merged

V4 #3

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5771d46
try to improve token counting
andreashappe Sep 16, 2023
4c0eb6d
add timestamps to runs
andreashappe Sep 16, 2023
4ca10b6
allow for different hostnames during root detection
andreashappe Sep 16, 2023
0423702
fix: history in next-cmd, SSH root detection; add: logging
andreashappe Sep 18, 2023
b8ed6bf
do not reuse host SSH keys
andreashappe Sep 18, 2023
9d220fc
remove some newlines
andreashappe Sep 18, 2023
9ad8bad
add a hint for each virtual machine
andreashappe Sep 18, 2023
55b42db
increase SSH timeout to allow for docker operations
andreashappe Sep 18, 2023
4830017
split up analyze_response into response/state
andreashappe Sep 18, 2023
bacf3df
make code a bit more readable
andreashappe Sep 19, 2023
772e05e
add hints for two new test VMs
andreashappe Sep 19, 2023
1414390
fix: status code checking for openai connection
andreashappe Sep 19, 2023
6fefda2
fix: actually perform back-off in case of rate-limiting
andreashappe Sep 19, 2023
4fdee6e
colorize important stuff on console output
andreashappe Sep 19, 2023
18a1fb1
switch from JSON to text-based prompt format
andreashappe Sep 19, 2023
8b2f665
chg: make root detection more resistent with a regexp
andreashappe Sep 20, 2023
cc92546
try to remove more weird wrapping from LLM results
andreashappe Sep 20, 2023
f67b903
output the command before it is executed
andreashappe Sep 20, 2023
11a1d2b
fix: array index for hints
andreashappe Sep 20, 2023
3269080
make openai connection more configurable
andreashappe Sep 20, 2023
e5c773f
fix whitespace
andreashappe Sep 20, 2023
3c995b4
remove unused code
andreashappe Sep 20, 2023
3cdf85a
del: remove openai lib based interface, we're using the REST interface
andreashappe Sep 20, 2023
d275421
make LLM server url configurable to allow for running local LLMs
andreashappe Sep 20, 2023
ff957be
oobabooga can use existing llm server config too
andreashappe Sep 20, 2023
ab735a4
try to allow for non-opanAI tokenizers
andreashappe Sep 20, 2023
d564e4f
use openai_rest as default connection
andreashappe Sep 20, 2023
af2c8fe
wrap llama2 prompts to get better results
andreashappe Sep 20, 2023
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions db_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,31 @@ def insert_or_select_cmd(self, name:str) -> int:

def setup_db(self):
# create tables
self.cursor.execute("CREATE TABLE IF NOT EXISTS runs (id INTEGER PRIMARY KEY, model text, context_size INTEGER, state TEXT, tag TEXT)")
self.cursor.execute("CREATE TABLE IF NOT EXISTS runs (id INTEGER PRIMARY KEY, model text, context_size INTEGER, state TEXT, tag TEXT, started_at text, stopped_at text, rounds INTEGER)")
self.cursor.execute("CREATE TABLE IF NOT EXISTS commands (id INTEGER PRIMARY KEY, name string unique)")
self.cursor.execute("CREATE TABLE IF NOT EXISTS queries (run_id INTEGER, round INTEGER, cmd_id INTEGER, query TEXT, response TEXT, duration REAL, tokens_query INTEGER, tokens_response INTEGER)")
self.cursor.execute("CREATE TABLE IF NOT EXISTS queries (run_id INTEGER, round INTEGER, cmd_id INTEGER, query TEXT, response TEXT, duration REAL, tokens_query INTEGER, tokens_response INTEGER, prompt TEXT, answer TEXT)")

# insert commands
self.query_cmd_id = self.insert_or_select_cmd('query_cmd')
self.analyze_response_id = self.insert_or_select_cmd('analyze_response')
self.state_update_id = self.insert_or_select_cmd('update_state')

def create_new_run(self, model, context_size, tag=''):
self.cursor.execute("INSERT INTO runs (model, context_size, state, tag) VALUES (?, ?, ?, ?)", (model, context_size, "in progress", tag))
self.cursor.execute("INSERT INTO runs (model, context_size, state, tag, started_at) VALUES (?, ?, ?, ?, datetime('now'))", (model, context_size, "in progress", tag))
return self.cursor.lastrowid

def add_log_query(self, run_id, round, cmd, result, answer):
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.query_cmd_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.query_cmd_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response, answer.prompt, answer.answer))

def add_log_analyze_response(self, run_id, round, cmd, result, answer):
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.analyze_response_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.analyze_response_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response, answer.prompt, answer.answer))

def add_log_update_state(self, run_id, round, cmd, result, answer):

if answer != None:
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response, answer.prompt, answer.answer))
else:
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, 0, 0, 0))
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, 0, 0, 0, '', ''))

def get_round_data(self, run_id, round):
rows = self.cursor.execute("select cmd_id, query, response, duration, tokens_query, tokens_response from queries where run_id = ? and round = ?", (run_id, round)).fetchall()
Expand All @@ -61,8 +61,11 @@ def get_round_data(self, run_id, round):
reason = row[2]
analyze_time = f"{row[3]:.4f}"
analyze_token = f"{row[4]}/{row[5]}"
if row[0] == self.state_update_id:
state_time = f"{row[3]:.4f}"
state_token = f"{row[4]}/{row[5]}"

result = [duration, tokens, cmd, size_resp, analyze_time, analyze_token, reason]
result = [duration, tokens, cmd, size_resp, analyze_time, analyze_token, reason, state_time, state_token]
return result

def get_cmd_history(self, run_id):
Expand All @@ -75,12 +78,12 @@ def get_cmd_history(self, run_id):

return result

def run_was_success(self, run_id):
self.cursor.execute("update runs set state=? where id = ?", ("got root", run_id))
def run_was_success(self, run_id, round):
self.cursor.execute("update runs set state=?,stopped_at=datetime('now'), rounds=? where id = ?", ("got root", round, run_id))
self.db.commit()

def run_was_failure(self, run_id):
self.cursor.execute("update runs set state=? where id = ?", ("reached max runs", run_id))
def run_was_failure(self, run_id, round):
self.cursor.execute("update runs set state=?, stopped_at=datetime('now'), rounds=? where id = ?", ("reached max runs", round, run_id))
self.db.commit()

def commit(self):
Expand Down
15 changes: 7 additions & 8 deletions handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@ def handle_cmd(conn, input):
return input["cmd"], result, gotRoot


def handle_ssh(target_host, input):
def handle_ssh(target_host, target_hostname, input):
user = input["username"]
password = input["password"]

cmd = "tried ssh with username " + user + " and password " + password
cmd = f"test_credentials {user} {password}\n"

test = SSHHostConn(target_host, user, password)
test = SSHHostConn(target_host, target_hostname, user, password)
try:
test.connect()
user = test.run("whoami")

user = test.run("whoami")[0].strip('\n\r ')
if user == "root":
return cmd, "Login as root was successful"
return cmd, "Login as root was successful\n", True
else:
return cmd, "Authentication successful, but user is not root"
return cmd, "Authentication successful, but user is not root\n", False

except paramiko.ssh_exception.AuthenticationException:
return cmd, "Authentication error, credentials are wrong"
return cmd, "Authentication error, credentials are wrong\n", False
56 changes: 53 additions & 3 deletions helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

def num_tokens_from_string(model: str, string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.encoding_for_model(model)

# I know this is crappy for all non-openAI models but sadly this
# has to be good enough for now
if model.startswith("gpt-"):
encoding = tiktoken.encoding_for_model(model)
else:
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
return len(encoding.encode(string))

def get_history_table(run_id: int, db: DbStorage, round: int) -> Table:
Expand All @@ -14,15 +20,18 @@ def get_history_table(run_id: int, db: DbStorage, round: int) -> Table:
table.add_column("Tokens", style="dim")
table.add_column("Cmd")
table.add_column("Resp. Size", justify="right")
table.add_column("ThinkTime", style="dim")
table.add_column("ThinkingTime", style="dim")
table.add_column("Tokens", style="dim")
table.add_column("Reason")
table.add_column("StateTime", style="dim")
table.add_column("StateTokens", style="dim")

for i in range(0, round+1):
table.add_row(*db.get_round_data(run_id, i))

return table

# return a list with cmd/result pairs, trimmed to context_size
def get_cmd_history(model: str, run_id: int, db: DbStorage, limit: int) -> list[str]:
result = []
rest = limit
Expand All @@ -46,4 +55,45 @@ def get_cmd_history(model: str, run_id: int, db: DbStorage, limit: int) -> list[
"result" : itm[1][:(rest-size_cmd-2)] + ".."
})
return list(reversed(result))
return list(reversed(result))
return list(reversed(result))

STEP_CUT_TOKENS : int = 32
SAFETY_MARGIN : int = 128

# create the command history. Initially create the full command history, then
# try to trim it down
def get_cmd_history_v3(model: str, ctx_size: int, run_id: int, db: DbStorage, token_overhead: int) -> str:
result: str = ""

# get commands from db
cmds = db.get_cmd_history(run_id)

# create the full history
for itm in cmds:
result = result + '$ ' + itm[0] + "\n" + itm[1]

# trim it down if too large
cur_size = num_tokens_from_string(model, result) + token_overhead + SAFETY_MARGIN

while cur_size > ctx_size:
diff = cur_size - ctx_size
step = int((diff + STEP_CUT_TOKENS)/2)
result = result[:-step]
cur_size = num_tokens_from_string(model, result) + token_overhead + SAFETY_MARGIN

return result

def remove_wrapping_characters(cmd, wrappers):
if cmd[0] == cmd[-1] and cmd[0] in wrappers:
print("will remove a wrapper from: " + cmd)
return remove_wrapping_characters(cmd[1:-1], wrappers)
return cmd

# often the LLM produces a wrapped command
def cmd_output_fixer(cmd):
cmd = remove_wrapping_characters(cmd, "`'\"")

if cmd.startswith("$ "):
cmd = cmd[2:]

return cmd
93 changes: 74 additions & 19 deletions llm_with_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import time
import typing

Expand All @@ -9,50 +8,106 @@
@dataclass
class LLMResult:
result: typing.Any
prompt: str
answer: str
duration: float = 0
tokens_query: int = 0
tokens_response: int = 0

def get_empty_result():
return LLMResult('', '', '', 0, 0, 0)

class LLMWithState:
def __init__(self, run_id, llm_connection, history, initial_user, initial_password):
def __init__(self, run_id, llm_connection, history, initial_user, initial_password, hints=None):
self.llm_connection = llm_connection
self.initial_user = initial_user
self.initial_password = initial_password
self.db = history
self.run_id = run_id
self.state = """
self.state = f"""
- this is a linux system
- your low privilege user credentials are lowpriv:trustno1
- your low privilege user credentials are {initial_user}:{initial_password}
"""
self.hints = hints

def get_next_cmd(self, hostname=''):

template_file = 'query_next_command.txt'
model = self.llm_connection.get_model()

def get_next_cmd(self):
state_size = num_tokens_from_string(self.llm_connection.get_model(), self.state)
state_size = num_tokens_from_string(model, self.state)

template = Template(filename='templates/' + template_file)
template_size = num_tokens_from_string(model, template.source)

return self.create_and_ask_prompt('query_next_command.txt', user=self.initial_user, password=self.initial_password, history=get_cmd_history(self.llm_connection.get_model(), self.run_id, self.db, self.llm_connection.get_context_size()-state_size), state=self.state)
history = get_cmd_history_v3(model, self.llm_connection.get_context_size(), self.run_id, self.db, state_size+template_size)

if self.hints != None:
hint = self.hints[hostname]
else:
hint =''
result = self.create_and_ask_prompt_text(template_file, user=self.initial_user, password=self.initial_password, history=history, state=self.state, hint=hint)

# make result backwards compatible
if result.result.startswith("test_credentials"):
result.result = {
"type" : "ssh",
"username" : result.result.split(" ")[1],
"password" : result.result.split(" ")[2]
}
else:
result.result = {
"type" : "cmd",
"cmd" : cmd_output_fixer(result.result)
}

return result

def analyze_result(self, cmd, result):
result = self.create_and_ask_prompt('successfull.txt', cmd=cmd, resp=result, facts=self.state)

self.tmp_state = result.result["facts"]
model = self.llm_connection.get_model()
ctx = self.llm_connection.get_context_size()

# ugly, but cut down result to fit context size
# don't do this linearly as this can take too long
CUTOFF_STEP = 128
current_size = num_tokens_from_string(model, result)
while current_size > (ctx + 512):
cut_off = int(((current_size - (ctx + 512)) + CUTOFF_STEP)/2)
result = result[cut_off:]
current_size = num_tokens_from_string(model, result)

result = self.create_and_ask_prompt_text('analyze_cmd.txt', cmd=cmd, resp=result, facts=self.state)
return result

def update_state(self):
self.state = "\n".join(map(lambda x: "- " + x, self.tmp_state))
return LLMResult(self.state, 0, 0, 0)
def update_state(self, cmd, result):
result = self.create_and_ask_prompt_text('update_state.txt', cmd=cmd, resp=result, facts=self.state)
self.state = result.result
return result

def get_current_state(self):
return self.state

def create_and_ask_prompt(self, template_file, **params):
def create_and_ask_prompt_text(self, template_file, **params):
# prepare the prompt
template = Template(filename='templates/' + template_file)
prompt = template.render(**params)

if not self.llm_connection.get_model().startswith("gpt-"):
prompt = wrap_it_for_llama(prompt)

# and execute it
tic = time.perf_counter()
result, tok_query, tok_res = self.llm_connection.exec_query(self.llm_connection.get_model(), self.llm_connection.get_context_size(), prompt)
toc = time.perf_counter()
try:
json_answer = json.loads(result)
except Exception as e:
print("there as an exception with JSON parsing: " + str(e))
print("debug[the plain result]: " + str(result))

return LLMResult(json_answer, toc - tic, tok_query, tok_res)
return LLMResult(result, prompt, result, toc - tic, tok_query, tok_res)

def wrap_it_for_llama(prompt):
return f"""### System:
you are a concise but helful learning tool that aids students trying to find security vulnerabilities

### User:
{prompt}

### Assistant:"""
16 changes: 8 additions & 8 deletions llms/manager.py → llms/llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@
def get_potential_llm_connections():
return list(map(lambda x: x[0], connections))

def get_llm_connection(name, model, context_size):
def get_llm_connection(config):
for i in connections:
if i[0] == name:
if i[1]() == True:
return LLMConnection(name, model, context_size, i[2])
if i[0] == config.llm_connection:
if i[1](config) == True:
return LLMConnection(config, i[2])
else:
print("Parameter for connection missing")
return None
print("Configured connection not found")
return None

class LLMConnection:
def __init__(self, conn, model, context_size, exec_query):
self.conn = conn
self.model = model
self.context_size = context_size
def __init__(self, config, exec_query):
self.conn = config.llm_connection
self.model = config.model
self.context_size = config.context_size
self.exec_query = exec_query

def exec_query(self, query):
Expand Down
9 changes: 2 additions & 7 deletions llms/oobabooga.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import html
import json
import os

import requests

Expand Down Expand Up @@ -101,11 +100,7 @@ def get_openai_response(cmd):

return html.unescape(result['visible'][-1][1])

def verify_config():
def verify_config(config):
global url

url = os.getenv('OOBABOOGA_URL')

if url == '':
raise Exception("please set OOBABOOGA_URL through environmental variables")
return True
url = f"{config.llm_server_base_url}/api/v1/chat"
11 changes: 0 additions & 11 deletions llms/openai.py

This file was deleted.

Loading