Skip to content
Prev Previous commit
Next Next commit
optimizeded code
  • Loading branch information
DianaStrauss committed Aug 6, 2024
commit e4ef23a1f39e6d9af86a287c738ce56017e99cb2
106 changes: 49 additions & 57 deletions src/hackingBuddyGPT/usecases/web_api_testing/prompt_engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,45 @@ def get_http_action_template(self, method):
else:
return (
f"Create HTTPRequests of type {method} considering only the object with id=1 for the endpoint and understand the responses. Ensure that they are correct requests.")


def get_initial_steps(self, common_steps):
return [
"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}",
"Note down the response structures, status codes, and headers for each endpoint.",
"For each endpoint, document the following details: URL, HTTP method, query parameters and path variables, expected request body structure for requests, response structure for successful and error responses."
] + common_steps

def get_phase_steps(self, phase, common_steps):
if phase != "DELETE":
return [
f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.",
self.get_http_action_template(phase)
] + common_steps
else:
return [
"Check for all endpoints the DELETE method. Delete the first instance for all endpoints.",
self.get_http_action_template(phase)
] + common_steps

def get_endpoints_needing_help(self):
endpoints_needing_help = []
endpoints_and_needed_methods = {}
http_methods_set = {"GET", "POST", "PUT", "DELETE"}

for endpoint, methods in self.endpoint_methods.items():
missing_methods = http_methods_set - set(methods)
if len(methods) < 4:
endpoints_needing_help.append(endpoint)
endpoints_and_needed_methods[endpoint] = list(missing_methods)

