|
| 1 | +from agent import workflow, guard_prompt, ToolCall, ToolCallResult |
| 2 | +import redis.asyncio as redis |
| 3 | +from contextlib import asynccontextmanager |
| 4 | +from fastapi import Depends, FastAPI, Header, HTTPException |
| 5 | +from fastapi_limiter import FastAPILimiter |
| 6 | +from fastapi_limiter.depends import RateLimiter |
| 7 | +import json |
| 8 | +from fastapi.responses import ORJSONResponse |
| 9 | +from pydantic import BaseModel |
| 10 | +import gradio as gr |
| 11 | +import requests |
| 12 | + |
| 13 | +class ApiInput(BaseModel): |
| 14 | + prompt: str |
| 15 | + |
| 16 | +class ApiOutput(BaseModel): |
| 17 | + is_safe_prompt: bool |
| 18 | + response: str |
| 19 | + process: str |
| 20 | + |
| 21 | +with open("/run/secrets/internal_key", "r") as f: |
| 22 | + internal_key = f.read() |
| 23 | +f.close() |
| 24 | + |
| 25 | +@asynccontextmanager |
| 26 | +async def lifespan(_: FastAPI): |
| 27 | + redis_connection = redis.from_url("redis://llama_redis:6379", encoding="utf8") |
| 28 | + await FastAPILimiter.init(redis_connection) |
| 29 | + yield |
| 30 | + await FastAPILimiter.close() |
| 31 | + |
| 32 | +async def check_api_key(x_api_key: str = Header(None)): |
| 33 | + if x_api_key == internal_key: |
| 34 | + return x_api_key |
| 35 | + else: |
| 36 | + raise HTTPException(status_code=401, detail="Invalid API key") |
| 37 | + |
| 38 | +app = FastAPI(default_response_class=ORJSONResponse, lifespan=lifespan) |
| 39 | + |
| 40 | +@app.get("/test", dependencies=[Depends(RateLimiter(times=10, seconds=1))]) |
| 41 | +async def index(): |
| 42 | + return {"response": "Hello world!"} |
| 43 | + |
| 44 | +@app.post("/chat", dependencies=[Depends(RateLimiter(times=10, seconds=60))]) |
| 45 | +async def chat(inpt: ApiInput, x_api_key: str = Depends(check_api_key)) -> ApiOutput: |
| 46 | + is_safe, r = await guard_prompt(inpt.prompt) |
| 47 | + process = "" |
| 48 | + if not is_safe: |
| 49 | + return ApiOutput(is_safe_prompt=is_safe, response="I cannot produce an essay about this topic", process=r) |
| 50 | + handler = workflow.run(user_msg=inpt.prompt) |
| 51 | + async for event in handler.stream_events(): |
| 52 | + if isinstance(event, ToolCall): |
| 53 | + process += "Calling tool **" + event.tool_name + "**" + " with arguments:\n```json\n" + json.dumps(event.tool_kwargs, indent=4) + "\n```\n\n" |
| 54 | + if isinstance(event, ToolCallResult): |
| 55 | + process += f"Tool call result for **{event.tool_name}**: {event.tool_output}\n\n" |
| 56 | + response = await handler |
| 57 | + r = str(response) |
| 58 | + return ApiOutput(is_safe_prompt=is_safe, response=r, process=process) |
| 59 | + |
| 60 | + |
| 61 | +def add_message(history: list, message: dict): |
| 62 | + if message is not None: |
| 63 | + history.append({"role": "user", "content": message}) |
| 64 | + return history, gr.Textbox(value=None, interactive=False) |
| 65 | + |
| 66 | +def bot(history: list): |
| 67 | + headers = {"Content-Type": "application/json", "x-api-key": internal_key} |
| 68 | + response = requests.post("http://localhost:80/chat", json=ApiInput(prompt=history[-1]["content"]).model_dump(), headers=headers) |
| 69 | + if response.status_code == 200: |
| 70 | + res = response.json()["response"] |
| 71 | + process = response.json()["process"] |
| 72 | + history.append({"role": "assistant", "content": f"## Agentic Process\n\n{process}"}) |
| 73 | + return history, "# Canvas\n\n---\n\n"+res |
| 74 | + elif response.status_code == 429: |
| 75 | + res = "Sorry, we are having high traffic at the moment... Try again later!" |
| 76 | + history.append({"role": "assistant", "content": f"Sorry, we are having high traffic at the moment... Try again later!"}) |
| 77 | + return history, "# Canvas\n\n---\n\n"+res |
| 78 | + else: |
| 79 | + res = "Sorry, an internal error occurred. Feel free to report the bug on [GitHub discussions](https://github.com/AstraBert/llama-4-researcher/discussions/)" |
| 80 | + history.append({"role": "assistant", "content": f"Sorry, an internal error occurred. Feel free to report the bug on [GitHub discussions](https://github.com/AstraBert/llama-4-researcher/discussions/)"}) |
| 81 | + return history, "# Canvas\n\n---\n\n"+res |
| 82 | + |
| 83 | +with gr.Blocks(theme=gr.themes.Citrus(), title="LlamaResearcher") as frontend: |
| 84 | + title = gr.HTML("<h1 align='center'>LlamaResearcher</h1>\n<h2 align='center'>From topic to essay in seconds!</h2>") |
| 85 | + with gr.Row(): |
| 86 | + with gr.Column(): |
| 87 | + canvas = gr.Markdown(label="Canvas", show_label=True, show_copy_button=True, container=True, min_height=700) |
| 88 | + with gr.Column(): |
| 89 | + chatbot = gr.Chatbot(elem_id="chatbot", type="messages", min_height=700, min_width=700, label="LlamaResearcher Chat") |
| 90 | + with gr.Row(): |
| 91 | + chat_input = gr.Textbox( |
| 92 | + interactive=True, |
| 93 | + placeholder="Enter message...", |
| 94 | + show_label=False, |
| 95 | + submit_btn=True, |
| 96 | + stop_btn=True, |
| 97 | + ) |
| 98 | + |
| 99 | + chat_msg = chat_input.submit( |
| 100 | + add_message, [chatbot, chat_input], [chatbot, chat_input] |
| 101 | + ) |
| 102 | + bot_msg = chat_msg.then(bot, chatbot, [chatbot, canvas], api_name="bot_response") |
| 103 | + bot_msg.then(lambda: gr.Textbox(interactive=True), None, [chat_input]) |
| 104 | + |
| 105 | +app = gr.mount_gradio_app(app, frontend, "") |
0 commit comments