11from __future__ import annotations
22
3- import contextlib
43import enum
54import logging
65import uuid
76from typing import (
87 Any ,
98 Callable ,
109 Dict ,
11- Generator ,
1210 Iterable ,
1311 List ,
1412 Optional ,
2119import sqlalchemy
2220from sqlalchemy import SQLColumnExpression , cast , delete , func
2321from sqlalchemy .dialects .postgresql import JSON , JSONB , JSONPATH , UUID , insert
24- from sqlalchemy .orm import Session , relationship
22+ from sqlalchemy .orm import Session , relationship , sessionmaker
2523
2624try :
2725 from sqlalchemy .orm import declarative_base
@@ -288,15 +286,19 @@ def __init__(
288286 self .override_relevance_score_fn = relevance_score_fn
289287
290288 if isinstance (connection , str ):
291- self ._bind = sqlalchemy .create_engine (url = connection , ** (engine_args or {}))
289+ self ._engine = sqlalchemy .create_engine (
290+ url = connection , ** (engine_args or {})
291+ )
292292 elif isinstance (connection , sqlalchemy .engine .Engine ):
293- self ._bind = connection
293+ self ._engine = connection
294294 else :
295295 raise ValueError (
296296 "connection should be a connection string or an instance of "
297297 "sqlalchemy.engine.Engine"
298298 )
299299
300+ self ._session_maker = sessionmaker (bind = self ._engine )
301+
300302 self .use_jsonb = use_jsonb
301303 self .create_extension = create_extension
302304
@@ -321,16 +323,16 @@ def __post_init__(
321323 self .create_collection ()
322324
323325 def __del__ (self ) -> None :
324- if isinstance (self ._bind , sqlalchemy .engine .Connection ):
325- self ._bind .close ()
326+ if isinstance (self ._engine , sqlalchemy .engine .Connection ):
327+ self ._engine .close ()
326328
327329 @property
328330 def embeddings (self ) -> Embeddings :
329331 return self .embedding_function
330332
331333 def create_vector_extension (self ) -> None :
332334 try :
333- with Session ( self ._bind ) as session : # type: ignore[arg-type]
335+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
334336 # The advisor lock fixes issue arising from concurrent
335337 # creation of the vector extension.
336338 # https://github.com/langchain-ai/langchain/issues/12933
@@ -348,36 +350,31 @@ def create_vector_extension(self) -> None:
348350 raise Exception (f"Failed to create vector extension: { e } " ) from e
349351
350352 def create_tables_if_not_exists (self ) -> None :
351- with Session ( self ._bind ) as session , session . begin (): # type: ignore[arg-type]
353+ with self ._session_maker ( ) as session :
352354 Base .metadata .create_all (session .get_bind ())
353355
354356 def drop_tables (self ) -> None :
355- with Session ( self ._bind ) as session , session . begin (): # type: ignore[arg-type]
357+ with self ._session_maker ( ) as session :
356358 Base .metadata .drop_all (session .get_bind ())
357359
358360 def create_collection (self ) -> None :
359361 if self .pre_delete_collection :
360362 self .delete_collection ()
361- with Session ( self ._bind ) as session : # type: ignore[arg-type]
363+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
362364 self .CollectionStore .get_or_create (
363365 session , self .collection_name , cmetadata = self .collection_metadata
364366 )
365367
366368 def delete_collection (self ) -> None :
367369 self .logger .debug ("Trying to delete collection" )
368- with Session ( self ._bind ) as session : # type: ignore[arg-type]
370+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
369371 collection = self .get_collection (session )
370372 if not collection :
371373 self .logger .warning ("Collection not found" )
372374 return
373375 session .delete (collection )
374376 session .commit ()
375377
376- @contextlib .contextmanager
377- def _make_session (self ) -> Generator [Session , None , None ]:
378- """Create a context manager for the session, bind to _conn string."""
379- yield Session (self ._bind ) # type: ignore[arg-type]
380-
381378 def delete (
382379 self ,
383380 ids : Optional [List [str ]] = None ,
@@ -390,7 +387,7 @@ def delete(
390387 ids: List of ids to delete.
391388 collection_only: Only delete ids in the collection.
392389 """
393- with Session ( self ._bind ) as session : # type: ignore[arg-type]
390+ with self ._session_maker ( ) as session :
394391 if ids is not None :
395392 self .logger .debug (
396393 "Trying to delete vectors by ids (represented by the model "
@@ -476,7 +473,7 @@ def add_embeddings(
476473 if not metadatas :
477474 metadatas = [{} for _ in texts ]
478475
479- with Session ( self ._bind ) as session : # type: ignore[arg-type]
476+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
480477 collection = self .get_collection (session )
481478 if not collection :
482479 raise ValueError ("Collection not found" )
@@ -901,7 +898,7 @@ def __query_collection(
901898 filter : Optional [Dict [str , str ]] = None ,
902899 ) -> List [Any ]:
903900 """Query the collection."""
904- with Session ( self ._bind ) as session : # type: ignore[arg-type]
901+ with self ._session_maker ( ) as session : # type: ignore[arg-type]
905902 collection = self .get_collection (session )
906903 if not collection :
907904 raise ValueError ("Collection not found" )
@@ -1066,6 +1063,7 @@ def from_existing_index(
10661063 embeddings = embedding ,
10671064 distance_strategy = distance_strategy ,
10681065 pre_delete_collection = pre_delete_collection ,
1066+ ** kwargs ,
10691067 )
10701068
10711069 return store
0 commit comments