Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"
homepage = "https://github.com/talkdai/dialog"

[tool.poetry.dependencies]
python = ">=3.11,<3.12"
python = ">=3.11,<3.13"
fastapi = "0.109.2"
sqlalchemy = "^2.0.23"
langchain = "0.1.7"
Expand All @@ -25,6 +25,7 @@ scikit-learn = "^1.3.2"
alembic = "^1.12.1"
langchain-community = "^0.0.20"
importlib-metadata = "^7.0.1"
langchain-openai = "^0.0.6"

[tool.poetry.group.dev.dependencies]
ipdb = "^0.13.13"
Expand Down
6 changes: 4 additions & 2 deletions src/dialog/llm/abstract_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.prompts import ChatPromptTemplate

class AbstractLLM:
def __init__(self, config, session_id=None, parent_session_id=None):
def __init__(self, config, session_id=None, parent_session_id=None, dataset=None, llm_key=None):
"""
:param config: Configuration dictionary

Expand All @@ -19,7 +19,9 @@ def __init__(self, config, session_id=None, parent_session_id=None):

self.config = config
self.prompt = None
self.session_id = session_id
self.session_id = session_id if dataset is None else f"{dataset}_{session_id}"
self.dataset = dataset
self.llm_key = llm_key
self.parent_session_id = parent_session_id

def get_prompt(self, input) -> ChatPromptTemplate:
Expand Down
4 changes: 2 additions & 2 deletions src/dialog/llm/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def memory(self) -> BaseChatMemory:
return None

def generate_prompt(self, input):
relevant_contents = get_most_relevant_contents_from_message(input, top=1)
relevant_contents = get_most_relevant_contents_from_message(input, top=1, dataset=self.dataset)

if len(relevant_contents) == 0:
prompt_templating = [
Expand Down Expand Up @@ -65,7 +65,7 @@ def llm(self) -> LLMChain:
conversation_options ={
"llm": ChatOpenAI(
**llm_config,
openai_api_key=OPENAI_API_KEY
openai_api_key=self.llm_key or OPENAI_API_KEY
),
"prompt": self.prompt,
"verbose": self.config.get("verbose", False)
Expand Down
2 changes: 1 addition & 1 deletion src/dialog/llm/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from langchain_community.embeddings import OpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from sqlalchemy import select

from dialog.models import CompanyContent
Expand Down
1 change: 1 addition & 0 deletions src/dialog/migrations/versions/b3ca30115351_.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgvector;")
op.create_table(
"contents",
sa.Column("id", sa.Integer(), nullable=False),
Expand Down