Skip to content
Merged

V3 #2

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make LLM connection (oobabooga/openai_rest) configurable
  • Loading branch information
andreashappe committed Sep 7, 2023
commit c796b2d99ff288ea4647b80186b2f7930e6856ec
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ TARGET_IP='enter-the-private-ip-of-some-vm.local'
# exchange with the user for your target VM
TARGET_USER='bob'
TARGET_PASSWORD='secret'

# which LLM driver to use (can be openai_rest or oobabooga for now)
LLM_CONNECTION = "openai_rest"
12 changes: 2 additions & 10 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import os

from dotenv import load_dotenv

def check_config():
load_dotenv()

def model():
return os.getenv("MODEL")

Expand All @@ -20,8 +15,5 @@ def target_password():
def target_user():
return os.getenv('TARGET_USER')

def openai_key():
return os.getenv('OPENAI_KEY')

def oobabooga_url():
return os.getenv('OOBABOOGA_URL')
def llm_connection():
return os.getenv("LLM_CONNECTION")
19 changes: 19 additions & 0 deletions llms/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from llms.openai_rest import get_openai_rest_connection_data
from llms.oobabooga import get_oobabooga_setup

# we do not need something fast (like a map)
connections = [
get_openai_rest_connection_data(),
get_oobabooga_setup()
]

def get_llm_connection(name):
for i in connections:
if i[0] == name:
if i[1]() == True:
return i[2]
else:
print("Parameter for connection missing")
return None
print("Configured connection not found")
return None
24 changes: 18 additions & 6 deletions llms/oobabooga.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import html
import json
import re
import os

import requests

from config import oobabooga_url

# For local streaming, the websockets are hosted without ssl - http://
url = oobabooga_url()
URI = f'{HOST}/api/v1/chat'
url : str = 'unkown'

# For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/chat'


def get_oobabooga_setup():
return "oobabooga", verify_config, get_openai_response

def run(user_input, history):

global url

request = {
'user_input': user_input,
'max_new_tokens': 250,
Expand Down Expand Up @@ -71,7 +74,7 @@ def run(user_input, history):
'stopping_strings': []
}

response = requests.post(URI, json=request)
response = requests.post(url, json=request)

if response.status_code == 200:
return response.json()['results'][0]['history']
Expand All @@ -97,3 +100,12 @@ def get_openai_response(cmd):
return str(json.dumps({ "type" : "cmd", "cmd" : tmp.replace("\\\"", "\"")}))

return html.unescape(result['visible'][-1][1])

def verify_config():
global url

url = os.getenv('OOBABOOGA_URL')

if url == '':
raise Exception("please set OOBABOOGA_URL through environmental variables")
return True
27 changes: 20 additions & 7 deletions llms/openai_rest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import config
import os
import requests

openai_model : str = 'gpt-3.5-turbo'
openai_key : str = 'none'

def get_openai_response(cmd):
if config.model() == '' and config.openai_key() == '':
def get_openai_rest_connection_data():
return "openai_rest", verify_config, get_openai_response

def verify_config():
global openai_key, openai_model

openai_key = os.getenv("OPENAI_KEY")
openai_model = os.getenv("MODEL")

if openai_model == '' or openai_key == '':
raise Exception("please set OPENAI_KEY and MODEL through environment variables!")
openapi_key = config.openai_key()
openapi_model = config.model()

return True

def get_openai_response(cmd):
global openai_key, openai_model

headers = {"Authorization": f"Bearer {openapi_key}"}
data = {'model': openapi_model, 'messages': [{'role': 'user', 'content': cmd}]}
headers = {"Authorization": f"Bearer {openai_key}"}
data = {'model': openai_model, 'messages': [{'role': 'user', 'content': cmd}]}
response = requests.post('https://api.openai.com/v1/chat/completions', headers=headers, json=data).json()

print(str(response))
Expand Down
12 changes: 7 additions & 5 deletions wintermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@

from history import ResultHistory, num_tokens_from_string
from targets.ssh import get_ssh_connection, SSHHostConn
from llms.openai_rest import get_openai_response
from prompt_helper import LLM
from llms.manager import get_llm_connection
from dotenv import load_dotenv

# setup dotenv
load_dotenv()


# setup some infrastructure
cmd_history = ResultHistory()
console = Console()

# read configuration from env and configure system parts
config.check_config()

# open SSH connection to target
conn = get_ssh_connection(config.target_ip(), config.target_user(), config.target_password())
conn.connect()

# initialize LLM connection
llm = LLM(get_openai_response)
llm = LLM(get_llm_connection(config.llm_connection()))

context_size = config.context_size()
print("used model: " + config.model() + " context-size: " + str(config.context_size()))
Expand Down