1- from langchain .prompts import ChatPromptTemplate , MessagesPlaceholder
1+ import warnings
2+ from operator import itemgetter
3+
4+ from langchain .schema import format_document
25from langchain .memory import ConversationBufferWindowMemory
6+ from langchain .prompts import ChatPromptTemplate , MessagesPlaceholder
37from langchain .chains .llm import LLMChain
8+ from langchain .prompts .prompt import PromptTemplate
49from langchain .memory .chat_memory import BaseChatMemory
5- from langchain . chains . conversation . memory import ConversationBufferMemory
10+ from langchain_core . runnables import RunnablePassthrough
611from langchain_core .runnables .history import RunnableWithMessageHistory
12+ from langchain .chains .conversation .memory import ConversationBufferMemory
13+
714from dialog_lib .db .memory import CustomPostgresChatMessageHistory , get_memory_instance
15+ from dialog_lib .embeddings .retrievers import DialogRetriever
816
917
1018class AbstractLLM :
@@ -104,45 +112,81 @@ def messages(self):
104112 return self .memory .messages
105113
106114
107- class AbstractLcelClass (AbstractLLM ):
115+ class AbstractLCEL (AbstractLLM ):
108116
117+ @property
118+ def document_prompt (self ):
119+ return PromptTemplate .from_template (template = "{page_content}" )
109120
110121 @property
111- def chat_model (self ):
122+ def model (self ):
112123 """
113124 builds and returns the chat model for the LCEL
114125 """
115126 raise NotImplementedError ("Chat model must be implemented" )
116127
128+ @property
129+ def retriever (self ):
130+ """
131+ builds and returns the retriever for the LCEL
132+ """
133+ raise NotImplementedError ("Retriever must be implemented" )
134+
135+ def documents_formatter (self , docs , document_separator = "\n \n " ):
136+ """
137+ This is the default combine_documents function that returns the documents as is.
138+ We use the default format_documents function from Langchain.
139+ """
140+ doc_strings = [format_document (doc , self .document_prompt ) for doc in docs ]
141+ return document_separator .join (doc_strings )
142+
117143 @property
118144 def context_dict (self ):
119145 """
120146 builds and returns the context dictionary for the LCEL
121147 """
122- raise NotImplementedError ("Context dictionary must be implemented" )
148+
149+ context_dict = {
150+ "input" : RunnablePassthrough (),
151+ "chat_history" : itemgetter ("chat_history" ),
152+ }
153+
154+ if self .retriever :
155+ context_dict ["context" ] = itemgetter ("input" ) | self .retriever | self .documents_formatter
156+
157+ return context_dict
123158
124159 @property
125160 def chain (self ):
126161 """
127162 builds and returns the chain for the LCEL
128163 """
129164 return (
130- self .context_dict | self .prompt | self .chat_model
165+ self .context_dict | self .prompt | self .model
131166 )
132167
133168 @property
134- def get_memory_instance (self ):
169+ def memory (self ):
135170 return get_memory_instance (
136171 session_id = self .session_id ,
137172 sqlalchemy_session = self .dbsession ,
138173 database_url = self .config .get ("database_url" )
139174 )
140175
176+ def get_session_history (self , something ):
177+ return CustomPostgresChatMessageHistory (
178+ connection_string = self .config .get ("database_url" ),
179+ session_id = self .session_id ,
180+ parent_session_id = self .parent_session_id ,
181+ table_name = "chat_messages" ,
182+ dbsession = self .dbsession ,
183+ )
184+
141185 @property
142186 def runnable (self ):
143- RunnableWithMessageHistory (
187+ return RunnableWithMessageHistory (
144188 self .chain ,
145- self .get_memory_instance ,
189+ self .get_session_history ,
146190 input_messages_key = 'input' ,
147191 history_messages_key = "chat_history"
148192 )
@@ -201,6 +245,13 @@ def process(self, input: str):
201245
202246class AbstractDialog (AbstractLLM ):
203247 def __init__ (self , * args , ** kwargs ):
248+ warnings .filterwarnings ("default" , category = DeprecationWarning )
249+ warnings .warn (
250+ (
251+ "AbstractDialog will be deprecated in release 0.2 due to the creation of Langchain's LCEL. " ,
252+ "Please use AbstractLCELDialog instead."
253+ ), DeprecationWarning , stacklevel = 3
254+ )
204255 kwargs ["config" ] = kwargs .get ("config" , {})
205256
206257 self .memory_instance = kwargs .pop ("memory" , None )
@@ -244,4 +295,41 @@ def llm(self):
244295 )
245296 return LLMChain (
246297 ** chain_settings
247- )
298+ )
299+
300+
301+ class AbstractLCELDialog (AbstractLCEL ):
302+ def __init__ (self , * args , ** kwargs ):
303+ kwargs ["config" ] = kwargs .get ("config" , {})
304+
305+ self .memory_instance = kwargs .pop ("memory" , None )
306+ self .llm_api_key = kwargs
307+ self .prompt_content = kwargs .pop ("prompt" , None )
308+ self .chat_model = kwargs .pop ("model_class" )
309+ self .embedding_llm = kwargs .pop ("embedding_llm" )
310+ super ().__init__ (* args , ** kwargs )
311+
312+ @property
313+ def retriever (self ):
314+ return DialogRetriever (
315+ session = self .dbsession ,
316+ embedding_llm = self .embedding_llm ,
317+ )
318+
319+ @property
320+ def model (self ):
321+ return self .chat_model
322+
323+ def generate_prompt (self , input_text ):
324+ self .prompt = ChatPromptTemplate .from_messages (
325+ [
326+ ("system" , "What can I help you with today?" ),
327+ MessagesPlaceholder (variable_name = "chat_history" ),
328+ ("system" , "Here is some context for the user request: {context}" ),
329+ ("human" , "{input}" ),
330+ ]
331+ )
332+ return input_text
333+
334+ def postprocess (self , output ):
335+ return output .content
0 commit comments