Skip to content

Commit 15bad73

Browse files
committed
Adds LCEL class and adds a sample class for OpenAI
1 parent 6b5b431 commit 15bad73

File tree

5 files changed

+149
-15
lines changed

5 files changed

+149
-15
lines changed

‎dialog_lib/agents/abstract.py‎

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
1+
import warnings
2+
from operator import itemgetter
3+
4+
from langchain.schema import format_document
25
from langchain.memory import ConversationBufferWindowMemory
6+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
37
from langchain.chains.llm import LLMChain
8+
from langchain.prompts.prompt import PromptTemplate
49
from langchain.memory.chat_memory import BaseChatMemory
5-
from langchain.chains.conversation.memory import ConversationBufferMemory
10+
from langchain_core.runnables import RunnablePassthrough
611
from langchain_core.runnables.history import RunnableWithMessageHistory
12+
from langchain.chains.conversation.memory import ConversationBufferMemory
13+
714
from dialog_lib.db.memory import CustomPostgresChatMessageHistory, get_memory_instance
15+
from dialog_lib.embeddings.retrievers import DialogRetriever
816

917

1018
class AbstractLLM:
@@ -104,45 +112,81 @@ def messages(self):
104112
return self.memory.messages
105113

106114

107-
class AbstractLcelClass(AbstractLLM):
115+
class AbstractLCEL(AbstractLLM):
108116

117+
@property
118+
def document_prompt(self):
119+
return PromptTemplate.from_template(template="{page_content}")
109120

110121
@property
111-
def chat_model(self):
122+
def model(self):
112123
"""
113124
builds and returns the chat model for the LCEL
114125
"""
115126
raise NotImplementedError("Chat model must be implemented")
116127

128+
@property
129+
def retriever(self):
130+
"""
131+
builds and returns the retriever for the LCEL
132+
"""
133+
raise NotImplementedError("Retriever must be implemented")
134+
135+
def documents_formatter(self, docs, document_separator="\n\n"):
136+
"""
137+
This is the default combine_documents function that returns the documents as is.
138+
We use the default format_documents function from Langchain.
139+
"""
140+
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
141+
return document_separator.join(doc_strings)
142+
117143
@property
118144
def context_dict(self):
119145
"""
120146
builds and returns the context dictionary for the LCEL
121147
"""
122-
raise NotImplementedError("Context dictionary must be implemented")
148+
149+
context_dict = {
150+
"input": RunnablePassthrough(),
151+
"chat_history": itemgetter("chat_history"),
152+
}
153+
154+
if self.retriever:
155+
context_dict["context"] = itemgetter("input") | self.retriever | self.documents_formatter
156+
157+
return context_dict
123158

124159
@property
125160
def chain(self):
126161
"""
127162
builds and returns the chain for the LCEL
128163
"""
129164
return (
130-
self.context_dict | self.prompt | self.chat_model
165+
self.context_dict | self.prompt | self.model
131166
)
132167

133168
@property
134-
def get_memory_instance(self):
169+
def memory(self):
135170
return get_memory_instance(
136171
session_id=self.session_id,
137172
sqlalchemy_session=self.dbsession,
138173
database_url=self.config.get("database_url")
139174
)
140175

