Skip to content

Commit 27198e9

Browse files
committed
Adds OpenAI endpoints
1 parent e2a8581 commit 27198e9

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

‎src/dialog/routers/__init__.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .dialog import *
1+
from .dialog import *
2+
from .openai import open_ai_api_router

‎src/main.py‎

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastapi import FastAPI
88
from fastapi.staticfiles import StaticFiles
99
from fastapi.middleware.cors import CORSMiddleware
10-
from dialog.routers import api_router
10+
from dialog.routers import api_router, open_ai_api_router
1111

1212
logging.basicConfig(
1313
level=Settings().LOGGING_LEVEL,
@@ -31,9 +31,18 @@ def get_application() -> FastAPI:
3131
allow_headers=Settings().CORS_ALLOW_HEADERS,
3232
)
3333

34+
app.add_middleware(
35+
CustomHeaderMiddleware
36+
)
37+
3438
app.include_router(
3539
api_router, prefix="",
3640
)
41+
42+
app.include_router(
43+
open_ai_api_router, prefix="/openai"
44+
)
45+
3746
app.mount("/static", StaticFiles(directory=Settings().STATIC_FILE_LOCATION), name="static")
3847
plugins = entry_points(group="dialog")
3948
for plugin in plugins:

‎src/tests/conftest.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,10 @@ def chat_session(dbsession):
4545
def llm_mock(mocker):
4646
llm_mock = mocker.patch('dialog.routers.dialog.process_user_message')
4747
llm_mock.process.return_value = {"text": "Hello"}
48+
return llm_mock
49+
50+
@pytest.fixture
51+
def llm_mock_openai_router(mocker):
52+
llm_mock = mocker.patch('dialog.routers.openai.process_user_message')
53+
llm_mock.process.return_value = {"text": "Hello"}
4854
return llm_mock

‎src/tests/test_views.py‎

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import pytest
23

34
from dialog_lib.db.models import ChatMessages, Chat
@@ -47,4 +48,21 @@ def test_invalid_database_connection(client, mocker):
4748
with pytest.raises(Exception):
4849
response = client.get("/health")
4950
assert response.status_code == 500
50-
assert response.json() == {"message": "Failed to execute simple SQL"}
51+
assert response.json() == {"message": "Failed to execute simple SQL"}
52+
53+
# test openai router
54+
55+
def test_customized_openai_models_response(client):
56+
response = client.get("/openai/models")
57+
assert response.status_code == 200
58+
for i in ["id", "object", "created", "owned_by"]:
59+
assert i in response.json()[0]
60+
61+
def test_customized_openai_chat_completion_response(client, llm_mock_openai_router):
62+
os.environ["LLM_CLASS"] = "dialog.llm.agents.default.DialogLLM"
63+
response = client.post("/openai/chat/completions", json={"message": "Hello"})
64+
assert response.status_code == 200
65+
for i in ["choices", "created", "id", "model", "object", "usage"]:
66+
assert i in response.json()
67+
assert llm_mock_openai_router.called
68+
assert response.json()["choices"][0]["message"]["role"] == "assistant"

0 commit comments

Comments
 (0)