Skip to content

Commit 394ebf2

Browse files
authored
Merge pull request ipa-lab#3 from ipa-lab/v4
V4
2 parents fc557af + af2c8fe commit 394ebf2

File tree

13 files changed

+311
-138
lines changed

13 files changed

+311
-138
lines changed

‎db_storage.py‎

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,31 @@ def insert_or_select_cmd(self, name:str) -> int:
2222

2323
def setup_db(self):
2424
# create tables
25-
self.cursor.execute("CREATE TABLE IF NOT EXISTS runs (id INTEGER PRIMARY KEY, model text, context_size INTEGER, state TEXT, tag TEXT)")
25+
self.cursor.execute("CREATE TABLE IF NOT EXISTS runs (id INTEGER PRIMARY KEY, model text, context_size INTEGER, state TEXT, tag TEXT, started_at text, stopped_at text, rounds INTEGER)")
2626
self.cursor.execute("CREATE TABLE IF NOT EXISTS commands (id INTEGER PRIMARY KEY, name string unique)")
27-
self.cursor.execute("CREATE TABLE IF NOT EXISTS queries (run_id INTEGER, round INTEGER, cmd_id INTEGER, query TEXT, response TEXT, duration REAL, tokens_query INTEGER, tokens_response INTEGER)")
27+
self.cursor.execute("CREATE TABLE IF NOT EXISTS queries (run_id INTEGER, round INTEGER, cmd_id INTEGER, query TEXT, response TEXT, duration REAL, tokens_query INTEGER, tokens_response INTEGER, prompt TEXT, answer TEXT)")
2828

2929
# insert commands
3030
self.query_cmd_id = self.insert_or_select_cmd('query_cmd')
3131
self.analyze_response_id = self.insert_or_select_cmd('analyze_response')
3232
self.state_update_id = self.insert_or_select_cmd('update_state')
3333

3434
def create_new_run(self, model, context_size, tag=''):
35-
self.cursor.execute("INSERT INTO runs (model, context_size, state, tag) VALUES (?, ?, ?, ?)", (model, context_size, "in progress", tag))
35+
self.cursor.execute("INSERT INTO runs (model, context_size, state, tag, started_at) VALUES (?, ?, ?, ?, datetime('now'))", (model, context_size, "in progress", tag))
3636
return self.cursor.lastrowid
3737

3838
def add_log_query(self, run_id, round, cmd, result, answer):
39-
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.query_cmd_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))
39+
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.query_cmd_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response, answer.prompt, answer.answer))
4040

4141
def add_log_analyze_response(self, run_id, round, cmd, result, answer):
42-
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.analyze_response_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))
42+
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.analyze_response_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response, answer.prompt, answer.answer))
4343

4444
def add_log_update_state(self, run_id, round, cmd, result, answer):
4545

4646
if answer != None:
47-
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response))
47+
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, answer.duration, answer.tokens_query, answer.tokens_response, answer.prompt, answer.answer))
4848
else:
49-
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, 0, 0, 0))
49+
self.cursor.execute("INSERT INTO queries (run_id, round, cmd_id, query, response, duration, tokens_query, tokens_response, prompt, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", (run_id, round, self.state_update_id, cmd, result, 0, 0, 0, '', ''))
5050

5151
def get_round_data(self, run_id, round):
5252
rows = self.cursor.execute("select cmd_id, query, response, duration, tokens_query, tokens_response from queries where run_id = ? and round = ?", (run_id, round)).fetchall()
@@ -61,8 +61,11 @@ def get_round_data(self, run_id, round):
6161
reason = row[2]
6262
analyze_time = f"{row[3]:.4f}"
6363
analyze_token = f"{row[4]}/{row[5]}"
64+
if row[0] == self.state_update_id:
65+
state_time = f"{row[3]:.4f}"
66+
state_token = f"{row[4]}/{row[5]}"
6467

65-
result = [duration, tokens, cmd, size_resp, analyze_time, analyze_token, reason]
68+
result = [duration, tokens, cmd, size_resp, analyze_time, analyze_token, reason, state_time, state_token]
6669
return result
6770

