Skip to content

Commit f6147b1

Browse files
committed
Makes CustomPostgresConnection async-able
1 parent d6aa495 commit f6147b1

File tree

7 files changed

+1181
-922
lines changed

7 files changed

+1181
-922
lines changed

‎Dockerfile‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ COPY poetry.lock pyproject.toml README.md /app/
1919
COPY pytest.ini /app/dialog_lib/
2020

2121
USER root
22-
RUN apt update -y && apt upgrade -y && apt install gcc libpq-dev -y
22+
RUN apt update -y && apt upgrade -y && apt install gcc libpq-dev postgresql-client -y
2323
RUN pip install -U pip poetry
2424

2525
COPY /etc /app/etc

‎Makefile‎

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
.PHONY: test
1+
.PHONY: test bump-beta bump-major bump-minor
22

33
test:
4-
poetry run pytest --cov=dialog_lib dialog_lib/tests/
4+
poetry run pytest --cov=dialog_lib dialog_lib/tests/
5+
6+
bump-prepatch:
7+
poetry version --next-phase prepatch
8+
9+
bump-preminor:
10+
poetry version --next-phase preminor
11+
12+
bump-premajor:
13+
poetry version --next-phase premajor

‎dialog_lib/db/memory.py‎

Lines changed: 124 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import psycopg
2-
from .session import get_session
3-
from langchain_postgres import PostgresChatMessageHistory
4-
from langchain.schema.messages import BaseMessage, _message_to_dict
2+
3+
from typing import List
4+
from psycopg import sql
55

66
from .models import Chat, ChatMessages
7+
from .session import get_session, get_async_session, get_async_psycopg_connection
8+
9+
from langchain_postgres import PostgresChatMessageHistory
10+
from langchain.schema.messages import BaseMessage, _message_to_dict
711

812

913
class CustomPostgresChatMessageHistory(PostgresChatMessageHistory):
@@ -16,45 +20,127 @@ def __init__(
1620
*args,
1721
parent_session_id=None,
1822
dbsession=get_session,
23+
async_dbsession=get_async_session,
1924
chats_model=Chat,
2025
chat_messages_model=ChatMessages,
2126
ssl_mode=None,
2227
**kwargs,
2328
):
2429
self.parent_session_id = parent_session_id
2530
self.dbsession = dbsession
31+
self.async_dbsession = async_dbsession
2632
self.chats_model = chats_model
2733
self.chat_messages_model = chat_messages_model
2834
self._connection = psycopg.connect(
2935
kwargs.pop("connection_string"), sslmode=ssl_mode
3036
)
37+
self._async_connection = None # Will be initialized when needed
3138
self._session_id = kwargs.pop("session_id")
32-
self._table_name = kwargs.pop("table_name")
33-
39+
self._table_name = kwargs.pop("table_name", chat_messages_model.__tablename__)
40+
41+
self.cursor = self._connection.cursor()
42+
43+
async def _initialize_async_connection(self):
44+
if self._async_connection is None:
45+
self._async_connection = await get_async_psycopg_connection()
46+
return self._async_connection
47+
48+
def _create_tables_queries(self, table_name):
49+
index_name = f"idx_{table_name}_session_id"
50+
return [
51+
sql.SQL(
52+
"""
53+
CREATE TABLE IF NOT EXISTS {table_name} (
54+
id SERIAL PRIMARY KEY,
55+
session_id TEXT NOT NULL,
56+
message JSONB NOT NULL,
57+
timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
58+
);"""
59+
).format(table_name=sql.Identifier(table_name)),
60+
sql.SQL(
61+
"""
62+
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id);
63+
"""
64+
).format(
65+
index_name=sql.Identifier(index_name),
66+
table_name=sql.Identifier(table_name)
67+
)
68+
]
69+
70+
def _get_messages_query(self, table_name):
71+
return [
72+
sql.SQL(
73+
"""
74+
SELECT message FROM {table_name} WHERE session_id = {session_id};
75+
"""
76+
).format(
77+
table_name=sql.Identifier(table_name),
78+
session_id=sql.Literal(self._session_id)
79+
)
80+
]
3481

