Skip to content

Commit bd59cdc

Browse files
committed
Adds session injection function for better session management in stand-alone library
1 parent eb9e41b commit bd59cdc

File tree

11 files changed

+29
-15
lines changed

11 files changed

+29
-15
lines changed

‎.gitignore‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,5 @@ poetry.toml
174174
pyrightconfig.json
175175

176176
# End of https://www.toptal.com/developers/gitignore/api/python
177-
n
177+
n
178+
.envrc

‎dialog_lib/agents/abstract.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from langchain_core.runnables.history import RunnableWithMessageHistory
1313
from langchain.chains.conversation.memory import ConversationBufferMemory
1414

15+
from dialog_lib.db import get_session
1516
from dialog_lib.db.memory import CustomPostgresChatMessageHistory, get_memory_instance
1617
from dialog_lib.embeddings.retrievers import DialogRetriever
1718

@@ -24,7 +25,7 @@ def __init__(
2425
parent_session_id=None,
2526
dataset=None,
2627
llm_api_key=None,
27-
dbsession=None,
28+
dbsession=get_session(),
2829
):
2930
"""
3031
:param config: Configuration dictionary

‎dialog_lib/db/__init__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
add_user_message_to_message_history,
55
get_messages,
66
)
7+
from .session import get_session

‎dialog_lib/db/memory.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import psycopg
2-
2+
from .session import get_session
33
from langchain_postgres import PostgresChatMessageHistory
44
from langchain.schema.messages import BaseMessage, _message_to_dict
55

@@ -15,7 +15,7 @@ def __init__(
1515
self,
1616
*args,
1717
parent_session_id=None,
18-
dbsession=None,
18+
dbsession=get_session(),
1919
chats_model=Chat,
2020
chat_messages_model=ChatMessages,
2121
ssl_mode=None,
@@ -67,7 +67,7 @@ def add_message(self, message: BaseMessage) -> None:
6767
def generate_memory_instance(
6868
session_id,
6969
parent_session_id=None,
70-
dbsession=None,
70+
dbsession=get_session(),
7171
database_url=None,
7272
chats_model=Chat,
7373
chat_messages_model=ChatMessages,
@@ -88,7 +88,7 @@ def generate_memory_instance(
8888

8989

9090
def add_user_message_to_message_history(
91-
session_id, message, memory=None, dbsession=None, database_url=None
91+
session_id, message, memory=None, dbsession=get_session(), database_url=None
9292
):
9393
"""
9494
Add a user message to the message history and returns the updated
@@ -103,7 +103,7 @@ def add_user_message_to_message_history(
103103
return memory
104104

105105

106-
def get_messages(session_id, dbsession=None, database_url=None):
106+
def get_messages(session_id, dbsession=get_session(), database_url=None):
107107
"""
108108
Get all messages for a given session_id
109109
"""

‎dialog_lib/db/session.py‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import os
2+
3+
import sqlalchemy as sa
4+
from sqlalchemy.orm import Session
5+
6+
def get_session():
7+
engine = sa.create_engine(os.environ.get("DATABASE_URL"))
8+
session = Session(engine)
9+
yield session
10+
session.close()

‎dialog_lib/db/utils.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import uuid
2-
2+
from .session import get_session
33
from .models import Chat
44

55

6-
def create_chat_session(identifier=None, dbsession=None, model=Chat):
6+
def create_chat_session(identifier=None, dbsession=get_session(), model=Chat):
77
if identifier is None:
88
identifier = uuid.uuid4().hex
99

‎dialog_lib/loaders/csv.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dialog_lib.db import get_session
12
from dialog_lib.db.models import CompanyContent
23
from dialog_lib.embeddings.generate import generate_embedding
34

@@ -6,9 +7,10 @@
67

78

89
def load_csv(
9-
file_path, dbsession, embeddings_model_instance=None,
10+
file_path, dbsession=get_session(), embeddings_model_instance=None,
1011
embedding_llm_model=None, embedding_llm_api_key=None, company_id=None
1112
):
13+
1214
loader = CSVLoader(file_path=file_path)
1315
contents = loader.load()
1416

‎dialog_lib/loaders/web.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from dialog_lib.db.models import CompanyContent
22
from dialog_lib.embeddings.generate import generate_embedding
3-
3+
from dialog_lib.db import get_session
44
from langchain_community.document_loaders import WebBaseLoader
55

66

7-
def load_webpage(url, dbsession, embeddings_model_instance, company_id=None):
7+
def load_webpage(url, embeddings_model_instance, dbsession=get_session(), company_id=None):
88
loader = WebBaseLoader(url)
99
contents = loader.load()
1010

‎dialog_lib/manage.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def anthropic(model, temperature, llm_api_key, prompt, debug):
7676
@click.option("--llm-api-key", default=get_llm_key(), help="The LLM API key", required=True)
7777
@click.option("--file", help="The CSV file to load the data from", required=True)
7878
def load_csv(database_url, llm_api_key, file):
79-
breakpoint()
8079
engine = create_engine(database_url)
8180
dbsession = Session(engine.connect())
8281
csv_loader(
@@ -86,6 +85,7 @@ def load_csv(database_url, llm_api_key, file):
8685
embedding_llm_api_key=llm_api_key
8786
)
8887
click.echo("## Loaded the CSV file to the database")
88+
dbsession.close()
8989

9090

9191
def main():

‎dialog_lib/tests/agents/test_abstract_agents.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def test_abstract_agent_with_valid_config():
2020
assert agent.dataset is None
2121
assert agent.llm_api_key is None
2222
assert agent.parent_session_id is None
23-
assert agent.dbsession is None
2423

2524
def test_abstract_agent_get_prompt():
2625
config = {

0 commit comments

Comments
 (0)