Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
ccurme committed Dec 10, 2024
commit f3d2aa23c90fe1027d37c69c30cb9118adf0b21e
147 changes: 76 additions & 71 deletions tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test PGVector functionality."""
import contextlib
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence

import pytest
from langchain_core.documents import Document
Expand All @@ -25,9 +25,12 @@
ADA_TOKEN_COUNT = 1536


class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)
def _compare_documents(left: Sequence[Document], right: Sequence[Document]) -> None:
"""Compare lists of documents, irrespective of IDs."""
assert len(left) == len(right)
for left_doc, right_doc in zip(left, right):
assert left_doc.page_content == right_doc.page_content
assert left_doc.metadata == right_doc.metadata


class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
Expand Down Expand Up @@ -55,7 +58,7 @@ def test_pgvector() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


@pytest.mark.asyncio
Expand All @@ -70,7 +73,7 @@ async def test_async_pgvector() -> None:
pre_delete_collection=True,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


def test_pgvector_embeddings() -> None:
Expand All @@ -86,7 +89,7 @@ def test_pgvector_embeddings() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


@pytest.mark.asyncio
Expand All @@ -103,7 +106,7 @@ async def test_async_pgvector_embeddings() -> None:
pre_delete_collection=True,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


def test_pgvector_with_metadatas() -> None:
Expand All @@ -119,7 +122,7 @@ def test_pgvector_with_metadatas() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"}, id=AnyStr())]
_compare_documents(output, [Document(page_content="foo", metadata={"page": "0"})])


@pytest.mark.asyncio
Expand All @@ -136,7 +139,7 @@ async def test_async_pgvector_with_metadatas() -> None:
pre_delete_collection=True,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"}, id=AnyStr())]
_compare_documents(output, [Document(page_content="foo", metadata={"page": "0"})])


def test_pgvector_with_metadatas_with_scores() -> None:
Expand All @@ -152,9 +155,9 @@ def test_pgvector_with_metadatas_with_scores() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=1)
assert output == [
(Document(page_content="foo", metadata={"page": "0"}, id=AnyStr()), 0.0)
]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})])
assert scores == (0.0,)


@pytest.mark.asyncio
Expand All @@ -171,9 +174,9 @@ async def test_async_pgvector_with_metadatas_with_scores() -> None:
pre_delete_collection=True,
)
output = await docsearch.asimilarity_search_with_score("foo", k=1)
assert output == [
(Document(page_content="foo", metadata={"page": "0"}, id=AnyStr()), 0.0)
]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})])
assert scores == (0.0,)


def test_pgvector_with_filter_match() -> None:
Expand All @@ -189,9 +192,9 @@ def test_pgvector_with_filter_match() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
assert output == [
(Document(page_content="foo", metadata={"page": "0"}, id=AnyStr()), 0.0)
]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})])
assert scores == (0.0,)


@pytest.mark.asyncio
Expand All @@ -210,9 +213,9 @@ async def test_async_pgvector_with_filter_match() -> None:
output = await docsearch.asimilarity_search_with_score(
"foo", k=1, filter={"page": "0"}
)
assert output == [
(Document(page_content="foo", metadata={"page": "0"}, id=AnyStr()), 0.0)
]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})])
assert scores == (0.0,)


def test_pgvector_with_filter_distant_match() -> None:
Expand All @@ -228,12 +231,9 @@ def test_pgvector_with_filter_distant_match() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"})
assert output == [
(
Document(page_content="baz", metadata={"page": "2"}, id=AnyStr()),
0.0013003906671379406,
)
]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="baz", metadata={"page": "2"})])
assert scores == (0.0013003906671379406,)


@pytest.mark.asyncio
Expand All @@ -252,12 +252,9 @@ async def test_async_pgvector_with_filter_distant_match() -> None:
output = await docsearch.asimilarity_search_with_score(
"foo", k=1, filter={"page": "2"}
)
assert output == [
(
Document(page_content="baz", metadata={"page": "2"}, id=AnyStr()),
0.0013003906671379406,
)
]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="baz", metadata={"page": "2"})])
assert scores == (0.0013003906671379406,)


def test_pgvector_with_filter_no_match() -> None:
Expand Down Expand Up @@ -593,17 +590,16 @@ def test_pgvector_relevance_score() -> None:
)

output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
assert output == [
(Document(page_content="foo", metadata={"page": "0"}, id=AnyStr()), 1.0),
(
Document(page_content="bar", metadata={"page": "1"}, id=AnyStr()),
0.9996744261675065,
),
(
Document(page_content="baz", metadata={"page": "2"}, id=AnyStr()),
0.9986996093328621,
),
]
docs, scores = zip(*output)
_compare_documents(
docs,
[
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
Document(page_content="baz", metadata={"page": "2"}),
],
)
assert scores == (1.0, 0.9996744261675065, 0.9986996093328621)


@pytest.mark.asyncio
Expand All @@ -621,17 +617,16 @@ async def test_async_pgvector_relevance_score() -> None:
)

output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3)
assert output == [
(Document(page_content="foo", metadata={"page": "0"}, id=AnyStr()), 1.0),
(
Document(page_content="bar", metadata={"page": "1"}, id=AnyStr()),
0.9996744261675065,
),
(
Document(page_content="baz", metadata={"page": "2"}, id=AnyStr()),
0.9986996093328621,
),
]
docs, scores = zip(*output)
_compare_documents(
docs,
[
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
Document(page_content="baz", metadata={"page": "2"}),
],
)
assert scores == (1.0, 0.9996744261675065, 0.9986996093328621)


def test_pgvector_retriever_search_threshold() -> None:
Expand All @@ -652,10 +647,13 @@ def test_pgvector_retriever_search_threshold() -> None:
search_kwargs={"k": 3, "score_threshold": 0.999},
)
output = retriever.get_relevant_documents("summer")
assert output == [
Document(page_content="foo", metadata={"page": "0"}, id=AnyStr()),
Document(page_content="bar", metadata={"page": "1"}, id=AnyStr()),
]
_compare_documents(
output,
[
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
],
)


@pytest.mark.asyncio
Expand All @@ -677,10 +675,13 @@ async def test_async_pgvector_retriever_search_threshold() -> None:
search_kwargs={"k": 3, "score_threshold": 0.999},
)
output = await retriever.aget_relevant_documents("summer")
assert output == [
Document(page_content="foo", metadata={"page": "0"}, id=AnyStr()),
Document(page_content="bar", metadata={"page": "1"}, id=AnyStr()),
]
_compare_documents(
output,
[
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
],
)


def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None:
Expand Down Expand Up @@ -741,7 +742,7 @@ def test_pgvector_max_marginal_relevance_search() -> None:
pre_delete_collection=True,
)
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


@pytest.mark.asyncio
Expand All @@ -756,7 +757,7 @@ async def test_async_pgvector_max_marginal_relevance_search() -> None:
pre_delete_collection=True,
)
output = await docsearch.amax_marginal_relevance_search("foo", k=1, fetch_k=3)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


def test_pgvector_max_marginal_relevance_search_with_score() -> None:
Expand All @@ -770,7 +771,9 @@ def test_pgvector_max_marginal_relevance_search_with_score() -> None:
pre_delete_collection=True,
)
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
assert output == [(Document(page_content="foo", id=AnyStr()), 0.0)]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo")])
assert scores == (0.0,)


@pytest.mark.asyncio
Expand All @@ -787,7 +790,9 @@ async def test_async_pgvector_max_marginal_relevance_search_with_score() -> None
output = await docsearch.amax_marginal_relevance_search_with_score(
"foo", k=1, fetch_k=3
)
assert output == [(Document(page_content="foo", id=AnyStr()), 0.0)]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo")])
assert scores == (0.0,)


def test_pgvector_with_custom_connection() -> None:
Expand All @@ -801,7 +806,7 @@ def test_pgvector_with_custom_connection() -> None:
pre_delete_collection=True,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


@pytest.mark.asyncio
Expand All @@ -816,7 +821,7 @@ async def test_async_pgvector_with_custom_connection() -> None:
pre_delete_collection=True,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


def test_pgvector_with_custom_engine_args() -> None:
Expand All @@ -839,7 +844,7 @@ def test_pgvector_with_custom_engine_args() -> None:
engine_args=engine_args,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]
_compare_documents(output, [Document(page_content="foo")])


# We should reuse this test-case across other integrations
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_vectorstore_standard_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore

@property
def has_async(self) -> bool:
return False
return False # Skip async tests for sync vector store


class TestAsync(VectorStoreIntegrationTests):
Expand All @@ -34,4 +34,4 @@ async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignor

@property
def has_sync(self) -> bool:
return False
return False # Skip sync tests for async vector store
Loading