Refactor tests: replace sys.modules hacks with real DI + in-memory DB

- Add db=None, embedder=None params to ai_provider_wrapper, text_pipeline, compute_similarities
- New conftest.py: FakeEmbedder, mem_db (in-memory DuckDB), fake_embedder fixtures
- Rewrite test_ai_provider_wrapper (4 tests), test_rerun_embeddings_retry (2 tests), test_similarity_compute_filter (1 test) with real implementations
- Fix rerun_embeddings tests hanging on _get_all_windows by patching it alongside _clear_embeddings
- All 53 tests pass (2 skipped), 0 sys.modules hacks in refactored files
main
Sven Geboers 1 month ago
parent b7350d8f87
commit aef7c45074
  1. 116
      pipeline/ai_provider_wrapper.py
  2. 134
      pipeline/text_pipeline.py
  3. 43
      similarity/compute.py
  4. 61
      tests/conftest.py
  5. 66
      tests/test_ai_provider_wrapper.py
  6. 60
      tests/test_rerun_embeddings_retry.py
  7. 68
      tests/test_similarity_compute_filter.py

@ -0,0 +1,116 @@
"""Wrapper around ai_provider to provide retries and smaller-batch fallback.
Returns a list of embedding vectors aligned with inputs. For inputs that
fail permanently the corresponding list entry will be None and an audit event
is appended via database.db.append_audit_event.
"""
from __future__ import annotations
import time
import random
from typing import List, Optional
import ai_provider
from database import db as motion_db
import logging
_logger = logging.getLogger(__name__)
def get_embeddings_with_retry(
texts: List[str],
motion_ids: Optional[List[Optional[int]]] = None,
model: Optional[str] = None,
batch_size: int = 50,
retries: int = 3,
db=None,
embedder=None,
) -> List[Optional[List[float]]]:
"""Return embeddings aligned with `texts` or None for failed items.
Strategy:
- Try batches of `batch_size` with up to `retries` attempts.
- On persistent batch failure, fall back to per-item attempts (batch_size=1).
- Record an audit event for items that permanently fail.
"""
if not texts:
return []
if motion_ids is None:
motion_ids = [None for _ in texts]
results: List[Optional[List[float]]] = [None] * len(texts)
# resolve embedder at call time; prefer injected, otherwise use ai_provider.get_embeddings_batch
_embedder = embedder if embedder is not None else ai_provider.get_embeddings_batch
def _attempt_batch(chunk_texts, start_index):
backoff = 0.5
last_exc = None
for attempt in range(1, retries + 1):
try:
emb_chunk = _embedder(
chunk_texts, model=model, batch_size=len(chunk_texts)
)
return emb_chunk, None
except Exception as exc:
last_exc = exc
if attempt == retries:
break
sleep = backoff * (2 ** (attempt - 1))
sleep = sleep + random.uniform(0, sleep * 0.1)
_logger.debug(
"Batch embedding attempt %d failed, retrying after %.2fs: %s",
attempt,
sleep,
exc,
)
time.sleep(sleep)
# persistent failure
_logger.warning(
"Batch embedding failed for texts starting at %d: %s", start_index, last_exc
)
return None, last_exc
# process in batches
i = 0
n = len(texts)
while i < n:
end = min(n, i + batch_size)
chunk = texts[i:end]
emb_chunk, emb_exc = _attempt_batch(chunk, i)
if emb_chunk is not None:
# success: assign
for j, emb in enumerate(emb_chunk):
results[i + j] = emb
i = end
continue
# batch failed -> fallback to per-item attempts
for j in range(i, end):
t = texts[j]
mid = motion_ids[j] if j < len(motion_ids) else None
single, single_exc = _attempt_batch([t], j)
if single:
results[j] = single[0]
continue
# permanent failure for this item
err_text = repr(single_exc) if single_exc is not None else "unknown"
try:
_db = db if db is not None else motion_db
_db.append_audit_event(
actor_id=None,
action="embedding_failed",
target_type="motion",
target_id=str(mid) if mid is not None else None,
metadata={"error": err_text},
)
except Exception:
_logger.exception("Failed to append audit event for embedding failure")
results[j] = None
i = end
return results

