Skip to content

Commit 3eac42a

Browse files
committed
Adds LCEL native support
1 parent b61f1a7 commit 3eac42a

File tree

9 files changed

+110
-131
lines changed

9 files changed

+110
-131
lines changed

‎docs/plugins.md‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,8 @@ PLUGINS=plugins.whats_audio_synth.main,
5151

5252
### Using Langchain's LCEL structure
5353

54-
If you want to use Langchain's LCEL structure, you can use the `DialogLcelLLM` setting the variable to the 'dialog.llm.lcel_default.DialogLcelLLM'.
54+
If you want to use Langchain's LCEL structure, you can use the default LCEL implementation available in the file `dialog/llm/agents/lcel.py`. To use it, you need to add the following environment variable to your .env file:
55+
56+
```bash
57+
LLM_CLASS=dialog.llm.agents.lcel.runnable
58+
```

‎src/dialog/llm/__init__.py‎

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33

44
from typing import Type
55

6-
from .default import DialogLLM
76
from dialog_lib.agents.abstract import AbstractLLM
87
from dialog.settings import Settings
98

9+
from langchain.schema.runnable import RunnablePassthrough
10+
from langchain_core.runnables.base import RunnableSequence
1011

11-
def get_llm_class() -> Type[AbstractLLM]:
12+
from .agents.default import *
13+
from .agents.lcel import *
14+
15+
16+
def get_llm_class():
1217
if Settings().LLM_CLASS is None:
1318
return DialogLLM
1419

@@ -20,5 +25,34 @@ def get_llm_class() -> Type[AbstractLLM]:
2025
llm_class = getattr(llm_module, class_name)
2126
except Exception as e:
2227
logging.info(f"Failed to load LLM class {Settings().LLM_CLASS}. Using default LLM. Exception: {e}")
28+
raise
29+
30+
if isinstance(llm_class, AbstractLLM):
31+
return llm_class, "AbstractLLM"
32+
33+
elif 'langchain_core.runnables' in str(type(llm_class)):
34+
return llm_class, "LCELRunnable"
35+
36+
logging.info(f"Type for LLM class is: {type(llm_class)}")
37+
38+
return DialogLLM, "AbstractLLM"
39+
40+
41+
def process_user_message(message, chat_id=None):
42+
LLM, llm_type = get_llm_class()
43+
if llm_type == "AbstractLLM":
44+
llm_instance = LLM(config=Settings().PROJECT_CONFIG, session_id=chat_id)
45+
ai_message = llm_instance.process(message.message)
46+
47+
elif llm_type == "LCELRunnable":
48+
ai_message = LLM.invoke(
49+
{"input": message.message},
50+
{"configurable": {
51+
"session_id": chat_id,
52+
"model": Settings().PROJECT_CONFIG.get("model_name", "gpt-3.5-turbo"),
53+
**Settings().PROJECT_CONFIG,
54+
}}
55+
)
56+
ai_message = {"text": ai_message.content}
2357

24-
return llm_class or DialogLLM
58+
return ai_message

‎src/dialog/llm/agents/__init__.py‎

Whitespace-only changes.
File renamed without changes.

‎src/dialog/llm/agents/lcel.py‎

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
from dialog.db import get_session
3+
from dialog.settings import Settings
4+
from dialog_lib.memory import generate_memory_instance
5+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
6+
from langchain_openai.chat_models import ChatOpenAI
7+
from langchain_core.runnables import ConfigurableField
8+
from langchain_core.runnables.history import RunnableWithMessageHistory
9+
from langchain.output_parsers.json import SimpleJsonOutputParser
10+
11+
# Here we define the model that we will be using as a base for our agent
12+
# as well as the model_name and temperature (some of these parameterers
13+
# are exclusive to OpenAI instances).
14+
chat_model = ChatOpenAI(
15+
model_name="gpt-3.5-turbo",
16+
temperature=0,
17+
openai_api_key=Settings().OPENAI_API_KEY,
18+
).configurable_fields(
19+
model_name=ConfigurableField(
20+
id="model_name",
21+
name="GPT Model",
22+
description="The GPT model to use"
23+
),
24+
temperature=ConfigurableField(
25+
id="temperature",
26+
name="Temperature",
27+
description="The temperature to use"
28+
)
29+
)
30+
31+
prompt = ChatPromptTemplate.from_messages(
32+
[
33+
(
34+
"system",
35+
Settings().PROJECT_CONFIG.get("prompt").get("prompt", "What can I help you with today?")
36+
),
37+
MessagesPlaceholder(variable_name="chat_history"),
38+
("human", "{input}"),
39+
]
40+
)
41+
42+
def get_memory_instance(session_id):
43+
return generate_memory_instance(
44+
session_id=session_id,
45+
dbsession=next(get_session()),
46+
database_url=Settings().DATABASE_URL
47+
)
48+
49+
chain = prompt | chat_model
50+
51+
runnable = RunnableWithMessageHistory(
52+
chain,
53+
get_memory_instance,
54+
input_messages_key='input',
55+
history_messages_key="chat_history"
56+
)

