Skip to content

Commit 5611a87

Browse files
committed
Fixes comments from @lgabs
1 parent 7257e2c commit 5611a87

File tree

3 files changed

+39
-58
lines changed

3 files changed

+39
-58
lines changed

‎dialog_lib/agents/abstract.py‎

Lines changed: 36 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -114,29 +114,38 @@ def messages(self):
114114

115115

116116
class AbstractLCEL(AbstractLLM):
117+
def __init__(self, *args, **kwargs):
118+
kwargs["config"] = kwargs.get("config", {})
119+
self.memory_instance = kwargs.pop("memory", None)
120+
self.llm_api_key = kwargs
121+
self.prompt_content = kwargs.pop("prompt", None)
122+
self.chat_model = kwargs.pop("model_class")
123+
self.embedding_llm = kwargs.pop("embedding_llm")
124+
self.cosine_similarity_threshold = kwargs.pop("cosine_similarity_threshold", 0.3)
125+
self.top_k = kwargs.pop("top_k", 3)
126+
super().__init__(*args, **kwargs)
117127

118128
@property
119129
def document_prompt(self):
120130
return PromptTemplate.from_template(template="{page_content}")
121131

122132
@property
123-
def model(self):
124-
"""
125-
builds and returns the chat model for the LCEL
126-
"""
127-
raise NotImplementedError("Chat model must be implemented")
133+
def retriever(self):
134+
return DialogRetriever(
135+
session=self.dbsession,
136+
embedding_llm=self.embedding_llm,
137+
threshold=self.cosine_similarity_threshold,
138+
top_k=self.top_k
139+
)
128140

129141
@property
130-
def retriever(self):
131-
"""
132-
builds and returns the retriever for the LCEL
133-
"""
134-
raise NotImplementedError("Retriever must be implemented")
142+
def model(self):
143+
return self.chat_model
135144

136-
def documents_formatter(self, docs, document_separator="\n\n"):
145+
def combine_docs(self, docs, document_separator="\n\n"):
137146
"""
138147
This is the default combine_documents function that returns the documents as is.
139-
We use the default format_documents function from Langchain.
148+
We use the default combine_docs function from Langchain.
140149
"""
141150
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
142151
return document_separator.join(doc_strings)
@@ -161,9 +170,6 @@ def fallback_chain(self):
161170
]
162171
)
163172

164-
def parse_fallback(ai_message):
165-
return ai_message.content
166-
167173
return (
168174
fallback_prompt | RunnableLambda(lambda x: x.messages[-1])
169175
)
@@ -176,7 +182,7 @@ def answer_chain(self):
176182
return (
177183
RunnableParallel(
178184
{
179-
"context": itemgetter("relevant_contents") | RunnableLambda(self.documents_formatter),
185+
"context": itemgetter("relevant_contents") | RunnableLambda(self.combine_docs),
180186
"input": itemgetter("input"),
181187
"chat_history": itemgetter("chat_history"),
182188
}
@@ -253,6 +259,19 @@ def invoke(self, input: dict):
253259
"""
254260
return self.process(input)
255261

262+
def generate_prompt(self, input_text):
263+
self.prompt = ChatPromptTemplate.from_messages(
264+
[
265+
("system", "What can I help you with today?"),
266+
MessagesPlaceholder(variable_name="chat_history"),
267+
("system", "Here is some context for the user request: {context}"),
268+
("human", input_text),
269+
]
270+
)
271+
272+
def postprocess(self, output):
273+
return output.content
274+
256275

257276
class AbstractRAG(AbstractLLM):
258277
relevant_contents = []
@@ -287,7 +306,7 @@ def __init__(self, *args, **kwargs):
287306
warnings.warn(
288307
(
289308
"AbstractDialog will be deprecated in release 0.2 due to the creation of Langchain's LCEL. ",
290-
"Please use AbstractLCELDialog instead."
309+
"Please use AbstractLCEL instead."
291310
), DeprecationWarning, stacklevel=3
292311
)
293312
kwargs["config"] = kwargs.get("config", {})
@@ -335,41 +354,3 @@ def llm(self):
335354
**chain_settings
336355
)
337356

338-
339-
class AbstractLCELDialog(AbstractLCEL):
340-
def __init__(self, *args, **kwargs):
341-
kwargs["config"] = kwargs.get("config", {})
342-
343-
self.memory_instance = kwargs.pop("memory", None)
344-
self.llm_api_key = kwargs
345-
self.prompt_content = kwargs.pop("prompt", None)
346-
self.chat_model = kwargs.pop("model_class")
347-
self.embedding_llm = kwargs.pop("embedding_llm")
348-
super().__init__(*args, **kwargs)
349-
350-
@property
351-
def retriever(self):
352-
return DialogRetriever(
353-
session=self.dbsession,
354-
embedding_llm=self.embedding_llm,
355-
threshold=0.3,
356-
top_k=3
357-
)
358-
359-
@property
360-
def model(self):
361-
return self.chat_model
362-
363-
def generate_prompt(self, input_text):
364-
self.prompt = ChatPromptTemplate.from_messages(
365-
[
366-
("system", "What can I help you with today?"),
367-
MessagesPlaceholder(variable_name="chat_history"),
368-
("system", "Here is some context for the user request: {context}"),
369-
("human", "{input}"),
370-
]
371-
)
372-
return input_text
373-
374-
def postprocess(self, output):
375-
return output.content

‎dialog_lib/agents/openai.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from .abstract import AbstractDialog, AbstractLCELDialog
2+
from .abstract import AbstractDialog, AbstractLCEL
33
from langchain_openai import OpenAIEmbeddings
44
from langchain_openai.chat_models.base import ChatOpenAI
55
from dialog_lib.embeddings.retrievers import DialogRetriever
@@ -20,7 +20,7 @@ def postprocess(self, output):
2020
return output.get("text")
2121

2222

23-
class DialogLCELOpenAI(AbstractLCELDialog):
23+
class DialogLCELOpenAI(AbstractLCEL):
2424
def __init__(self, *args, **kwargs):
2525
self.openai_api_key = kwargs.get("llm_api_key") or os.environ.get("OPENAI_API_KEY")
2626
kwargs["model_class"] = ChatOpenAI(

‎pyproject.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "dialog-lib"
3-
version = "0.0.2.0.alpha"
3+
version = "0.0.2.0.alpha1"
44
description = ""
55
authors = ["Talkd.AI <foss@talkd.ai>"]
66
license = "MIT"

0 commit comments

Comments
 (0)