Skip to content

Commit 5272109

Browse files
vmeseltjbck
andcommitted
Fully support Open-WebUI
Co-authored-by: Tim Baek <tim@openwebui.com>
1 parent 6d1dbad commit 5272109

File tree

5 files changed

+129
-43
lines changed

5 files changed

+129
-43
lines changed

‎docker-compose-open-webui.yml‎

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
version: '3.3'
2+
services:
3+
db:
4+
image: pgvector/pgvector:pg15
5+
restart: always
6+
ports:
7+
- '5432:5432'
8+
environment:
9+
POSTGRES_USER: talkdai
10+
POSTGRES_PASSWORD: talkdai
11+
POSTGRES_DB: talkdai
12+
volumes:
13+
- ./data/db:/var/lib/postgresql/data
14+
- ./psql/db-ext-vector.sql:/docker-entrypoint-initdb.d/db-ext-vector.sql
15+
healthcheck:
16+
test: ["CMD", "pg_isready", "-d", "talkdai", "-U", "talkdai"]
17+
interval: 10s
18+
timeout: 5s
19+
retries: 5
20+
dialog:
21+
build:
22+
context: .
23+
dockerfile: Dockerfile
24+
stdin_open: true
25+
tty: true
26+
volumes:
27+
- ./:/app
28+
- ./.empty:/app/data/db
29+
- ./static:/app/static
30+
- ./sample_data:/app/src/sample_data
31+
ports:
32+
- '8000:8000'
33+
depends_on:
34+
db:
35+
condition: service_healthy
36+
environment:
37+
- DATABASE_URL=postgresql://talkdai:talkdai@db:5432/talkdai
38+
- STATIC_FILE_LOCATION=/app/static
39+
env_file:
40+
- .env
41+
openwebui:
42+
image: ghcr.io/open-webui/open-webui:main
43+
ports:
44+
- '3000:8080'
45+
environment:
46+
- OPENAI_API_KEYS=FAKE-KEY;
47+
- OPENAI_API_BASE_URLS=http://dialog:8000/openai;
48+
- ENABLE_OPENAI_API=true
49+
volumes:
50+
- open-webui:/app/backend/data
51+
depends_on:
52+
db:
53+
condition: service_healthy
54+
dialog:
55+
condition: service_started
56+
57+
volumes:
58+
open-webui:

‎docker-compose.yml‎

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,6 @@ services:
3838
- STATIC_FILE_LOCATION=/app/static
3939
env_file:
4040
- .env
41-
openwebui:
42-
image: ghcr.io/open-webui/open-webui:main
43-
ports:
44-
- '3000:8080'
45-
environment:
46-
- OPENAI_API_KEYS=FAKE-KEY;
47-
- OPENAI_API_BASE_URLS=http://dialog:8000/openai;
48-
- ENABLE_OPENAI_API=true
49-
volumes:
50-
- open-webui:/app/backend/data
51-
depends_on:
52-
db:
53-
condition: service_healthy
54-
dialog:
55-
condition: service_started
5641

5742
volumes:
5843
open-webui:

‎src/dialog/routers/openai.py‎

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55

66
from dialog.db import engine, get_session
77
from dialog_lib.db.models import Chat as ChatEntity, ChatMessages
8-
from dialog.schemas import OpenAIChat, OpenAIChatCompletion, OpenAIModel
8+
from dialog.schemas import (
9+
OpenAIChat, OpenAIChatCompletion, OpenAIModel, OpenAIMessage,
10+
OpenAIStreamChoice, OpenAIStreamSchema
11+
)
912
from dialog.llm import process_user_message
1013

1114
from sqlalchemy.orm import Session
1215

1316
from fastapi import APIRouter, Depends
17+
from fastapi.responses import StreamingResponse
1418

1519
open_ai_api_router = APIRouter()
1620