‎src/dialog/llm/lcel_default.py‎

Lines changed: 0 additions & 113 deletions
This file was deleted.

‎src/dialog/routers/dialog.py‎

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from pydantic import parse_obj_as
55

6-
from dialog.llm import get_llm_class
6+
from dialog.llm import process_user_message
77
from dialog_lib.db.memory import get_messages
88
from dialog_lib.db.models import Chat as ChatEntity
99
from dialog.schemas import ChatModel, SessionModel, SessionsModel
@@ -41,9 +41,7 @@ async def post_message(chat_id: str, message: ChatModel, session: Session = Depe
4141
detail="Chat ID not found",
4242
)
4343
start_time = datetime.datetime.now()
44-
LLM = get_llm_class()
45-
llm_instance = LLM(config=Settings().PROJECT_CONFIG, session_id=chat_id)
46-
ai_message = llm_instance.process(message.message)
44+
ai_message = process_user_message(message, chat_id)
4745
duration = datetime.datetime.now() - start_time
4846
logging.info(f"Request processing time for chat_id {chat_id}: {duration}")
4947
return {"message": ai_message["text"]}
@@ -56,9 +54,7 @@ async def ask_question_to_llm(message: ChatModel, session: Session = Depends(get
5654
using memory.
5755
"""
5856
start_time = datetime.datetime.now()
59-
LLM = get_llm_class()
60-
llm_instance = LLM(config=Settings().PROJECT_CONFIG)
61-
ai_message = llm_instance.process(message.message)
57+
ai_message = process_user_message(message, chat_id=None)
6258
duration = datetime.datetime.now() - start_time
6359
logging.info(f"Request processing time: {duration}")
6460
return {"message": ai_message["text"]}

‎src/tests/conftest.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,6 @@ def chat_session(dbsession):
4343

4444
@pytest.fixture
4545
def llm_mock(mocker):
46-
llm_mock = mocker.patch('dialog.routers.dialog.get_llm_class')
46+
llm_mock = mocker.patch('dialog.routers.dialog.process_user_message')
4747
llm_mock.process.return_value = {"text": "Hello"}
4848
return llm_mock

‎src/tests/llm_tests/test_abstract_llms.py‎

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from dialog.llm import get_llm_class
55
from dialog_lib.agents.abstract import AbstractLLM
6-
from dialog.llm.default import DialogLLM
6+
from dialog.llm.agents.default import DialogLLM
7+
from dialog.llm import DialogLLM
78

89

910
def test_abstract_llm_for_invalid_config():
@@ -38,10 +39,11 @@ def test_get_llm_class_get_default_class():
3839

3940
def test_get_llm_class_get_custom_class():
4041
os.environ["LLM_CLASS"] = "dialog_lib.agents.abstract.AbstractLLM"
41-
llm_class = get_llm_class()
42-
assert llm_class == AbstractLLM
42+
llm_class, llm_type = get_llm_class()
43+
assert llm_class == DialogLLM
44+
assert llm_type == "AbstractLLM"
4345

4446
def test_get_llm_class_with_invalid_class():
45-
os.environ["LLM_CLASS"] = "dialogl.llm.invalid_llm.InvalidLLM"
46-
llm_class = get_llm_class()
47-
assert llm_class == DialogLLM
47+
os.environ["LLM_CLASS"] = "dialog.llm.invalid_llm.InvalidLLM"
48+
with pytest.raises(ModuleNotFoundError):
49+
llm_class, llm_type = get_llm_class()

0 commit comments

Comments
 (0)