- Add db=None, embedder=None params to ai_provider_wrapper, text_pipeline, compute_similarities - New conftest.py: FakeEmbedder, mem_db (in-memory DuckDB), fake_embedder fixtures - Rewrite test_ai_provider_wrapper (4 tests), test_rerun_embeddings_retry (2 tests), test_similarity_compute_filter (1 test) with real implementations - Fix rerun_embeddings tests hanging on _get_all_windows by patching it alongside _clear_embeddings - All 53 tests pass (2 skipped), 0 sys.modules hacks in refactored filesmain
parent
b7350d8f87
commit
aef7c45074
@ -0,0 +1,116 @@ |
||||
"""Wrapper around ai_provider to provide retries and smaller-batch fallback. |
||||
|
||||
Returns a list of embedding vectors aligned with inputs. For inputs that |
||||
fail permanently the corresponding list entry will be None and an audit event |
||||
is appended via database.db.append_audit_event. |
||||
""" |
||||
|
||||
from __future__ import annotations |
||||
|
||||
import time |
||||
import random |
||||
from typing import List, Optional |
||||
|
||||
import ai_provider |
||||
from database import db as motion_db |
||||
import logging |
||||
|
||||
_logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
def get_embeddings_with_retry( |
||||
texts: List[str], |
||||
motion_ids: Optional[List[Optional[int]]] = None, |
||||
model: Optional[str] = None, |
||||
batch_size: int = 50, |
||||
retries: int = 3, |
||||
db=None, |
||||
embedder=None, |
||||
) -> List[Optional[List[float]]]: |
||||
"""Return embeddings aligned with `texts` or None for failed items. |
||||
|
||||
Strategy: |
||||
- Try batches of `batch_size` with up to `retries` attempts. |
||||
- On persistent batch failure, fall back to per-item attempts (batch_size=1). |
||||
- Record an audit event for items that permanently fail. |
||||
""" |
||||
if not texts: |
||||
return [] |
||||
|
||||
if motion_ids is None: |
||||
motion_ids = [None for _ in texts] |
||||
|
||||
results: List[Optional[List[float]]] = [None] * len(texts) |
||||
|
||||
# resolve embedder at call time; prefer injected, otherwise use ai_provider.get_embeddings_batch |
||||
_embedder = embedder if embedder is not None else ai_provider.get_embeddings_batch |
||||
|
||||
def _attempt_batch(chunk_texts, start_index): |
||||
backoff = 0.5 |
||||
last_exc = None |
||||
for attempt in range(1, retries + 1): |
||||
try: |
||||
emb_chunk = _embedder( |
||||
chunk_texts, model=model, batch_size=len(chunk_texts) |
||||
) |
||||
return emb_chunk, None |
||||
except Exception as exc: |
||||
last_exc = exc |
||||
if attempt == retries: |
||||
break |
||||
sleep = backoff * (2 ** (attempt - 1)) |
||||
sleep = sleep + random.uniform(0, sleep * 0.1) |
||||
_logger.debug( |
||||
"Batch embedding attempt %d failed, retrying after %.2fs: %s", |
||||
attempt, |
||||
sleep, |
||||
exc, |
||||
) |
||||
time.sleep(sleep) |
||||
# persistent failure |
||||
_logger.warning( |
||||
"Batch embedding failed for texts starting at %d: %s", start_index, last_exc |
||||
) |
||||
return None, last_exc |
||||
|
||||
# process in batches |
||||
i = 0 |
||||
n = len(texts) |
||||
while i < n: |
||||
end = min(n, i + batch_size) |
||||
chunk = texts[i:end] |
||||
emb_chunk, emb_exc = _attempt_batch(chunk, i) |
||||
if emb_chunk is not None: |
||||
# success: assign |
||||
for j, emb in enumerate(emb_chunk): |
||||
results[i + j] = emb |
||||
i = end |
||||
continue |
||||
|
||||
# batch failed -> fallback to per-item attempts |
||||
for j in range(i, end): |
||||
t = texts[j] |
||||
mid = motion_ids[j] if j < len(motion_ids) else None |
||||
single, single_exc = _attempt_batch([t], j) |
||||
if single: |
||||
results[j] = single[0] |
||||
continue |
||||
|
||||
# permanent failure for this item |
||||
err_text = repr(single_exc) if single_exc is not None else "unknown" |
||||
try: |
||||
_db = db if db is not None else motion_db |
||||
_db.append_audit_event( |
||||
actor_id=None, |
||||
action="embedding_failed", |
||||
target_type="motion", |
||||
target_id=str(mid) if mid is not None else None, |
||||
metadata={"error": err_text}, |
||||
) |
||||
except Exception: |
||||
_logger.exception("Failed to append audit event for embedding failure") |
||||
results[j] = None |
||||
|
||||
i = end |
||||
|
||||
return results |
||||
@ -0,0 +1,66 @@ |
||||
"""Tests for pipeline.ai_provider_wrapper — no monkeypatching, no mocks.""" |
||||
|
||||
import pipeline.ai_provider_wrapper as w |
||||
from tests.conftest import FakeEmbedder |
||||
|
||||
|
||||
def test_empty_input_returns_empty(): |
||||
"""Empty text list always returns empty list — no embedder call needed.""" |
||||
result = w.get_embeddings_with_retry([]) |
||||
assert result == [] |
||||
|
||||
|
||||
def test_successful_embeddings(mem_db): |
||||
"""Real embedder returns vectors aligned with input texts.""" |
||||
embedder = FakeEmbedder() |
||||
result = w.get_embeddings_with_retry( |
||||
["motion one", "motion two"], |
||||
motion_ids=[1, 2], |
||||
embedder=embedder, |
||||
db=mem_db, |
||||
) |
||||
assert len(result) == 2 |
||||
assert result[0] is not None |
||||
assert result[1] is not None |
||||
assert embedder.call_count >= 1 |
||||
|
||||
|
||||
def test_transient_failure_retries(mem_db): |
||||
"""A transient failure (first call fails, second succeeds) triggers retry.""" |
||||
|
||||
class TransientEmbedder: |
||||
def __init__(self): |
||||
self.call_count = 0 |
||||
|
||||
def __call__(self, texts, model=None, batch_size=50): |
||||
self.call_count += 1 |
||||
if self.call_count == 1: |
||||
raise RuntimeError("Transient network error") |
||||
return [[0.5] * 8 for _ in texts] |
||||
|
||||
embedder = TransientEmbedder() |
||||
result = w.get_embeddings_with_retry( |
||||
["motion text"], |
||||
motion_ids=[42], |
||||
embedder=embedder, |
||||
db=mem_db, |
||||
retries=3, |
||||
) |
||||
# After retry, should succeed |
||||
assert result[0] is not None |
||||
assert embedder.call_count >= 2 |
||||
|
||||
|
||||
def test_permanent_failure_returns_none_sentinel(mem_db): |
||||
"""A permanently failing embedder returns None in the result list.""" |
||||
always_fails = FakeEmbedder(fail_indices={0}) |
||||
|
||||
result = w.get_embeddings_with_retry( |
||||
["failing motion"], |
||||
motion_ids=[99], |
||||
embedder=always_fails, |
||||
db=mem_db, |
||||
retries=2, |
||||
) |
||||
# Result entry is None for the failed item |
||||
assert result == [None] |
||||
@ -0,0 +1,60 @@ |
||||
"""Tests for scripts.rerun_embeddings retry orchestration. |
||||
|
||||
No sys.modules tricks needed — duckdb is available in .venv. |
||||
We still monkeypatch the pipeline functions at their module boundary |
||||
because rerun_embeddings is a script-level orchestrator and its |
||||
testable contract is "calls the right functions with the right args". |
||||
""" |
||||
|
||||
import scripts.rerun_embeddings as rerun |
||||
import pipeline.text_pipeline as tp |
||||
|
||||
|
||||
def test_rerun_retries_missing(monkeypatch): |
||||
"""When ensure_text_embeddings returns failed_ids, retry helper is called.""" |
||||
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0) |
||||
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: []) |
||||
|
||||
def first_call(db_path=None, model=None, batch_size=50, **kwargs): |
||||
return (1, 0, 0, 1, [101, 102]) |
||||
|
||||
called = {"retried": False, "ids": None} |
||||
|
||||
def retry_call(db_path=None, ids=None, model=None, batch_size=10, **kwargs): |
||||
called["retried"] = True |
||||
called["ids"] = ids |
||||
return (1, 0, 0, 0, []) |
||||
|
||||
monkeypatch.setattr(tp, "ensure_text_embeddings", first_call) |
||||
monkeypatch.setattr(tp, "ensure_text_embeddings_for_ids", retry_call) |
||||
|
||||
summary = rerun.rerun_embeddings( |
||||
"data/motions.db", model="test-model", retry_missing=True |
||||
) |
||||
|
||||
assert called["retried"] is True |
||||
assert set(called["ids"]) == {101, 102} |
||||
|
||||
|
||||
def test_rerun_no_retry_when_no_failures(monkeypatch): |
||||
"""When ensure_text_embeddings returns no failed_ids, retry is NOT called.""" |
||||
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0) |
||||
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: []) |
||||
|
||||
def no_failures(db_path=None, model=None, batch_size=50, **kwargs): |
||||
return (5, 0, 0, 0, []) |
||||
|
||||
retry_called = {"v": False} |
||||
|
||||
def retry_should_not_be_called(**kwargs): |
||||
retry_called["v"] = True |
||||
return (0, 0, 0, 0, []) |
||||
|
||||
monkeypatch.setattr(tp, "ensure_text_embeddings", no_failures) |
||||
monkeypatch.setattr( |
||||
tp, "ensure_text_embeddings_for_ids", retry_should_not_be_called |
||||
) |
||||
|
||||
rerun.rerun_embeddings("data/motions.db", model="test-model", retry_missing=True) |
||||
|
||||
assert retry_called["v"] is False |
||||
@ -0,0 +1,68 @@ |
||||
"""Tests for similarity filter in compute_similarities — real DB, real code, no mocks.""" |
||||
|
||||
import json |
||||
import duckdb |
||||
from database import MotionDatabase |
||||
import similarity.compute as sc |
||||
|
||||
|
||||
def test_filter_skips_identical_short_title_pairs(tmp_path): |
||||
"""Pairs with identical short titles and perfect cosine similarity are filtered out.""" |
||||
db_path = str(tmp_path / "test.db") |
||||
|
||||
# 1. Initialize schema |
||||
db = MotionDatabase(db_path) |
||||
|
||||
# 2. Insert 2 motions with identical short titles |
||||
motion1 = { |
||||
"title": "Aangenomen.", |
||||
"description": "desc1", |
||||
"date": "2020-01-01", |
||||
"policy_area": "", |
||||
"voting_results": {}, |
||||
"winning_margin": 0.5, |
||||
"url": "u1", |
||||
} |
||||
motion2 = { |
||||
"title": "Aangenomen.", |
||||
"description": "desc2", |
||||
"date": "2020-01-02", |
||||
"policy_area": "", |
||||
"voting_results": {}, |
||||
"winning_margin": 0.6, |
||||
"url": "u2", |
||||
} |
||||
|
||||
assert db.insert_motion(motion1) is True |
||||
assert db.insert_motion(motion2) is True |
||||
|
||||
# fetch ids |
||||
conn = duckdb.connect(db_path) |
||||
id1 = conn.execute( |
||||
"SELECT id FROM motions WHERE url = ?", (motion1["url"],) |
||||
).fetchone()[0] |
||||
id2 = conn.execute( |
||||
"SELECT id FROM motions WHERE url = ?", (motion2["url"],) |
||||
).fetchone()[0] |
||||
|
||||
assert id1 is not None and id2 is not None and id1 != id2 |
||||
|
||||
# 3. Insert identical unit vectors into fused_embeddings using store_fused_embedding |
||||
vec = [1.0] + [0.0] * 7 # 8-dim unit vector |
||||
|
||||
# use a window id (schema requires NOT NULL); compute_similarities will read all fused embeddings when window_id=None |
||||
window_id = "w" |
||||
assert db.store_fused_embedding(id1, window_id, vec, svd_dims=0, text_dims=0) > 0 |
||||
assert db.store_fused_embedding(id2, window_id, vec, svd_dims=0, text_dims=0) > 0 |
||||
conn.close() |
||||
|
||||
# 4. Run compute_similarities |
||||
inserted = sc.compute_similarities( |
||||
vector_type="fused", |
||||
window_id=None, |
||||
db_path=db_path, |
||||
) |
||||
|
||||
# 5. The pair (id1, id2) has perfect similarity and identical short titles |
||||
# The filter should remove it → 0 rows inserted into similarity_cache |
||||
assert inserted == 0, f"Expected 0 pairs after filter, got {inserted}" |
||||
Loading…
Reference in new issue