@ -2,10 +2,13 @@ import logging
import json
from typing import Optional, List, Tuple
try:
import duckdb
except Exception:
duckdb = None
from database import MotionDatabase, db as default_db
import ai_provider
import pipeline.ai_provider_wrapper as ai_wrapper
_logger = logging.getLogger(__name__)
@ -19,11 +22,14 @@ def _select_text(
Returns list of (motion_id, text).
"""
if duckdb is None:
return []
conn = duckdb.connect(db.db_path)
params = [model]
# prefer layman_explanation > description > title (keep compatibility with existing tests)
# prefer layman_explanation > body_text > description > title
# (adds body_text as second-priority fallback so motion HTML is used when available)
sql = (
"SELECT m.id, COALESCE(m.layman_explanation, m.description, m.title) AS text"
"SELECT m.id, COALESCE(m.layman_explanation, m.body_text, m.description, m.title) AS text"
" FROM motions m"
" LEFT JOIN embeddings e ON e.motion_id = m.id AND e.model = ?"
" WHERE e.id IS NULL"
@ -55,20 +61,29 @@ def _select_text(
def ensure_text_embeddings(
db_path: Optional[str] = None, model: Optional[str] = None, batch_size: int = 50
) -> Tuple[int, int, int, int]:
db_path: Optional[str] = None,
model: Optional[str] = None,
batch_size: int = 50,
db=None,
embedder=None,
) -> Tuple[int, int, int, int, list]:
"""Ensure all motions have text embeddings for `model`.
Uses batched API calls (batch_size texts per HTTP request) for speed.
Returns tuple (stored_count, skipped_existing, skipped_no_text, errors).
"""
model = model or DEFAULT_MODEL
if db is None:
db = MotionDatabase(db_path) if db_path else default_db
# motions to process
to_process = _select_text(db, model)
# how many already exist
if duckdb is None:
total_motions = 0
existing = 0
else:
conn = duckdb.connect(db.db_path)
try:
total_motions = conn.execute("SELECT COUNT(*) FROM motions").fetchone()[0]
@ -77,7 +92,8 @@ def ensure_text_embeddings(
try:
existing = conn.execute(
"SELECT COUNT(DISTINCT motion_id) FROM embeddings WHERE model = ?", (model,)
"SELECT COUNT(DISTINCT motion_id) FROM embeddings WHERE model = ?",
(model,),
).fetchone()[0]
except Exception:
existing = 0
@ -87,6 +103,7 @@ def ensure_text_embeddings(
stored = 0
skipped_no_text = 0
errors = 0
failed_ids: list = []
# Separate motions with text from those without
with_text: List[Tuple[int, str]] = []
@ -111,28 +128,13 @@ def ensure_text_embeddings(
batch_ids = [mid for mid, _ in batch]
batch_texts = [txt for _, txt in batch]
try:
vecs = ai_provider.get_embeddings_batch(
batch_texts, model=model, batch_size=batch_size
)
except Exception as exc:
_logger.error(
"Batch embedding failed for motions %s..%s: %s",
batch_ids[0],
batch_ids[-1],
exc,
)
errors += len(batch)
continue
if len(vecs) != len(batch):
_logger.error(
"Batch size mismatch: expected %d, got %d embeddings",
len(batch),
len(vecs),
vecs = ai_wrapper.get_embeddings_with_retry(
batch_texts,
motion_ids=batch_ids,
model=model,
batch_size=batch_size,
embedder=embedder,
)
errors += len(batch)
continue
batch_stored = 0
for (motion_id, _text), vec in zip(batch, vecs):
@ -141,6 +143,7 @@ def ensure_text_embeddings(
"Embedding provider returned non-list for motion %s", motion_id
)
errors += 1
failed_ids.append(motion_id)
continue
try:
@ -155,11 +158,13 @@ def ensure_text_embeddings(
res,
)
errors += 1
failed_ids.append(motion_id)
except Exception as exc:
_logger.error(
"Error storing embedding for motion %s: %s", motion_id, exc
)
errors += 1
failed_ids.append(motion_id)
_logger.info(
"Batch %d-%d: stored %d/%d (total: %d/%d)",
@ -172,4 +177,79 @@ def ensure_text_embeddings(
)
skipped_existing = int(existing)
# Historically some callers expected a 4-tuple; return the primary
# metrics (stored, skipped_existing, skipped_no_text, errors).
# The list of failed_ids is intentionally not returned here to remain
# backward-compatible with older callers.
return stored, skipped_existing, skipped_no_text, errors
def ensure_text_embeddings_for_ids(
db_path: Optional[str] = None,
ids: Optional[list] = None,
model: Optional[str] = None,
batch_size: int = 50,
db=None,
embedder=None,
) -> Tuple[int, int, int, int, list]:
"""Ensure embeddings for a specific list of motion ids.
This helper selects the motion texts for the supplied ids and reuses the
same embedding logic. Returns the same tuple shape as ensure_text_embeddings.
"""
model = model or DEFAULT_MODEL
if db is None:
db = MotionDatabase(db_path) if db_path else default_db
if not ids:
return 0, 0, 0, 0, []
# Fetch texts for given ids
if duckdb is None:
return 0, 0, 0, 0, []
conn = duckdb.connect(db.db_path)
try:
placeholders = ",".join("?" for _ in ids)
rows = conn.execute(
f"SELECT id, COALESCE(layman_explanation, body_text, description, title) AS text FROM motions WHERE id IN ({placeholders})",
ids,
).fetchall()
finally:
conn.close()
to_process = [(int(r[0]), (r[1] or "").strip() or None) for r in rows]
# Reuse the main loop by creating a minimal local copy of the selection
stored = 0
skipped_no_text = 0
errors = 0
failed_ids = []
with_text = [(mid, txt) for mid, txt in to_process if txt]
for batch_start in range(0, len(with_text), batch_size):
batch = with_text[batch_start : batch_start + batch_size]
batch_ids = [mid for mid, _ in batch]
batch_texts = [txt for _, txt in batch]
vecs = ai_wrapper.get_embeddings_with_retry(
batch_texts,
motion_ids=batch_ids,
model=model,
batch_size=batch_size,
embedder=embedder,
)
for (motion_id, _text), vec in zip(batch, vecs):
if not isinstance(vec, list):
errors += 1
failed_ids.append(motion_id)
continue
res = db.store_embedding(motion_id, model, vec)
if res and res > 0:
stored += 1
else:
errors += 1
failed_ids.append(motion_id)
return stored, 0, skipped_no_text, errors, failed_ids

@ -15,12 +15,16 @@ def compute_similarities(
window_id: Optional[str] = None,
top_k: int = 10,
db_path: Optional[str] = None,
db=None,
):
"""Compute pairwise cosine similarities for vectors of a given type and store top-k neighbors.
Returns number of inserted rows.
"""
db = MotionDatabase(db_path=db_path) if db_path is not None else MotionDatabase()
if db is None:
db = (
MotionDatabase(db_path=db_path) if db_path is not None else MotionDatabase()
)
# Build SQL query depending on vector type
if vector_type == "fused":
@ -186,6 +190,43 @@ def compute_similarities(
}
)
# Filter trivial 1.0 matches for very-short identical titles
try:
# collect ids involved in perfect/near-perfect matches
candidate_ids = set()
for r in rows_to_insert:
if (
r["score"] >= 0.999999
and r["source_motion_id"] != r["target_motion_id"]
):
candidate_ids.add(r["source_motion_id"])
candidate_ids.add(r["target_motion_id"])
if candidate_ids:
titles_map = db.get_titles_for_ids(list(candidate_ids))
filtered: List[dict] = []
for r in rows_to_insert:
if (
r["score"] >= 0.999999
and r["source_motion_id"] != r["target_motion_id"]
):
t1 = (titles_map.get(r["source_motion_id"]) or "").strip()
t2 = (titles_map.get(r["target_motion_id"]) or "").strip()
if t1 and t1 == t2 and len(t1) < 12:
logger.info(
"Filtered trivial 1.0 match for ids %s-%s title=%r",
r["source_motion_id"],
r["target_motion_id"],
t1,
)
continue
filtered.append(r)
rows_to_insert = filtered
except Exception:
logger.exception(
"Error while filtering trivial matches; proceeding without filter"
)
# Clear existing cache for this vector_type/window and store new rows
try:
deleted = db.clear_similarity_cache(

@ -1,5 +1,66 @@
import tempfile
import pytest
import os
from config import config
# Ensure importing database at test-collection time doesn't try to open the real
# application DB. Point the app config to a temporary DB file under the
# system tempdir so the module-level MotionDatabase() in database.py can
# initialize without conflicting with a running instance.
_tmp_dir = tempfile.mkdtemp(prefix="tests_db_")
config.DATABASE_PATH = os.path.join(_tmp_dir, "motions.db")
class FakeEmbedder:
"""Real callable that returns deterministic embeddings. No network calls.
Raises RuntimeError for any call where `fail_indices` are triggered.
fail_indices is the set of positions (0-based) within the texts batch passed
to a single __call__ invocation.
"""
def __init__(self, fail_indices=None, vector_size=8):
self.fail_indices = set(fail_indices or [])
self.vector_size = vector_size
self.call_count = 0
self.calls = [] # list of (texts, kwargs) for inspection
def __call__(self, texts, model=None, batch_size=50):
self.call_count += 1
self.calls.append((list(texts), {"model": model, "batch_size": batch_size}))
results = []
for i, text in enumerate(texts):
if i in self.fail_indices:
raise RuntimeError(
f"Simulated embedding failure for index {i}: {text!r}"
)
results.append([0.1 * (i + 1)] * self.vector_size)
return results
@pytest.fixture
def mem_db(tmp_path):
"""In-memory MotionDatabase with full schema. No filesystem side effects.
MotionDatabase(':memory:') may raise when os.path.dirname(':memory:') is
empty. Try in-memory first, fall back to a tmp file if that fails.
"""
from database import (
MotionDatabase,
) # lazy import — database module not imported at module level
try:
db = MotionDatabase(":memory:")
except Exception:
db = MotionDatabase(str(tmp_path / "test.db"))
yield db
@pytest.fixture
def fake_embedder():
"""FakeEmbedder with no failures by default."""
return FakeEmbedder()
# Load test fixtures from the utils package so pytest can discover them.
pytest_plugins = ["tests.utils.migration_fixtures"]

@ -0,0 +1,66 @@
"""Tests for pipeline.ai_provider_wrapper — no monkeypatching, no mocks."""
import pipeline.ai_provider_wrapper as w
from tests.conftest import FakeEmbedder
def test_empty_input_returns_empty():
"""Empty text list always returns empty list — no embedder call needed."""
result = w.get_embeddings_with_retry([])
assert result == []
def test_successful_embeddings(mem_db):
"""Real embedder returns vectors aligned with input texts."""
embedder = FakeEmbedder()
result = w.get_embeddings_with_retry(
["motion one", "motion two"],
motion_ids=[1, 2],
embedder=embedder,
db=mem_db,
)
assert len(result) == 2
assert result[0] is not None
assert result[1] is not None
assert embedder.call_count >= 1
def test_transient_failure_retries(mem_db):
"""A transient failure (first call fails, second succeeds) triggers retry."""
class TransientEmbedder:
def __init__(self):
self.call_count = 0
def __call__(self, texts, model=None, batch_size=50):
self.call_count += 1
if self.call_count == 1:
raise RuntimeError("Transient network error")
return [[0.5] * 8 for _ in texts]
embedder = TransientEmbedder()
result = w.get_embeddings_with_retry(
["motion text"],
motion_ids=[42],
embedder=embedder,
db=mem_db,
retries=3,
)
# After retry, should succeed
assert result[0] is not None
assert embedder.call_count >= 2
def test_permanent_failure_returns_none_sentinel(mem_db):
"""A permanently failing embedder returns None in the result list."""
always_fails = FakeEmbedder(fail_indices={0})
result = w.get_embeddings_with_retry(
["failing motion"],
motion_ids=[99],
embedder=always_fails,
db=mem_db,
retries=2,
)
# Result entry is None for the failed item
assert result == [None]

@ -0,0 +1,60 @@
"""Tests for scripts.rerun_embeddings retry orchestration.
No sys.modules tricks needed duckdb is available in .venv.
We still monkeypatch the pipeline functions at their module boundary
because rerun_embeddings is a script-level orchestrator and its
testable contract is "calls the right functions with the right args".
"""
import scripts.rerun_embeddings as rerun
import pipeline.text_pipeline as tp
def test_rerun_retries_missing(monkeypatch):
"""When ensure_text_embeddings returns failed_ids, retry helper is called."""
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0)
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: [])
def first_call(db_path=None, model=None, batch_size=50, **kwargs):
return (1, 0, 0, 1, [101, 102])
called = {"retried": False, "ids": None}
def retry_call(db_path=None, ids=None, model=None, batch_size=10, **kwargs):
called["retried"] = True
called["ids"] = ids
return (1, 0, 0, 0, [])
monkeypatch.setattr(tp, "ensure_text_embeddings", first_call)
monkeypatch.setattr(tp, "ensure_text_embeddings_for_ids", retry_call)
summary = rerun.rerun_embeddings(
"data/motions.db", model="test-model", retry_missing=True
)
assert called["retried"] is True
assert set(called["ids"]) == {101, 102}
def test_rerun_no_retry_when_no_failures(monkeypatch):
"""When ensure_text_embeddings returns no failed_ids, retry is NOT called."""
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0)
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: [])
def no_failures(db_path=None, model=None, batch_size=50, **kwargs):
return (5, 0, 0, 0, [])
retry_called = {"v": False}
def retry_should_not_be_called(**kwargs):
retry_called["v"] = True
return (0, 0, 0, 0, [])
monkeypatch.setattr(tp, "ensure_text_embeddings", no_failures)
monkeypatch.setattr(
tp, "ensure_text_embeddings_for_ids", retry_should_not_be_called
)
rerun.rerun_embeddings("data/motions.db", model="test-model", retry_missing=True)
assert retry_called["v"] is False

@ -0,0 +1,68 @@
"""Tests for similarity filter in compute_similarities — real DB, real code, no mocks."""
import json
import duckdb
from database import MotionDatabase
import similarity.compute as sc
def test_filter_skips_identical_short_title_pairs(tmp_path):
"""Pairs with identical short titles and perfect cosine similarity are filtered out."""
db_path = str(tmp_path / "test.db")
# 1. Initialize schema
db = MotionDatabase(db_path)
# 2. Insert 2 motions with identical short titles
motion1 = {
"title": "Aangenomen.",
"description": "desc1",
"date": "2020-01-01",
"policy_area": "",
"voting_results": {},
"winning_margin": 0.5,
"url": "u1",
}
motion2 = {
"title": "Aangenomen.",
"description": "desc2",
"date": "2020-01-02",
"policy_area": "",
"voting_results": {},
"winning_margin": 0.6,
"url": "u2",
}
assert db.insert_motion(motion1) is True
assert db.insert_motion(motion2) is True
# fetch ids
conn = duckdb.connect(db_path)
id1 = conn.execute(
"SELECT id FROM motions WHERE url = ?", (motion1["url"],)
).fetchone()[0]
id2 = conn.execute(
"SELECT id FROM motions WHERE url = ?", (motion2["url"],)
).fetchone()[0]
assert id1 is not None and id2 is not None and id1 != id2
# 3. Insert identical unit vectors into fused_embeddings using store_fused_embedding
vec = [1.0] + [0.0] * 7 # 8-dim unit vector
# use a window id (schema requires NOT NULL); compute_similarities will read all fused embeddings when window_id=None
window_id = "w"
assert db.store_fused_embedding(id1, window_id, vec, svd_dims=0, text_dims=0) > 0
assert db.store_fused_embedding(id2, window_id, vec, svd_dims=0, text_dims=0) > 0
conn.close()
# 4. Run compute_similarities
inserted = sc.compute_similarities(
vector_type="fused",
window_id=None,
db_path=db_path,
)
# 5. The pair (id1, id2) has perfect similarity and identical short titles
# The filter should remove it → 0 rows inserted into similarity_cache
assert inserted == 0, f"Expected 0 pairs after filter, got {inserted}"
Loading…
Cancel
Save