6871
def get_cmd_history(self, run_id):
@@ -75,12 +78,12 @@ def get_cmd_history(self, run_id):
7578

7679
return result
7780

78-
def run_was_success(self, run_id):
79-
self.cursor.execute("update runs set state=? where id = ?", ("got root", run_id))
81+
def run_was_success(self, run_id, round):
82+
self.cursor.execute("update runs set state=?,stopped_at=datetime('now'), rounds=? where id = ?", ("got root", round, run_id))
8083
self.db.commit()
8184

82-
def run_was_failure(self, run_id):
83-
self.cursor.execute("update runs set state=? where id = ?", ("reached max runs", run_id))
85+
def run_was_failure(self, run_id, round):
86+
self.cursor.execute("update runs set state=?, stopped_at=datetime('now'), rounds=? where id = ?", ("reached max runs", round, run_id))
8487
self.db.commit()
8588

8689
def commit(self):

‎handlers.py‎

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,20 @@ def handle_cmd(conn, input):
77
return input["cmd"], result, gotRoot
88

99

10-
def handle_ssh(target_host, input):
10+
def handle_ssh(target_host, target_hostname, input):
1111
user = input["username"]
1212
password = input["password"]
1313

14-
cmd = "tried ssh with username " + user + " and password " + password
14+
cmd = f"test_credentials {user} {password}\n"
1515

16-
test = SSHHostConn(target_host, user, password)
16+
test = SSHHostConn(target_host, target_hostname, user, password)
1717
try:
1818
test.connect()
19-
user = test.run("whoami")
20-
19+
user = test.run("whoami")[0].strip('\n\r ')
2120
if user == "root":
22-
return cmd, "Login as root was successful"
21+
return cmd, "Login as root was successful\n", True
2322
else:
24-
return cmd, "Authentication successful, but user is not root"
23+
return cmd, "Authentication successful, but user is not root\n", False
2524

2625
except paramiko.ssh_exception.AuthenticationException:
27-
return cmd, "Authentication error, credentials are wrong"
26+
return cmd, "Authentication error, credentials are wrong\n", False

‎helper.py‎

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
def num_tokens_from_string(model: str, string: str) -> int:
77
"""Returns the number of tokens in a text string."""
8-
encoding = tiktoken.encoding_for_model(model)
8+
9+
# I know this is crappy for all non-openAI models but sadly this
10+
# has to be good enough for now
11+
if model.startswith("gpt-"):
12+
encoding = tiktoken.encoding_for_model(model)
13+
else:
14+
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
915
return len(encoding.encode(string))
1016

1117
def get_history_table(run_id: int, db: DbStorage, round: int) -> Table:
@@ -14,15 +20,18 @@ def get_history_table(run_id: int, db: DbStorage, round: int) -> Table:
1420
table.add_column("Tokens", style="dim")
1521
table.add_column("Cmd")
1622
table.add_column("Resp. Size", justify="right")
17-
table.add_column("ThinkTime", style="dim")
23+
table.add_column("ThinkingTime", style="dim")
1824
table.add_column("Tokens", style="dim")
1925
table.add_column("Reason")
26+
table.add_column("StateTime", style="dim")
27+
table.add_column("StateTokens", style="dim")
2028

2129
for i in range(0, round+1):
2230
table.add_row(*db.get_round_data(run_id, i))
2331

2432
return table
2533

