You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
1.9 KiB
66 lines
1.9 KiB
"""Tests for pipeline.ai_provider_wrapper — no monkeypatching, no mocks."""
|
|
|
|
import pipeline.ai_provider_wrapper as w
|
|
from tests.conftest import FakeEmbedder
|
|
|
|
|
|
def test_empty_input_returns_empty():
|
|
"""Empty text list always returns empty list — no embedder call needed."""
|
|
result = w.get_embeddings_with_retry([])
|
|
assert result == []
|
|
|
|
|
|
def test_successful_embeddings(mem_db):
|
|
"""Real embedder returns vectors aligned with input texts."""
|
|
embedder = FakeEmbedder()
|
|
result = w.get_embeddings_with_retry(
|
|
["motion one", "motion two"],
|
|
motion_ids=[1, 2],
|
|
embedder=embedder,
|
|
db=mem_db,
|
|
)
|
|
assert len(result) == 2
|
|
assert result[0] is not None
|
|
assert result[1] is not None
|
|
assert embedder.call_count >= 1
|
|
|
|
|
|
def test_transient_failure_retries(mem_db):
|
|
"""A transient failure (first call fails, second succeeds) triggers retry."""
|
|
|
|
class TransientEmbedder:
|
|
def __init__(self):
|
|
self.call_count = 0
|
|
|
|
def __call__(self, texts, model=None, batch_size=50):
|
|
self.call_count += 1
|
|
if self.call_count == 1:
|
|
raise RuntimeError("Transient network error")
|
|
return [[0.5] * 8 for _ in texts]
|
|
|
|
embedder = TransientEmbedder()
|
|
result = w.get_embeddings_with_retry(
|
|
["motion text"],
|
|
motion_ids=[42],
|
|
embedder=embedder,
|
|
db=mem_db,
|
|
retries=3,
|
|
)
|
|
# After retry, should succeed
|
|
assert result[0] is not None
|
|
assert embedder.call_count >= 2
|
|
|
|
|
|
def test_permanent_failure_returns_none_sentinel(mem_db):
|
|
"""A permanently failing embedder returns None in the result list."""
|
|
always_fails = FakeEmbedder(fail_indices={0})
|
|
|
|
result = w.get_embeddings_with_retry(
|
|
["failing motion"],
|
|
motion_ids=[99],
|
|
embedder=always_fails,
|
|
db=mem_db,
|
|
retries=2,
|
|
)
|
|
# Result entry is None for the failed item
|
|
assert result == [None]
|
|
|