if endpoints_needing_help:
first_endpoint = endpoints_needing_help[0]
needed_method = endpoints_and_needed_methods[first_endpoint][0]
return [
f"For endpoint {first_endpoint} find this missing method: {needed_method}. If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search."]
return []
def chain_of_thought(self, doc=False, hint=""):
"""
Generates a prompt using the chain-of-thought strategy.
If 'doc' is True, it follows a detailed documentation-oriented prompt strategy based on the round number.
If 'doc' is False, it provides general guidance for early round numbers and focuses on HTTP methods for later rounds.

Args:
doc (bool): Determines whether the documentation-oriented chain of thought should be used.
Expand All @@ -126,70 +158,30 @@ def chain_of_thought(self, doc=False, hint=""):
"Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes."
]

http_methods = [ "PUT", "DELETE"]
http_phase = {
5: http_methods[0],
10: http_methods[1]
}

http_methods = ["PUT", "DELETE"]
http_phase = {10: http_methods[0], 15: http_methods[1]}
if doc:
if self.round < 5:

chain_of_thought_steps = [
f"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}", f"Note down the response structures, status codes, and headers for each endpoint.",
f"For each endpoint, document the following details: URL, HTTP method, "
f"query parameters and path variables, expected request body structure for requests, response structure for successful and error responses."
] + common_steps
if self.round <= 5:
chain_of_thought_steps = self.get_initial_steps(common_steps)
elif self.round <= 10:
phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys())))
chain_of_thought_steps = self.get_phase_steps(phase, common_steps)
else:
if self.round <= 10:
phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys())))
print(f'phase:{phase}')
if phase != "DELETE":
chain_of_thought_steps = [
f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.",
self.get_http_action_template(phase)
] + common_steps
else:
chain_of_thought_steps = [
f"Check for all endpoints the DELETE method. Delete the first instance for all endpoints. ",
self.get_http_action_template(phase)
] + common_steps
else:
endpoints_needing_help = []
endpoints_and_needed_methods = {}

# Standard HTTP methods
http_methods = {"GET", "POST", "PUT", "DELETE"}

for endpoint in self.endpoint_methods:
# Calculate the missing methods for the current endpoint
missing_methods = http_methods - set(self.endpoint_methods[endpoint])

if len(self.endpoint_methods[endpoint]) < 4:
endpoints_needing_help.append(endpoint)
# Add the missing methods to the dictionary
endpoints_and_needed_methods[endpoint] = list(missing_methods)

print(f'endpoints_and_needed_methods: {endpoints_and_needed_methods}')
print(f'first endpoint in list: {endpoints_needing_help[0]}')
print(f'methods needed for first endpoint: {endpoints_and_needed_methods[endpoints_needing_help[0]][0]}')

chain_of_thought_steps = [f"For enpoint {endpoints_needing_help[0]} find this missing method :{endpoints_and_needed_methods[endpoints_needing_help[0]][0]} "
f"If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search. ",]

chain_of_thought_steps = self.get_endpoints_needing_help()
else:
if self.round == 0:
chain_of_thought_steps = ["Let's think step by step."] # Zero shot prompt
chain_of_thought_steps = ["Let's think step by step."]
elif self.round <= 20:
focus_phases = ["endpoints", "HTTP method GET", "HTTP method POST and PUT", "HTTP method DELETE"]
focus_phase = focus_phases[self.round // 5]
chain_of_thought_steps = [f"Just focus on the {focus_phase} for now."]
else:
chain_of_thought_steps = ["Look for exploits."]

print(f'chain of thought steps: {chain_of_thought_steps}')
prompt = self.check_prompt(self.previous_prompt,
chain_of_thought_steps + [hint] if hint else chain_of_thought_steps)
if hint:
chain_of_thought_steps.append(hint)

prompt = self.check_prompt(self.previous_prompt, chain_of_thought_steps)
return prompt

def token_count(self, text):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _setup_initial_prompt(self):
response_handler=self.response_handler)


def all_http_methods_found(self):
def all_http_methods_found(self,turn):
print(f'found endpoints:{self.documentation_handler.endpoint_methods.items()}')
print(f'found endpoints values:{self.documentation_handler.endpoint_methods.values()}')

Expand All @@ -83,17 +83,20 @@ def all_http_methods_found(self):
print(f'found endpoints:{found_endpoints}')
print(f'expected endpoints:{expected_endpoints}')
print(f'correct? {found_endpoints== expected_endpoints}')
if found_endpoints== expected_endpoints or found_endpoints == expected_endpoints -1:
if found_endpoints > 0 and (found_endpoints== expected_endpoints) :
return True
else:
if turn == 20:
if found_endpoints > 0 and (found_endpoints == expected_endpoints):
return True
return False

def perform_round(self, turn: int):
prompt = self.prompt_engineer.generate_prompt(doc=True)
response, completion = self.llm_handler.call_llm(prompt)
return self._handle_response(completion, response)
return self._handle_response(completion, response, turn)

def _handle_response(self, completion, response):
def _handle_response(self, completion, response, turn):
message = completion.choices[0].message
tool_call_id = message.tool_calls[0].id
command = pydantic_core.to_json(response).decode()
Expand All @@ -106,7 +109,6 @@ def _handle_response(self, completion, response):
result_str = self.response_handler.parse_http_status_line(result)
self._prompt_history.append(tool_message(result_str, tool_call_id))
invalid_flags = ["recorded","Not a valid HTTP method", "404" ,"Client Error: Not Found"]
print(f'result_str:{result_str}')
if not result_str in invalid_flags or any(item in result_str for item in invalid_flags):
self.prompt_engineer.found_endpoints = self.documentation_handler.update_openapi_spec(response, result)
self.documentation_handler.write_openapi_to_yaml()
Expand All @@ -120,8 +122,7 @@ def _handle_response(self, completion, response):
http_methods_dict[method].append(endpoint)
self.prompt_engineer.endpoint_found_methods = http_methods_dict
self.prompt_engineer.endpoint_methods = self.documentation_handler.endpoint_methods
print(f'SCHEMAS:{self.prompt_engineer.schemas}')
return self.all_http_methods_found()
return self.all_http_methods_found(turn)



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def parse_http_status_line(self, status_line):
"""
if status_line == "Not a valid HTTP method":
return status_line

status_line = status_line.split('\r\n')[0]
# Regular expression to match valid HTTP status lines
match = re.match(r'^(HTTP/\d\.\d) (\d{3}) (.*)$', status_line)
if match:
Expand Down