@@ -15,7 +15,7 @@ def __init__(
1515 self ,
1616 * args ,
1717 parent_session_id = None ,
18- dbsession = get_session () ,
18+ dbsession = get_session ,
1919 chats_model = Chat ,
2020 chat_messages_model = ChatMessages ,
2121 ssl_mode = None ,
@@ -48,10 +48,10 @@ def create_tables(self) -> None:
4848
4949 def add_tags (self , tags : str ) -> None :
5050 """Add tags for a given session_id/uuid on chats table"""
51- self .dbsession . query ( self . chats_model ). where (
52- self .chats_model . session_id == self . _session_id
53- ). update ({ getattr ( self .chats_model , "tags" ): tags })
54- self . dbsession . commit ( )
51+ with self .dbsession () as session :
52+ session . query ( self .chats_model ). where (
53+ self .chats_model . session_id == self . _session_id
54+ ). update ({ getattr ( self . chats_model , "tags" ): tags } )
5555
5656 def add_message (self , message : BaseMessage ) -> None :
5757 """Append the message to the record in PostgreSQL"""
@@ -61,13 +61,12 @@ def add_message(self, message: BaseMessage) -> None:
6161 if self .parent_session_id :
6262 message .parent = self .parent_session_id
6363 self .dbsession .add (message )
64- self .dbsession .commit ()
6564
6665
6766def generate_memory_instance (
6867 session_id ,
6968 parent_session_id = None ,
70- dbsession = get_session () ,
69+ dbsession = get_session ,
7170 database_url = None ,
7271 chats_model = Chat ,
7372 chat_messages_model = ChatMessages ,
@@ -88,29 +87,31 @@ def generate_memory_instance(
8887
8988
9089def add_user_message_to_message_history (
91- session_id , message , memory = None , dbsession = get_session () , database_url = None
90+ session_id , message , memory = None , dbsession = get_session , database_url = None
9291):
9392 """
9493 Add a user message to the message history and returns the updated
9594 memory instance
9695 """
97- if not memory :
98- memory = generate_memory_instance (
99- session_id , dbsession = dbsession , database_url = database_url
100- )
96+ with dbsession () as session :
97+ if not memory :
98+ memory = generate_memory_instance (
99+ session_id , dbsession = session , database_url = database_url
100+ )
101101
102- memory .add_user_message (message )
103- return memory
102+ memory .add_user_message (message )
103+ return memory
104104
105105
106- def get_messages (session_id , dbsession = get_session () , database_url = None ):
106+ def get_messages (session_id , dbsession = get_session , database_url = None ):
107107 """
108108 Get all messages for a given session_id
109109 """
110- memory = generate_memory_instance (
111- session_id , dbsession = dbsession , database_url = database_url
112- )
113- return memory .messages
110+ with dbsession () as session :
111+ memory = generate_memory_instance (
112+ session_id , dbsession = session , database_url = database_url
113+ )
114+ return memory .messages
114115
115116def get_memory_instance (session_id , sqlalchemy_session , database_url ):
116117 return generate_memory_instance (
0 commit comments