Skip to content
1 change: 1 addition & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ VERBOSE_LLM=True
DIALOG_DATA_PATH=./know.csv
PROJECT_CONFIG=./prompt.toml
DATABASE_URL=postgresql://talkdai:talkdai@db:5432/talkdai
DEBUG=false
7 changes: 6 additions & 1 deletion etc/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@

alembic upgrade head
python load_csv.py --path ${DIALOG_DATA_PATH}
uvicorn main:app --host 0.0.0.0 --port ${PORT}

if "${DEBUG}"; then
uvicorn main:app --host 0.0.0.0 --port ${PORT} --reload
else
uvicorn main:app --host 0.0.0.0 --port ${PORT}
fi
5 changes: 2 additions & 3 deletions src/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def get_most_relevant_contents_from_message(message, top=5):
message_embedding = generate_embedding(message)
possible_contents = session.scalars(
select(CompanyContent)
.filter(CompanyContent.embedding.l2_distance(message_embedding) < 5)
.filter(CompanyContent.embedding.l2_distance(message_embedding) < 1)
.order_by(CompanyContent.embedding.l2_distance(message_embedding).asc())
.limit(top)
).all()
return possible_contents


def process_user_intent(session_id, message):
async def process_user_intent(session_id, message):
"""
Process user intent using memory and embeddings
"""
Expand Down Expand Up @@ -126,5 +126,4 @@ def process_user_intent(session_id, message):

# categorize conversation history in background
asyncio.create_task(categorize_conversation_history(chat_memory))

return ai_message
20 changes: 12 additions & 8 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
# *-* coding: utf-8 *-*
import datetime
import logging
import uuid

from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware

from llm import process_user_intent
from llm.memory import get_messages

from models import Chat as ChatEntity
from models.db import engine, session

from pydantic import BaseModel
from sqlalchemy import text

import datetime
import logging
from settings import LOGGING_LEVEL
from models.helpers import create_session as db_create_session
from webhooks.router import router


logging.basicConfig(
level=LOGGING_LEVEL,
format="%(asctime)s - %(levelname)s - %(message)s"
Expand All @@ -36,6 +43,7 @@
allow_headers=["*"],
)

app.include_router(router, prefix="/webhooks", tags=["webhooks"])

class Chat(BaseModel):
message: str
Expand All @@ -61,7 +69,7 @@ async def post_message(chat_id: str, message: Chat):
detail="Chat ID not found",
)
start_time = datetime.datetime.now()
ai_message = process_user_intent(chat_id, message.message)
ai_message = await process_user_intent(chat_id, message.message)
duration = datetime.datetime.now() - start_time
logging.info(f"Request processing time for chat_id {chat_id}: {duration}")
return {"message": ai_message["text"]}
Expand All @@ -83,8 +91,4 @@ async def get_chat_content(chat_id):

@app.post("/session")
async def create_session():
session_uuid = uuid.uuid4().hex
chat = ChatEntity(uuid=session_uuid)
session.add(chat)
session.commit()
return {"chat_id": session_uuid}
return db_create_session()
30 changes: 30 additions & 0 deletions src/models/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import uuid

from sqlalchemy import select
from sqlalchemy.exc import NoResultFound

from models import Chat as ChatEntity
from models.db import session
from psycopg2.errors import UniqueViolation


def create_session(identifier = None):
if identifier is None:
session_uuid = uuid.uuid4().hex
else:
session_uuid = identifier

try:
instance = session.query(ChatEntity).filter_by(uuid=session_uuid).one()
except NoResultFound:
instance = None
except UniqueViolation:
return {"chat_id": session_uuid}

if instance is not None:
return {"chat_id": instance.uuid}

chat = ChatEntity(uuid=session_uuid)
session.add(chat)
session.commit()
return {"chat_id": session_uuid}
7 changes: 5 additions & 2 deletions src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
default={},
)

LLM_CONFIG = {"temperature": 0.2, "model_name": "gpt-3.5-turbo"}
LLM_CONFIG = {"temperature": 0.2, "model_name": "gpt-3.5-turbo"}
LLM_CONFIG.update(PROJECT_CONFIG.get("llm", {}))
MODEL_NAME = LLM_CONFIG.get("model_name")

PROMPT = PROJECT_CONFIG.get("prompt", {})

WHATSAPP_VERIFY_TOKEN = config("WHATSAPP_VERIFY_TOKEN", "1234567890")
WHATSAPP_API_TOKEN = config("WHATSAPP_API_TOKEN", "1234567890")
WHATSAPP_ACCOUNT_NUMBER = config("WHATSAPP_ACCOUNT_NUMBER", "")
Empty file added src/webhooks/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions src/webhooks/responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import requests
import logging

from fastapi import HTTPException

from llm import process_user_intent
from settings import WHATSAPP_VERIFY_TOKEN, WHATSAPP_API_TOKEN, WHATSAPP_ACCOUNT_NUMBER

from models.helpers import create_session

from webhooks.serializers import *

logger = logging.getLogger(__name__)


async def whatsapp_get_response(request):
"""
Returns the challenge response for WhatsApp if verify token matches
the one available in settings, else returns None
"""
if request.query_params.get("hub.verify_token") == WHATSAPP_VERIFY_TOKEN:
return int(request.query_params.get("hub.challenge"))

raise HTTPException(status_code=404)


async def whatsapp_post_response(request, body):
value = body["entry"][0]["changes"][0]["value"]
try:
message = value["messages"][0]["text"]["body"]
except KeyError:
raise HTTPException(status_code=200)

phone_number_id = value["metadata"]["phone_number_id"]
from_number = value["messages"][0]["from"]
logger.info(value)
headers = {
"Authorization": f"Bearer {WHATSAPP_API_TOKEN}",
"Content-Type": "application/json",
}
url = f"https://graph.facebook.com/v17.0/{WHATSAPP_ACCOUNT_NUMBER}/messages"

create_session(identifier=from_number)

processed_message = await process_user_intent(from_number, message)
processed_message = processed_message["text"]
logger.info("Processed message: %s", processed_message)

data = {
"messaging_product": "whatsapp",
"to": from_number,
"type": "text",
"text": {"body": processed_message},
}
response = requests.post(url, json=data, headers=headers)
if response.status_code not in [200, 201]:
logger.info(f"Failed request: {response.text}")

response.raise_for_status()
return {"status": "success"}
46 changes: 46 additions & 0 deletions src/webhooks/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import logging

from typing import Any
from webhooks.responses import *
from urllib.parse import urlparse

from fastapi import APIRouter, Request, Body
from fastapi import HTTPException

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

router = APIRouter()

POSSIBLE_ORIGINS = {
"facebook": whatsapp_serializer,
}

GET_RESPONSES = {
"facebook": whatsapp_get_response,
}

POST_RESPONSES = {
"facebook": whatsapp_post_response
}

@router.get("/{origin}")
async def webhook_get(origin: str, request: Request):
is_existing_origin = origin in POSSIBLE_ORIGINS.keys()
if not is_existing_origin:
raise HTTPException(status_code=404)

response_function = GET_RESPONSES.get(origin, None)

content = await response_function(request)
return content


@router.post("/{origin}")
async def webhook_post(origin: str, request: Request, payload: Any = Body(None)):
serializer = POST_RESPONSES.get(origin, None)
if serializer is None:
raise HTTPException(status_code=404)

content = await serializer(request, payload)
return content
4 changes: 4 additions & 0 deletions src/webhooks/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@


def whatsapp_serializer(request):
pass