3582
def create_tables(self) -> None:
3683
"""
37-
create table if it does not exist
38-
add a new column for timestamp
84+
Create table if it does not exist
85+
Add a new column for timestamp
86+
"""
87+
create_table_queries = self._create_tables_queries(self._table_name)
88+
for query in create_table_queries:
89+
self.cursor.execute(query)
90+
self._connection.commit()
91+
92+
async def acreate_tables(self) -> None:
93+
"""
94+
Asynchronously create tables.
95+
"""
96+
create_table_queries = self._create_tables_queries(self._table_name)
97+
async_conn = await self._initialize_async_connection()
98+
async with async_conn.cursor() as cursor:
99+
for query in create_table_queries:
100+
await cursor.execute(query)
101+
await async_conn.commit()
102+
103+
def get_messages(self):
104+
"""
105+
Retrieve messages synchronously.
106+
"""
107+
get_messages_query = self._get_messages_query(self._table_name)
108+
for query in get_messages_query:
109+
self.cursor.execute(query)
110+
return self.cursor.fetchall()
111+
112+
async def aget_messages(self):
113+
"""
114+
Retrieve messages asynchronously.
39115
"""
40-
create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
41-
id SERIAL PRIMARY KEY,
42-
session_id TEXT NOT NULL,
43-
message JSONB NOT NULL,
44-
timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
45-
);"""
46-
self.cursor.execute(create_table_query)
47-
self.connection.commit()
116+
get_messages_query = self._get_messages_query(self._table_name)
117+
async_conn = await self._initialize_async_connection()
118+
async with async_conn.cursor() as cursor:
119+
for query in get_messages_query:
120+
await cursor.execute(query)
121+
return await cursor.fetchall()
48122

49123
def add_tags(self, tags: str) -> None:
50-
"""Add tags for a given session_id/uuid on chats table"""
124+
"""
125+
Add tags for a given session_id/uuid on chats table.
126+
"""
51127
with self.dbsession() as session:
52128
session.query(self.chats_model).where(
53129
self.chats_model.session_id == self._session_id
54130
).update({getattr(self.chats_model, "tags"): tags})
131+
session.commit()
132+
133+
def add_messages(self, messages: List[BaseMessage]) -> None:
134+
"""
135+
Add messages to the record in PostgreSQL.
136+
"""
137+
for message in messages:
138+
self.add_message(message)
55139

56140
def add_message(self, message: BaseMessage) -> None:
57-
"""Append the message to the record in PostgreSQL"""
141+
"""
142+
Append the message to the record in PostgreSQL.
143+
"""
58144
message = self.chat_messages_model(
59145
session_id=self._session_id, message=_message_to_dict(message)
60146
)
@@ -63,6 +149,27 @@ def add_message(self, message: BaseMessage) -> None:
63149
self.dbsession.add(message)
64150
self.dbsession.commit()
65151

152+
async def aadd_messages(self, messages: List[BaseMessage]) -> None:
153+
"""
154+
Asynchronously add messages to the record in PostgreSQL.
155+
"""
156+
for message in messages:
157+
await self.aadd_message(message)
158+
159+
async def aadd_message(self, message: BaseMessage) -> None:
160+
"""
161+
Asynchronously append the message to the record in PostgreSQL.
162+
"""
163+
async_conn = await self._initialize_async_connection()
164+
async with async_conn.cursor() as cursor:
165+
await cursor.execute(
166+
sql.SQL("INSERT INTO {table_name} (session_id, message) VALUES (%s, %s)").format(
167+
table_name=sql.Identifier(self._table_name)
168+
),
169+
(self._session_id, _message_to_dict(message))
170+
)
171+
await async_conn.commit()
172+
66173

67174
def generate_memory_instance(
68175
session_id,

‎dialog_lib/db/session.py‎

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import os
2+
from functools import lru_cache
23

34
import sqlalchemy as sa
45
from sqlalchemy.orm import Session, sessionmaker
56

6-
from contextlib import contextmanager
7+
from contextlib import contextmanager, asynccontextmanager
8+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
9+
from psycopg_pool import AsyncConnectionPool
710

8-
from functools import cache
9-
10-
@cache
11-
def get_engine():
11+
@lru_cache()
12+
def get_sync_engine():
1213
return sa.create_engine(os.environ.get("DATABASE_URL"))
1314

1415
@contextmanager
15-
def session_scope():
16-
with Session(bind=get_engine()) as session:
16+
def sync_session_scope():
17+
with Session(bind=get_sync_engine()) as session:
1718
try:
1819
yield session
1920
session.commit()
@@ -24,5 +25,47 @@ def session_scope():
2425
session.close()
2526

2627
def get_session():
27-
with session_scope() as session:
28+
with sync_session_scope() as session:
29+
return session
30+
31+
@lru_cache()
32+
def get_async_engine():
33+
return create_async_engine(os.environ.get("DATABASE_URL"))
34+
35+
@asynccontextmanager
36+
async def async_session_scope():
37+
async_session = sessionmaker(
38+
get_async_engine(), class_=AsyncSession, expire_on_commit=False
39+
)
40+
async with async_session() as session:
41+
try:
42+
yield session
43+
await session.commit()
44+
except Exception as exc:
45+
await session.rollback()
46+
raise exc
47+
finally:
48+
await session.close()
49+
50+
async def get_async_session():
51+
async with async_session_scope() as session:
2852
return session
53+
54+
@lru_cache()
55+
def create_async_psycopg_pool():
56+
return AsyncConnectionPool(os.environ.get("DATABASE_URL"))
57+
58+
@asynccontextmanager
59+
async def async_psycopg_connection():
60+
pool = create_async_psycopg_pool()
61+
async with pool.connection() as conn:
62+
try:
63+
yield conn
64+
await conn.commit()
65+
except Exception:
66+
await conn.rollback()
67+
raise
68+
69+
async def get_async_psycopg_connection():
70+
async with async_psycopg_connection() as conn:
71+
return conn

‎docker-compose.yml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ services:
2020
build:
2121
context: .
2222
dockerfile: Dockerfile
23-
entrypoint: pytest -vvv
23+
command: pytest -vvv
2424
stdin_open: true
2525
tty: true
2626
depends_on:

0 commit comments

Comments
 (0)