34+
# return a list with cmd/result pairs, trimmed to context_size
2635
def get_cmd_history(model: str, run_id: int, db: DbStorage, limit: int) -> list[str]:
2736
result = []
2837
rest = limit
@@ -46,4 +55,45 @@ def get_cmd_history(model: str, run_id: int, db: DbStorage, limit: int) -> list[
4655
"result" : itm[1][:(rest-size_cmd-2)] + ".."
4756
})
4857
return list(reversed(result))
49-
return list(reversed(result))
58+
return list(reversed(result))
59+
60+
STEP_CUT_TOKENS : int = 32
61+
SAFETY_MARGIN : int = 128
62+
63+
# create the command history. Initially create the full command history, then
64+
# try to trim it down
65+
def get_cmd_history_v3(model: str, ctx_size: int, run_id: int, db: DbStorage, token_overhead: int) -> str:
66+
result: str = ""
67+
68+
# get commands from db
69+
cmds = db.get_cmd_history(run_id)
70+
71+
# create the full history
72+
for itm in cmds:
73+
result = result + '$ ' + itm[0] + "\n" + itm[1]
74+
75+
# trim it down if too large
76+
cur_size = num_tokens_from_string(model, result) + token_overhead + SAFETY_MARGIN
77+
78+
while cur_size > ctx_size:
79+
diff = cur_size - ctx_size
80+
step = int((diff + STEP_CUT_TOKENS)/2)
81+
result = result[:-step]
82+
cur_size = num_tokens_from_string(model, result) + token_overhead + SAFETY_MARGIN
83+
84+
return result
85+
86+
def remove_wrapping_characters(cmd, wrappers):
87+
if cmd[0] == cmd[-1] and cmd[0] in wrappers:
88+
print("will remove a wrapper from: " + cmd)
89+
return remove_wrapping_characters(cmd[1:-1], wrappers)
90+
return cmd
91+
92+
# often the LLM produces a wrapped command
93+
def cmd_output_fixer(cmd):
94+
cmd = remove_wrapping_characters(cmd, "`'\"")
95+
96+
if cmd.startswith("$ "):
97+
cmd = cmd[2:]
98+
99+
return cmd

‎llm_with_state.py‎

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import time
32
import typing
43

@@ -9,50 +8,106 @@
98
@dataclass
109
class LLMResult:
1110
result: typing.Any
11+
prompt: str
12+
answer: str
1213
duration: float = 0
1314
tokens_query: int = 0
1415
tokens_response: int = 0
1516

17+
def get_empty_result():
18+
return LLMResult('', '', '', 0, 0, 0)
19+
1620
class LLMWithState:
17-
def __init__(self, run_id, llm_connection, history, initial_user, initial_password):
21+
def __init__(self, run_id, llm_connection, history, initial_user, initial_password, hints=None):
1822
self.llm_connection = llm_connection
1923
self.initial_user = initial_user
2024
self.initial_password = initial_password
2125
self.db = history
2226
self.run_id = run_id
23-
self.state = """
27+
self.state = f"""
2428
- this is a linux system
25-
- your low privilege user credentials are lowpriv:trustno1
29+
- your low privilege user credentials are {initial_user}:{initial_password}
2630
"""
31+
self.hints = hints
32+
33+
def get_next_cmd(self, hostname=''):
34+
35+
template_file = 'query_next_command.txt'
36+
model = self.llm_connection.get_model()
2737

28-
def get_next_cmd(self):
29-
state_size = num_tokens_from_string(self.llm_connection.get_model(), self.state)
38+
state_size = num_tokens_from_string(model, self.state)
39+
40+
template = Template(filename='templates/' + template_file)
41+
template_size = num_tokens_from_string(model, template.source)
3042

31-
return self.create_and_ask_prompt('query_next_command.txt', user=self.initial_user, password=self.initial_password, history=get_cmd_history(self.llm_connection.get_model(), self.run_id, self.db, self.llm_connection.get_context_size()-state_size), state=self.state)
43+
history = get_cmd_history_v3(model, self.llm_connection.get_context_size(), self.run_id, self.db, state_size+template_size)
44+
45+
if self.hints != None:
46+
hint = self.hints[hostname]
47+
else:
48+
hint =''
49+
result = self.create_and_ask_prompt_text(template_file, user=self.initial_user, password=self.initial_password, history=history, state=self.state, hint=hint)
50+
51+
# make result backwards compatible
52+
if result.result.startswith("test_credentials"):
53+
result.result = {
54+
"type" : "ssh",
55+
"username" : result.result.split(" ")[1],
56+
"password" : result.result.split(" ")[2]
57+
}
58+
else:
59+
result.result = {
60+
"type" : "cmd",
61+
"cmd" : cmd_output_fixer(result.result)
62+
}
63+
64+
return result
3265

3366
def analyze_result(self, cmd, result):
34-
result = self.create_and_ask_prompt('successfull.txt', cmd=cmd, resp=result, facts=self.state)
3567

36-
self.tmp_state = result.result["facts"]
68+
model = self.llm_connection.get_model()
69+
ctx = self.llm_connection.get_context_size()
70+
71+
# ugly, but cut down result to fit context size
72+
# don't do this linearly as this can take too long
73+
CUTOFF_STEP = 128
74+
current_size = num_tokens_from_string(model, result)
75+
while current_size > (ctx + 512):
76+
cut_off = int(((current_size - (ctx + 512)) + CUTOFF_STEP)/2)
77+
result = result[cut_off:]
78+
current_size = num_tokens_from_string(model, result)
79+
80+
result = self.create_and_ask_prompt_text('analyze_cmd.txt', cmd=cmd, resp=result, facts=self.state)
3781
return result
3882

39-
def update_state(self):
40-
self.state = "\n".join(map(lambda x: "- " + x, self.tmp_state))
41-
return LLMResult(self.state, 0, 0, 0)
83+
def update_state(self, cmd, result):
84+
result = self.create_and_ask_prompt_text('update_state.txt', cmd=cmd, resp=result, facts=self.state)
85+
self.state = result.result
86+
return result
4287

4388
def get_current_state(self):
4489
return self.state
4590

46-
def create_and_ask_prompt(self, template_file, **params):
91+
def create_and_ask_prompt_text(self, template_file, **params):
92+
# prepare the prompt
4793
template = Template(filename='templates/' + template_file)
4894
prompt = template.render(**params)
95+
96+
if not self.llm_connection.get_model().startswith("gpt-"):
97+
prompt = wrap_it_for_llama(prompt)
98+
99+
# and execute it
49100
tic = time.perf_counter()
50101
result, tok_query, tok_res = self.llm_connection.exec_query(self.llm_connection.get_model(), self.llm_connection.get_context_size(), prompt)
51102
toc = time.perf_counter()
52-
try:
53-
json_answer = json.loads(result)
54-
except Exception as e:
55-
print("there as an exception with JSON parsing: " + str(e))
56-
print("debug[the plain result]: " + str(result))
57103

58-
return LLMResult(json_answer, toc - tic, tok_query, tok_res)
104+
return LLMResult(result, prompt, result, toc - tic, tok_query, tok_res)
105+
106+
def wrap_it_for_llama(prompt):
107+
return f"""### System:
108+
you are a concise but helful learning tool that aids students trying to find security vulnerabilities
109+
110+
### User:
111+
{prompt}
112+
113+
### Assistant:"""