176+
def get_session_history(self, something):
177+
return CustomPostgresChatMessageHistory(
178+
connection_string=self.config.get("database_url"),
179+
session_id=self.session_id,
180+
parent_session_id=self.parent_session_id,
181+
table_name="chat_messages",
182+
dbsession=self.dbsession,
183+
)
184+
141185
@property
142186
def runnable(self):
143-
RunnableWithMessageHistory(
187+
return RunnableWithMessageHistory(
144188
self.chain,
145-
self.get_memory_instance,
189+
self.get_session_history,
146190
input_messages_key='input',
147191
history_messages_key="chat_history"
148192
)
@@ -201,6 +245,13 @@ def process(self, input: str):
201245

202246
class AbstractDialog(AbstractLLM):
203247
def __init__(self, *args, **kwargs):
248+
warnings.filterwarnings("default", category=DeprecationWarning)
249+
warnings.warn(
250+
(
251+
"AbstractDialog will be deprecated in release 0.2 due to the creation of Langchain's LCEL. ",
252+
"Please use AbstractLCELDialog instead."
253+
), DeprecationWarning, stacklevel=3
254+
)
204255
kwargs["config"] = kwargs.get("config", {})
205256

206257
self.memory_instance = kwargs.pop("memory", None)
@@ -244,4 +295,41 @@ def llm(self):
244295
)
245296
return LLMChain(
246297
**chain_settings
247-
)
298+
)
299+
300+
301+
class AbstractLCELDialog(AbstractLCEL):
302+
def __init__(self, *args, **kwargs):
303+
kwargs["config"] = kwargs.get("config", {})
304+
305+
self.memory_instance = kwargs.pop("memory", None)
306+
self.llm_api_key = kwargs
307+
self.prompt_content = kwargs.pop("prompt", None)
308+
self.chat_model = kwargs.pop("model_class")
309+
self.embedding_llm = kwargs.pop("embedding_llm")
310+
super().__init__(*args, **kwargs)
311+
312+
@property
313+
def retriever(self):
314+
return DialogRetriever(
315+
session=self.dbsession,
316+
embedding_llm=self.embedding_llm,
317+
)
318+
319+
@property
320+
def model(self):
321+
return self.chat_model
322+
323+
def generate_prompt(self, input_text):
324+
self.prompt = ChatPromptTemplate.from_messages(
325+
[
326+
("system", "What can I help you with today?"),
327+
MessagesPlaceholder(variable_name="chat_history"),
328+
("system", "Here is some context for the user request: {context}"),
329+
("human", "{input}"),
330+
]
331+
)
332+
return input_text
333+
334+
def postprocess(self, output):
335+
return output.content

‎dialog_lib/agents/openai.py‎

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from .abstract import AbstractDialog
1+
from .abstract import AbstractDialog, AbstractLCELDialog
2+
from langchain_openai import OpenAIEmbeddings
23
from langchain_openai.chat_models.base import ChatOpenAI
4+
from dialog_lib.embeddings.retrievers import DialogRetriever
35

46

57
class DialogOpenAI(AbstractDialog):
@@ -14,4 +16,16 @@ def __init__(self, *args, **kwargs):
1416
super().__init__(*args, **kwargs)
1517

1618
def postprocess(self, output):
17-
return output.get("text")
19+
return output.get("text")
20+
21+
22+
class DialogLCELOpenAI(AbstractLCELDialog):
23+
def __init__(self, *args, **kwargs):
24+
self.openai_api_key = kwargs.get("llm_api_key") or os.environ.get("OPENAI_API_KEY")
25+
kwargs["model_class"] = ChatOpenAI(
26+
model=kwargs.pop("model"),
27+
temperature=kwargs.pop("temperature"),
28+
openai_api_key=self.openai_api_key,
29+
)
30+
kwargs["embedding_llm"] = OpenAIEmbeddings(openai_api_key=self.openai_api_key)
31+
super().__init__(*args, **kwargs)

‎poetry.lock‎

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "dialog-lib"
3-
version = "0.0.1.19"
3+
version = "0.0.1.20"
44
description = ""
55
authors = ["Talkd.AI <foss@talkd.ai>"]
66
license = "MIT"
@@ -14,7 +14,7 @@ click = "^8.1.7"
1414
pgvector = "^0.2.5"
1515
langchain-openai = "^0.1.8"
1616
psycopg2-binary = "^2.9.9"
17-
langchain-postgres = "^0.0.7"
17+
langchain-postgres = "0.0.7"
1818
langchain-community = "^0.2.5"
1919
langchain-anthropic = "^0.1.11"
2020
bs4 = "^0.0.2"

‎samples/openai/lcel/main.py‎

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
import logging
3+
from uuid import uuid4
4+
from sqlalchemy import create_engine
5+
from sqlalchemy.orm import Session
6+
from dialog_lib.agents.openai import DialogLCELOpenAI
7+
8+
logging.getLogger().setLevel(logging.ERROR)
9+
10+
database_url = "postgresql://talkdai:talkdai@db:5432/test_talkdai"
11+
12+
engine = create_engine(database_url)
13+
14+
dbsession = Session(engine)
15+
16+
17+
agent = DialogLCELOpenAI(
18+
model="gpt-4o",
19+
temperature=0.1,
20+
llm_api_key=os.environ.get("OPENAI_API_KEY"),
21+
prompt="You are a bot called Sara. Be nice to other human beings.",
22+
dbsession=dbsession,
23+
config={
24+
"database_url": database_url,
25+
},
26+
session_id=str(uuid4())
27+
)
28+
29+
while True:
30+
input_text = input("You: ")
31+
output_text = agent.process(input_text)
32+
print(f"Sara: {output_text}")

0 commit comments

Comments
 (0)