@@ -49,27 +53,48 @@ async def ask_question_to_llm(message: OpenAIChat, session: Session = Depends(ge
4953

5054
duration = datetime.datetime.now() - start_time
5155
logging.info(f"Request processing time: {duration}")
52-
chat_completion = OpenAIChatCompletion(
53-
choices=[
54-
{
55-
"finish_reason": "stop",
56-
"index": 0,
57-
"message": {
58-
"content": ai_message["text"],
59-
"role": "assistant"
60-
},
61-
"logprobs": None
56+
generated_message = ai_message["text"]
57+
if not message.stream:
58+
chat_completion = OpenAIChatCompletion(
59+
choices=[
60+
{
61+
"finish_reason": "stop",
62+
"index": 0,
63+
"message": OpenAIMessage(**{
64+
"content": generated_message,
65+
"role": "assistant"
66+
}),
67+
"logprobs": None
68+
}
69+
],
70+
created=int(datetime.datetime.now().timestamp()),
71+
id=f"talkdai-{str(uuid4())}",
72+
model="talkd-ai",
73+
object="chat.completion",
74+
usage={
75+
"completion_tokens": None,
76+
"prompt_tokens": None,
77+
"total_tokens": None
6278
}
63-
],
64-
created=int(datetime.datetime.now().timestamp()),
65-
id=f"talkdai-{str(uuid4())}",
66-
model="talkd-ai",
67-
object="chat.completion",
68-
usage={
69-
"completion_tokens": None,
70-
"prompt_tokens": None,
71-
"total_tokens": None
72-
}
73-
)
74-
logging.info(f"Chat completion: {chat_completion}")
75-
return chat_completion
79+
)
80+
logging.info(f"Chat completion: {chat_completion}")
81+
return chat_completion
82+
83+
def gen():
84+
for word in f"{generated_message} +END".split():
85+
# Yield Streaming Response on each word
86+
message_part = OpenAIStreamChoice(
87+
index=0,
88+
delta={
89+
"content": f"{word} "
90+
} if word != "+END" else {}
91+
)
92+
93+
message_stream = OpenAIStreamSchema(
94+
id=f"talkdai-{str(uuid4())}",
95+
choices=[message_part]
96+
)
97+
logging.info(f"data: {message_stream.model_dump_json()}")
98+
yield f"data: {message_stream.model_dump_json()}\n\n"
99+
100+
return StreamingResponse(gen(), media_type='text/event-stream')

‎src/dialog/schemas/openai_schemas.py‎

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from typing import List, Optional
23
from pydantic import BaseModel, ConfigDict
34

@@ -9,6 +10,7 @@ class OpenAIMessage(BaseModel):
910
class OpenAIChat(BaseModel):
1011
model: str
1112
messages: List[OpenAIMessage]
13+
stream: bool = False
1214

1315

1416
class OpenAIChoices(BaseModel):
@@ -28,7 +30,7 @@ class OpenAIChatCompletion(BaseModel):
2830
choices: List[OpenAIChoices]
2931
created: float
3032
id: str
31-
model: str = "talkdai"
33+
model: str = "talkd-ai"
3234
object: str = "chat.completion"
3335
usage: OpenAIUsageDict
3436

@@ -37,4 +39,19 @@ class OpenAIModel(BaseModel):
3739
id: str
3840
object: str
3941
created: int
40-
owned_by: str
42+
owned_by: str
43+
44+
45+
class OpenAIStreamChoice(BaseModel):
46+
index: int
47+
delta: dict
48+
logprobs: Optional[str] = None
49+
finish_reason: Optional[str] = None
50+
51+
class OpenAIStreamSchema(BaseModel):
52+
id: str
53+
object: str = "chat.completion.chunk"
54+
created: int = int(datetime.datetime.now().timestamp())
55+
model: str = "talkd-ai"
56+
system_fingerprint: str = None
57+
choices: List[OpenAIStreamChoice]

‎src/tests/test_views.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_customized_openai_models_response(client):
5858
for i in ["id", "object", "created", "owned_by"]:
5959
assert i in response.json()[0]
6060

61-
def test_customized_openai_chat_completion_response(client, llm_mock_openai_router):
61+
def test_customized_openai_chat_completion_response_stream_false(client, llm_mock_openai_router):
6262
os.environ["LLM_CLASS"] = "dialog.llm.agents.default.DialogLLM"
6363
response = client.post("/openai/chat/completions", json={
6464
"model": "talkdai",
@@ -67,7 +67,8 @@ def test_customized_openai_chat_completion_response(client, llm_mock_openai_rout
6767
"role": "user",
6868
"content": "Hello"
6969
}
70-
]
70+
],
71+
"stream": False
7172
})
7273
assert response.status_code == 200
7374
for i in ["choices", "created", "id", "model", "object", "usage"]:

0 commit comments

Comments
 (0)