‎llms/manager.py‎ renamed to ‎llms/llm_connection.py‎

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@
1010
def get_potential_llm_connections():
1111
return list(map(lambda x: x[0], connections))
1212

13-
def get_llm_connection(name, model, context_size):
13+
def get_llm_connection(config):
1414
for i in connections:
15-
if i[0] == name:
16-
if i[1]() == True:
17-
return LLMConnection(name, model, context_size, i[2])
15+
if i[0] == config.llm_connection:
16+
if i[1](config) == True:
17+
return LLMConnection(config, i[2])
1818
else:
1919
print("Parameter for connection missing")
2020
return None
2121
print("Configured connection not found")
2222
return None
2323

2424
class LLMConnection:
25-
def __init__(self, conn, model, context_size, exec_query):
26-
self.conn = conn
27-
self.model = model
28-
self.context_size = context_size
25+
def __init__(self, config, exec_query):
26+
self.conn = config.llm_connection
27+
self.model = config.model
28+
self.context_size = config.context_size
2929
self.exec_query = exec_query
3030

3131
def exec_query(self, query):

‎llms/oobabooga.py‎

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import html
22
import json
3-
import os
43

54
import requests
65

@@ -101,11 +100,7 @@ def get_openai_response(cmd):
101100

102101
return html.unescape(result['visible'][-1][1])
103102

104-
def verify_config():
103+
def verify_config(config):
105104
global url
106105

107-
url = os.getenv('OOBABOOGA_URL')
108-
109-
if url == '':
110-
raise Exception("please set OOBABOOGA_URL through environmental variables")
111-
return True
106+
url = f"{config.llm_server_base_url}/api/v1/chat"

‎llms/openai.py‎

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)