You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
4.1 KiB
122 lines
4.1 KiB
"""Tests for pipeline/text_pipeline.py retry behaviour.
|
|
|
|
Uses monkeypatching to stub get_embeddings_with_retry and store_embedding
|
|
so no real DB or network is needed.
|
|
"""
|
|
|
|
import pipeline.text_pipeline as tp
|
|
import pipeline.ai_provider_wrapper as ai_wrapper
|
|
|
|
|
|
def _make_fake_db(store_results=None):
|
|
"""Return a minimal fake db object for text_pipeline tests."""
|
|
store_results = store_results or {}
|
|
call_log = {"stored": []}
|
|
|
|
class FakeDB:
|
|
db_path = ":memory:"
|
|
|
|
def store_embedding(self, motion_id, model, vec):
|
|
call_log["stored"].append(motion_id)
|
|
return store_results.get(motion_id, 1)
|
|
|
|
return FakeDB(), call_log
|
|
|
|
|
|
def _stub_select_text(monkeypatch, rows):
|
|
"""Patch _select_text to return predetermined (motion_id, text) rows."""
|
|
monkeypatch.setattr(tp, "_select_text", lambda db, model: rows)
|
|
|
|
|
|
def _stub_counts(monkeypatch, total=10, existing=0):
|
|
"""Patch the duckdb connection used for count queries."""
|
|
import types
|
|
from unittest.mock import MagicMock
|
|
|
|
fake_conn = MagicMock()
|
|
# fetchone()[0] is used twice: total_motions and existing count
|
|
fake_conn.execute.return_value.fetchone.side_effect = [(total,), (existing,)]
|
|
fake_duckdb = MagicMock()
|
|
fake_duckdb.connect.return_value = fake_conn
|
|
monkeypatch.setattr(tp, "duckdb", fake_duckdb)
|
|
|
|
|
|
def test_all_embeddings_stored(monkeypatch):
|
|
"""When wrapper returns an embedding for every text, stored count matches."""
|
|
rows = [(1, "tekst een"), (2, "tekst twee"), (3, "tekst drie")]
|
|
_stub_select_text(monkeypatch, rows)
|
|
_stub_counts(monkeypatch, total=3, existing=0)
|
|
|
|
fake_db, call_log = _make_fake_db()
|
|
|
|
def fake_wrapper(texts, motion_ids=None, model=None, batch_size=50, **kwargs):
|
|
return [[0.1, 0.2, 0.3] for _ in texts]
|
|
|
|
monkeypatch.setattr(ai_wrapper, "get_embeddings_with_retry", fake_wrapper)
|
|
|
|
stored, skipped_existing, skipped_no_text, errors, failed_ids = (
|
|
tp.ensure_text_embeddings(db=fake_db, model="test-model")
|
|
)
|
|
|
|
assert stored == 3
|
|
assert errors == 0
|
|
assert failed_ids == []
|
|
assert skipped_no_text == 0
|
|
assert set(call_log["stored"]) == {1, 2, 3}
|
|
|
|
|
|
def test_partial_failure_populates_failed_ids(monkeypatch):
|
|
"""When wrapper returns None for some items, those ids appear in failed_ids."""
|
|
rows = [(10, "text a"), (11, "text b"), (12, "text c")]
|
|
_stub_select_text(monkeypatch, rows)
|
|
_stub_counts(monkeypatch, total=3, existing=0)
|
|
|
|
fake_db, call_log = _make_fake_db()
|
|
|
|
def fake_wrapper(texts, motion_ids=None, model=None, batch_size=50, **kwargs):
|
|
# Return embedding for first, None for second, embedding for third
|
|
return (
|
|
[[0.1] for _ in range(len(texts))]
|
|
if len(texts) != 3
|
|
else [
|
|
[0.1, 0.2],
|
|
None, # motion_id=11 fails
|
|
[0.3, 0.4],
|
|
]
|
|
)
|
|
|
|
monkeypatch.setattr(ai_wrapper, "get_embeddings_with_retry", fake_wrapper)
|
|
|
|
stored, skipped_existing, skipped_no_text, errors, failed_ids = (
|
|
tp.ensure_text_embeddings(db=fake_db, model="test-model")
|
|
)
|
|
|
|
assert stored == 2
|
|
assert errors == 1
|
|
assert 11 in failed_ids
|
|
assert 10 not in failed_ids
|
|
assert 12 not in failed_ids
|
|
|
|
|
|
def test_no_text_motions_skipped(monkeypatch):
|
|
"""Motions with empty text are counted as skipped_no_text, not sent to wrapper."""
|
|
rows = [(20, "has text"), (21, ""), (22, None)]
|
|
_stub_select_text(monkeypatch, rows)
|
|
_stub_counts(monkeypatch, total=3, existing=0)
|
|
|
|
fake_db, call_log = _make_fake_db()
|
|
wrapper_calls = {"count": 0}
|
|
|
|
def fake_wrapper(texts, motion_ids=None, model=None, batch_size=50, **kwargs):
|
|
wrapper_calls["count"] += len(texts)
|
|
return [[0.1] for _ in texts]
|
|
|
|
monkeypatch.setattr(ai_wrapper, "get_embeddings_with_retry", fake_wrapper)
|
|
|
|
stored, _, skipped_no_text, errors, failed_ids = tp.ensure_text_embeddings(
|
|
db=fake_db, model="test-model"
|
|
)
|
|
|
|
assert skipped_no_text == 2 # motions 21 and 22 have no text
|
|
assert stored == 1 # only motion 20 was stored
|
|
assert wrapper_calls["count"] == 1 # wrapper only received 1 text
|
|
|