You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_text_pipeline_retry.py

122 lines
4.1 KiB

"""Tests for pipeline/text_pipeline.py retry behaviour.
Uses monkeypatching to stub get_embeddings_with_retry and store_embedding
so no real DB or network is needed.
"""
import pipeline.text_pipeline as tp
import pipeline.ai_provider_wrapper as ai_wrapper
def _make_fake_db(store_results=None):
"""Return a minimal fake db object for text_pipeline tests."""
store_results = store_results or {}
call_log = {"stored": []}
class FakeDB:
db_path = ":memory:"
def store_embedding(self, motion_id, model, vec):
call_log["stored"].append(motion_id)
return store_results.get(motion_id, 1)
return FakeDB(), call_log
def _stub_select_text(monkeypatch, rows):
"""Patch _select_text to return predetermined (motion_id, text) rows."""
monkeypatch.setattr(tp, "_select_text", lambda db, model: rows)
def _stub_counts(monkeypatch, total=10, existing=0):
"""Patch the duckdb connection used for count queries."""
import types
from unittest.mock import MagicMock
fake_conn = MagicMock()
# fetchone()[0] is used twice: total_motions and existing count
fake_conn.execute.return_value.fetchone.side_effect = [(total,), (existing,)]
fake_duckdb = MagicMock()
fake_duckdb.connect.return_value = fake_conn
monkeypatch.setattr(tp, "duckdb", fake_duckdb)
def test_all_embeddings_stored(monkeypatch):
"""When wrapper returns an embedding for every text, stored count matches."""
rows = [(1, "tekst een"), (2, "tekst twee"), (3, "tekst drie")]
_stub_select_text(monkeypatch, rows)
_stub_counts(monkeypatch, total=3, existing=0)
fake_db, call_log = _make_fake_db()
def fake_wrapper(texts, motion_ids=None, model=None, batch_size=50, **kwargs):
return [[0.1, 0.2, 0.3] for _ in texts]
monkeypatch.setattr(ai_wrapper, "get_embeddings_with_retry", fake_wrapper)
stored, skipped_existing, skipped_no_text, errors, failed_ids = (
tp.ensure_text_embeddings(db=fake_db, model="test-model")
)
assert stored == 3
assert errors == 0
assert failed_ids == []
assert skipped_no_text == 0
assert set(call_log["stored"]) == {1, 2, 3}
def test_partial_failure_populates_failed_ids(monkeypatch):
"""When wrapper returns None for some items, those ids appear in failed_ids."""
rows = [(10, "text a"), (11, "text b"), (12, "text c")]
_stub_select_text(monkeypatch, rows)
_stub_counts(monkeypatch, total=3, existing=0)
fake_db, call_log = _make_fake_db()
def fake_wrapper(texts, motion_ids=None, model=None, batch_size=50, **kwargs):
# Return embedding for first, None for second, embedding for third
return (
[[0.1] for _ in range(len(texts))]
if len(texts) != 3
else [
[0.1, 0.2],
None, # motion_id=11 fails
[0.3, 0.4],
]
)
monkeypatch.setattr(ai_wrapper, "get_embeddings_with_retry", fake_wrapper)
stored, skipped_existing, skipped_no_text, errors, failed_ids = (
tp.ensure_text_embeddings(db=fake_db, model="test-model")
)
assert stored == 2
assert errors == 1
assert 11 in failed_ids
assert 10 not in failed_ids
assert 12 not in failed_ids
def test_no_text_motions_skipped(monkeypatch):
"""Motions with empty text are counted as skipped_no_text, not sent to wrapper."""
rows = [(20, "has text"), (21, ""), (22, None)]
_stub_select_text(monkeypatch, rows)
_stub_counts(monkeypatch, total=3, existing=0)
fake_db, call_log = _make_fake_db()
wrapper_calls = {"count": 0}
def fake_wrapper(texts, motion_ids=None, model=None, batch_size=50, **kwargs):
wrapper_calls["count"] += len(texts)
return [[0.1] for _ in texts]
monkeypatch.setattr(ai_wrapper, "get_embeddings_with_retry", fake_wrapper)
stored, _, skipped_no_text, errors, failed_ids = tp.ensure_text_embeddings(
db=fake_db, model="test-model"
)
assert skipped_no_text == 2 # motions 21 and 22 have no text
assert stored == 1 # only motion 20 was stored
assert wrapper_calls["count"] == 1 # wrapper only received 1 text