Skip to content
Prev Previous commit
Next Next commit
x
  • Loading branch information
eyurtsev committed Jul 12, 2024
commit 3429a84c6ec3ae738551ab7d35db34fe5220a0c7
8 changes: 5 additions & 3 deletions tests/unit_tests/test_vectorstore_standard_tests.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import AsyncGenerator, Generator

import pytest
from langchain_core.vectorstores import VectorStore
from langchain_standard_tests.integration_tests.vectorstores import (
AsyncReadWriteTestSuite,
ReadWriteTestSuite,
)

from tests.unit_tests.test_vectorstore import get_vectorstore, aget_vectorstore
from tests.unit_tests.test_vectorstore import aget_vectorstore, get_vectorstore


class TestSync(ReadWriteTestSuite):
@pytest.fixture()
def vectorstore(self) -> VectorStore:
def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore
"""Get an empty vectorstore for unit tests."""
with get_vectorstore(embedding=self.get_embeddings()) as vstore:
vstore.drop_tables()
Expand All @@ -21,7 +23,7 @@ def vectorstore(self) -> VectorStore:

class TestAsync(AsyncReadWriteTestSuite):
@pytest.fixture()
async def vectorstore(self) -> VectorStore:
async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore
"""Get an empty vectorstore for unit tests."""
async with aget_vectorstore(embedding=self.get_embeddings()) as vstore:
await vstore.adrop_tables()
Expand Down