Skip to content

Commit d5c8b93

Browse files
committed
restructure code, remove helper.py
1 parent 004ff0e commit d5c8b93

File tree

13 files changed

+193
-239
lines changed

13 files changed

+193
-239
lines changed

‎README.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ series = {ESEC/FSE 2023}
4141

4242
# Example runs
4343

44-
- more can be seen at [history notes](https://github.com/ipa-lab/hackingBuddyGPT/blob/v3/history_notes.md)
44+
- more can be seen at [history notes](https://github.com/ipa-lab/hackingBuddyGPT/blob/v3/docs/history_notes.md)
4545

4646
## updated version using GPT-4
4747

‎args.py‎

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import argparse
2+
import json
3+
import os
4+
5+
from dotenv import load_dotenv
6+
from llms.llm_connection import get_potential_llm_connections
7+
8+
def parse_args_and_env():
9+
# setup dotenv
10+
load_dotenv()
11+
12+
# perform argument parsing
13+
# for defaults we are using .env but allow overwrite through cli arguments
14+
parser = argparse.ArgumentParser(description='Run an LLM vs a SSH connection.')
15+
parser.add_argument('--enable-explanation', help="let the LLM explain each round's result", action="store_true")
16+
parser.add_argument('--enable-update-state', help='ask the LLM to keep a multi-round state with findings', action="store_true")
17+
parser.add_argument('--log', type=str, help='sqlite3 db for storing log files', default=os.getenv("LOG_DESTINATION") or ':memory:')
18+
parser.add_argument('--target-ip', type=str, help='ssh hostname to use to connect to target system', default=os.getenv("TARGET_IP") or '127.0.0.1')
19+
parser.add_argument('--target-hostname', type=str, help='safety: what hostname to exepct at the target IP', default=os.getenv("TARGET_HOSTNAME") or "debian")
20+
parser.add_argument('--target-user', type=str, help='ssh username to use to connect to target system', default=os.getenv("TARGET_USER") or 'lowpriv')
21+
parser.add_argument('--target-password', type=str, help='ssh password to use to connect to target system', default=os.getenv("TARGET_PASSWORD") or 'trustno1')
22+
parser.add_argument('--max-rounds', type=int, help='how many cmd-rounds to execute at max', default=int(os.getenv("MAX_ROUNDS")) or 10)
23+
parser.add_argument('--llm-connection', type=str, help='which LLM driver to use', choices=get_potential_llm_connections(), default=os.getenv("LLM_CONNECTION") or "openai_rest")
24+
parser.add_argument('--target-os', type=str, help='What is the target operating system?', choices=["linux", "windows"], default="linux")
25+
parser.add_argument('--model', type=str, help='which LLM to use', default=os.getenv("MODEL") or "gpt-3.5-turbo")
26+
parser.add_argument('--llm-server-base-url', type=str, help='which LLM server to use', default=os.getenv("LLM_SERVER_BASE_URL") or "https://api.openai.com")
27+
parser.add_argument('--tag', type=str, help='tag run with string', default="")
28+
parser.add_argument('--context-size', type=int, help='model context size to use', default=int(os.getenv("CONTEXT_SIZE")) or 4096)
29+
parser.add_argument('--hints', type=argparse.FileType('r', encoding='latin-1'), help='json file with a hint per tested hostname', default=None)
30+
31+
return parser.parse_args()
32+
33+
34+
def get_hint(args, console):
35+
if args.hints:
36+
try:
37+
hints = json.load(args.hints)
38+
if args.target_hostname in hints:
39+
hint = hints[args.target_hostname]
40+
console.print(f"[bold green]Using the following hint: '{hint}'")
41+
return hint
42+
except:
43+
console.print("[yellow]Was not able to load hint file")
44+
return None
File renamed without changes.
File renamed without changes.

‎handlers.py‎

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,42 @@
33
from targets.ssh import SSHHostConn
44

55
def handle_cmd(conn, input):
6-
result, gotRoot = conn.run(input["cmd"])
7-
return input["cmd"], result, gotRoot
6+
cmd = cmd_output_fixer(input)
7+
result, gotRoot = conn.run(cmd)
8+
return cmd, result, gotRoot
89

910

1011
def handle_ssh(target_host, target_hostname, input):
11-
user = input["username"]
12-
password = input["password"]
12+
cmd_parts = input.split(" ")
13+
assert(cmd_parts[0] == "test_credentials")
1314

14-
cmd = f"test_credentials {user} {password}\n"
15+
user = cmd_parts[1]
16+
password = cmd_parts[2]
1517

1618
test = SSHHostConn(target_host, target_hostname, user, password)
1719
try:
1820
test.connect()
1921
user = test.run("whoami")[0].strip('\n\r ')
2022
if user == "root":
21-
return cmd, "Login as root was successful\n", True
23+
return input, "Login as root was successful\n", True
2224
else:
23-
return cmd, "Authentication successful, but user is not root\n", False
25+
return input, "Authentication successful, but user is not root\n", False
2426

2527
except paramiko.ssh_exception.AuthenticationException:
26-
return cmd, "Authentication error, credentials are wrong\n", False
28+
return input, "Authentication error, credentials are wrong\n", False
29+
30+
31+
def remove_wrapping_characters(cmd, wrappers):
32+
if cmd[0] == cmd[-1] and cmd[0] in wrappers:
33+
print("will remove a wrapper from: " + cmd)
34+
return remove_wrapping_characters(cmd[1:-1], wrappers)
35+
return cmd
36+
37+
# often the LLM produces a wrapped command
38+
def cmd_output_fixer(cmd):
39+
cmd = remove_wrapping_characters(cmd, "`'\"")
40+
41+
if cmd.startswith("$ "):
42+
cmd = cmd[2:]
43+
44+
return cmd

‎helper.py‎

Lines changed: 0 additions & 101 deletions
This file was deleted.

‎llm_with_state.py‎

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import time
2+
import tiktoken
23
import typing
34

5+
from db_storage import DbStorage
46
from dataclasses import dataclass
57
from mako.template import Template
6-
from helper import *
78

89
@dataclass
910
class LLMResult:
@@ -14,33 +15,33 @@ class LLMResult:
1415
tokens_query: int = 0
1516
tokens_response: int = 0
1617

17-
def get_empty_result():
18-
return LLMResult('', '', '', 0, 0, 0)
19-
2018
class LLMWithState:
21-
def __init__(self, run_id, llm_connection, history, initial_user, initial_password, update_state, target_os, hint=None):
19+
def __init__(self, run_id, llm_connection, history, args, hint):
2220
self.llm_connection = llm_connection
23-
self.initial_user = initial_user
24-
self.initial_password = initial_password
21+
self.initial_user = args.target_user
22+
self.initial_password = args.target_password
2523
self.db = history
2624
self.run_id = run_id
27-
self.enable_update_state = update_state
28-
self.target_os = target_os
25+
self.enable_update_state = args.enable_update_state
26+
self.target_os = args.target_os
27+
self.hint = hint
2928
self.state = f"""
30-
- this is a linux system
31-
- your low privilege user credentials are {initial_user}:{initial_password}
29+
- this is a {self.target_os} system
30+
- your low privilege user credentials are {self.initial_user}:{self.initial_password}
3231
"""
33-
self.hint = hint
32+
33+
def get_state_size(self, model):
34+
if self.enable_update_state:
35+
return num_tokens_from_string(model, self.state)
36+
else:
37+
return 0
3438

3539
def get_next_cmd(self):
3640

3741
template_file = 'query_next_command.txt'
3842
model = self.llm_connection.get_model()
3943

40-
if self.enable_update_state:
41-
state_size = num_tokens_from_string(model, self.state)
42-
else:
43-
state_size = 0
44+
state_size = self.get_state_size(model)
4445

4546
template = Template(filename='templates/' + template_file)
4647
template_size = num_tokens_from_string(model, template.source)
@@ -52,22 +53,7 @@ def get_next_cmd(self):
5253
else:
5354
target_user = "Administrator"
5455

55-
result = self.create_and_ask_prompt_text(template_file, user=self.initial_user, password=self.initial_password, history=history, state=self.state, hint=self.hint, update_state=self.enable_update_state, target_os=self.target_os, target_user=target_user)
56-
57-
# make result backwards compatible
58-
if result.result.startswith("test_credentials"):
59-
result.result = {
60-
"type" : "ssh",
61-
"username" : result.result.split(" ")[1],
62-
"password" : result.result.split(" ")[2]
63-
}
64-
else:
65-
result.result = {
66-
"type" : "cmd",
67-
"cmd" : cmd_output_fixer(result.result)
68-
}
69-
70-
return result
56+
return self.create_and_ask_prompt_text(template_file, user=self.initial_user, password=self.initial_password, history=history, state=self.state, hint=self.hint, update_state=self.enable_update_state, target_os=self.target_os, target_user=target_user)
7157

7258
def analyze_result(self, cmd, result):
7359

@@ -109,6 +95,43 @@ def create_and_ask_prompt_text(self, template_file, **params):
10995

11096
return LLMResult(result, prompt, result, toc - tic, tok_query, tok_res)
11197

98+
def num_tokens_from_string(model: str, string: str) -> int:
99+
"""Returns the number of tokens in a text string."""
100+
101+
# I know this is crappy for all non-openAI models but sadly this
102+
# has to be good enough for now
103+
if model.startswith("gpt-"):
104+
encoding = tiktoken.encoding_for_model(model)
105+
else:
106+
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
107+
return len(encoding.encode(string))
108+
109+
STEP_CUT_TOKENS : int = 32
110+
SAFETY_MARGIN : int = 128
111+
112+
# create the command history. Initially create the full command history, then
113+
# try to trim it down
114+
def get_cmd_history_v3(model: str, ctx_size: int, run_id: int, db: DbStorage, token_overhead: int) -> str:
115+
result: str = ""
116+
117+
# get commands from db
118+
cmds = db.get_cmd_history(run_id)
119+
120+
# create the full history
121+
for itm in cmds:
122+
result = result + '$ ' + itm[0] + "\n" + itm[1]
123+
124+
# trim it down if too large
125+
cur_size = num_tokens_from_string(model, result) + token_overhead + SAFETY_MARGIN
126+
127+
while cur_size > ctx_size:
128+
diff = cur_size - ctx_size
129+
step = int((diff + STEP_CUT_TOKENS)/2)
130+
result = result[:-step]
131+
cur_size = num_tokens_from_string(model, result) + token_overhead + SAFETY_MARGIN
132+
133+
return result
134+
112135
def wrap_it_for_llama(prompt):
113136
return f"""### System:
114137
you are a concise but helful learning tool that aids students trying to find security vulnerabilities

‎llms/llm_connection.py‎

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ def exec_query(self, query):
3333

3434
def get_context_size(self):
3535
return self.context_size
36-
37-
def output_metadata(self):
38-
return f"connection: {self.conn} using {self.model} with context-size {self.context_size}"
3936

4037
def get_model(self) -> str:
4138
return self.model

‎targets/psexec.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
def get_smb_connection(ip, hostname, username, password):
44
return SMBHostConn(ip, hostname, username, password)
55

6+
# read https://pypi.org/project/pypsexec/
7+
# - TODO: why is timeout not working?
68
class SMBHostConn:
79

810
def __init__(self, host, hostname, username, password):
@@ -18,5 +20,5 @@ def connect(self):
1820

1921
def run(self, cmd):
2022
stdout, stderr, rc = self.client.run_executable("cmd.exe",
21-
arguments=f"/c {cmd}")
23+
arguments=f"/c {cmd}", timeout_seconds=2)
2224
return str(stdout), False

‎targets/ssh.py‎

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def run(self, cmd):
4747
try:
4848
resp = self.conn.run(cmd, pty=True, warn=True, out_stream=out, watchers=[sudopass], timeout=10)
4949
except Exception as e:
50-
print("TIMEOUT!")
50+
print("TIMEOUT! Could we have become root?")
5151
out.seek(0)
5252
tmp = ""
5353
lastline = ""
@@ -61,6 +61,13 @@ def run(self, cmd):
6161
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
6262
lastline = ansi_escape.sub('', lastline)
6363

64+
stupidity = re.compile(r"^[ \n\r]*```.*\n(.*)\n```$", re.MULTILINE)
65+
if stupidity.fullmatch(tmp):
66+
print("this would have been captured by the multi-line regex 1")
67+
stupidity = re.compile(r"^[ \n\r]*~~~.*\n(.*)\n~~~$", re.MULTILINE)
68+
if stupidity.fullmatch(tmp):
69+
print("this would have been captured by the multi-line regex 2")
70+
6471
for i in GOT_ROOT_REXEXPs:
6572
if i.fullmatch(lastline):
6673
gotRoot = True

0 commit comments

Comments
 (0)