11import psycopg
2- from . session import get_session
3- from langchain_postgres import PostgresChatMessageHistory
4- from langchain . schema . messages import BaseMessage , _message_to_dict
2+
3+ from typing import List
4+ from psycopg import sql
55
66from .models import Chat , ChatMessages
7+ from .session import get_session , get_async_session , get_async_psycopg_connection
8+
9+ from langchain_postgres import PostgresChatMessageHistory
10+ from langchain .schema .messages import BaseMessage , _message_to_dict
711
812
913class CustomPostgresChatMessageHistory (PostgresChatMessageHistory ):
@@ -16,45 +20,127 @@ def __init__(
1620 * args ,
1721 parent_session_id = None ,
1822 dbsession = get_session ,
23+ async_dbsession = get_async_session ,
1924 chats_model = Chat ,
2025 chat_messages_model = ChatMessages ,
2126 ssl_mode = None ,
2227 ** kwargs ,
2328 ):
2429 self .parent_session_id = parent_session_id
2530 self .dbsession = dbsession
31+ self .async_dbsession = async_dbsession
2632 self .chats_model = chats_model
2733 self .chat_messages_model = chat_messages_model
2834 self ._connection = psycopg .connect (
2935 kwargs .pop ("connection_string" ), sslmode = ssl_mode
3036 )
37+ self ._async_connection = None # Will be initialized when needed
3138 self ._session_id = kwargs .pop ("session_id" )
32- self ._table_name = kwargs .pop ("table_name" )
33-
39+ self ._table_name = kwargs .pop ("table_name" , chat_messages_model .__tablename__ )
40+
41+ self .cursor = self ._connection .cursor ()
42+
43+ async def _initialize_async_connection (self ):
44+ if self ._async_connection is None :
45+ self ._async_connection = await get_async_psycopg_connection ()
46+ return self ._async_connection
47+
48+ def _create_tables_queries (self , table_name ):
49+ index_name = f"idx_{ table_name } _session_id"
50+ return [
51+ sql .SQL (
52+ """
53+ CREATE TABLE IF NOT EXISTS {table_name} (
54+ id SERIAL PRIMARY KEY,
55+ session_id TEXT NOT NULL,
56+ message JSONB NOT NULL,
57+ timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
58+ );"""
59+ ).format (table_name = sql .Identifier (table_name )),
60+ sql .SQL (
61+ """
62+ CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id);
63+ """
64+ ).format (
65+ index_name = sql .Identifier (index_name ),
66+ table_name = sql .Identifier (table_name )
67+ )
68+ ]
69+
70+ def _get_messages_query (self , table_name ):
71+ return [
72+ sql .SQL (
73+ """
74+ SELECT message FROM {table_name} WHERE session_id = {session_id};
75+ """
76+ ).format (
77+ table_name = sql .Identifier (table_name ),
78+ session_id = sql .Literal (self ._session_id )
79+ )
80+ ]
3481
3582 def create_tables (self ) -> None :
3683 """
37- create table if it does not exist
38- add a new column for timestamp
84+ Create table if it does not exist
85+ Add a new column for timestamp
86+ """
87+ create_table_queries = self ._create_tables_queries (self ._table_name )
88+ for query in create_table_queries :
89+ self .cursor .execute (query )
90+ self ._connection .commit ()
91+
92+ async def acreate_tables (self ) -> None :
93+ """
94+ Asynchronously create tables.
95+ """
96+ create_table_queries = self ._create_tables_queries (self ._table_name )
97+ async_conn = await self ._initialize_async_connection ()
98+ async with async_conn .cursor () as cursor :
99+ for query in create_table_queries :
100+ await cursor .execute (query )
101+ await async_conn .commit ()
102+
103+ def get_messages (self ):
104+ """
105+ Retrieve messages synchronously.
106+ """
107+ get_messages_query = self ._get_messages_query (self ._table_name )
108+ for query in get_messages_query :
109+ self .cursor .execute (query )
110+ return self .cursor .fetchall ()
111+
112+ async def aget_messages (self ):
113+ """
114+ Retrieve messages asynchronously.
39115 """
40- create_table_query = f"""CREATE TABLE IF NOT EXISTS { self .table_name } (
41- id SERIAL PRIMARY KEY,
42- session_id TEXT NOT NULL,
43- message JSONB NOT NULL,
44- timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
45- );"""
46- self .cursor .execute (create_table_query )
47- self .connection .commit ()
116+ get_messages_query = self ._get_messages_query (self ._table_name )
117+ async_conn = await self ._initialize_async_connection ()
118+ async with async_conn .cursor () as cursor :
119+ for query in get_messages_query :
120+ await cursor .execute (query )
121+ return await cursor .fetchall ()
48122
49123 def add_tags (self , tags : str ) -> None :
50- """Add tags for a given session_id/uuid on chats table"""
124+ """
125+ Add tags for a given session_id/uuid on chats table.
126+ """
51127 with self .dbsession () as session :
52128 session .query (self .chats_model ).where (
53129 self .chats_model .session_id == self ._session_id
54130 ).update ({getattr (self .chats_model , "tags" ): tags })
131+ session .commit ()
132+
133+ def add_messages (self , messages : List [BaseMessage ]) -> None :
134+ """
135+ Add messages to the record in PostgreSQL.
136+ """
137+ for message in messages :
138+ self .add_message (message )
55139
56140 def add_message (self , message : BaseMessage ) -> None :
57- """Append the message to the record in PostgreSQL"""
141+ """
142+ Append the message to the record in PostgreSQL.
143+ """
58144 message = self .chat_messages_model (
59145 session_id = self ._session_id , message = _message_to_dict (message )
60146 )
@@ -63,6 +149,27 @@ def add_message(self, message: BaseMessage) -> None:
63149 self .dbsession .add (message )
64150 self .dbsession .commit ()
65151
152+ async def aadd_messages (self , messages : List [BaseMessage ]) -> None :
153+ """
154+ Asynchronously add messages to the record in PostgreSQL.
155+ """
156+ for message in messages :
157+ await self .aadd_message (message )
158+
159+ async def aadd_message (self , message : BaseMessage ) -> None :
160+ """
161+ Asynchronously append the message to the record in PostgreSQL.
162+ """
163+ async_conn = await self ._initialize_async_connection ()
164+ async with async_conn .cursor () as cursor :
165+ await cursor .execute (
166+ sql .SQL ("INSERT INTO {table_name} (session_id, message) VALUES (%s, %s)" ).format (
167+ table_name = sql .Identifier (self ._table_name )
168+ ),
169+ (self ._session_id , _message_to_dict (message ))
170+ )
171+ await async_conn .commit ()
172+
66173
67174def generate_memory_instance (
68175 session_id ,
0 commit comments