- Add db=None, embedder=None params to ai_provider_wrapper, text_pipeline, compute_similarities - New conftest.py: FakeEmbedder, mem_db (in-memory DuckDB), fake_embedder fixtures - Rewrite test_ai_provider_wrapper (4 tests), test_rerun_embeddings_retry (2 tests), test_similarity_compute_filter (1 test) with real implementations - Fix rerun_embeddings tests hanging on _get_all_windows by patching it alongside _clear_embeddings - All 53 tests pass (2 skipped), 0 sys.modules hacks in refactored filesmain
parent
b7350d8f87
commit
aef7c45074
@ -0,0 +1,116 @@ |
|||||||
|
"""Wrapper around ai_provider to provide retries and smaller-batch fallback. |
||||||
|
|
||||||
|
Returns a list of embedding vectors aligned with inputs. For inputs that |
||||||
|
fail permanently the corresponding list entry will be None and an audit event |
||||||
|
is appended via database.db.append_audit_event. |
||||||
|
""" |
||||||
|
|
||||||
|
from __future__ import annotations |
||||||
|
|
||||||
|
import time |
||||||
|
import random |
||||||
|
from typing import List, Optional |
||||||
|
|
||||||
|
import ai_provider |
||||||
|
from database import db as motion_db |
||||||
|
import logging |
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def get_embeddings_with_retry( |
||||||
|
texts: List[str], |
||||||
|
motion_ids: Optional[List[Optional[int]]] = None, |
||||||
|
model: Optional[str] = None, |
||||||
|
batch_size: int = 50, |
||||||
|
retries: int = 3, |
||||||
|
db=None, |
||||||
|
embedder=None, |
||||||
|
) -> List[Optional[List[float]]]: |
||||||
|
"""Return embeddings aligned with `texts` or None for failed items. |
||||||
|
|
||||||
|
Strategy: |
||||||
|
- Try batches of `batch_size` with up to `retries` attempts. |
||||||
|
- On persistent batch failure, fall back to per-item attempts (batch_size=1). |
||||||
|
- Record an audit event for items that permanently fail. |
||||||
|
""" |
||||||
|
if not texts: |
||||||
|
return [] |
||||||
|
|
||||||
|
if motion_ids is None: |
||||||
|
motion_ids = [None for _ in texts] |
||||||
|
|
||||||
|
results: List[Optional[List[float]]] = [None] * len(texts) |
||||||
|
|
||||||
|
# resolve embedder at call time; prefer injected, otherwise use ai_provider.get_embeddings_batch |
||||||
|
_embedder = embedder if embedder is not None else ai_provider.get_embeddings_batch |
||||||
|
|
||||||
|
def _attempt_batch(chunk_texts, start_index): |
||||||
|
backoff = 0.5 |
||||||
|
last_exc = None |
||||||
|
for attempt in range(1, retries + 1): |
||||||
|
try: |
||||||
|
emb_chunk = _embedder( |
||||||
|
chunk_texts, model=model, batch_size=len(chunk_texts) |
||||||
|
) |
||||||
|
return emb_chunk, None |
||||||
|
except Exception as exc: |
||||||
|
last_exc = exc |
||||||
|
if attempt == retries: |
||||||
|
break |
||||||
|
sleep = backoff * (2 ** (attempt - 1)) |
||||||
|
sleep = sleep + random.uniform(0, sleep * 0.1) |
||||||
|
_logger.debug( |
||||||
|
"Batch embedding attempt %d failed, retrying after %.2fs: %s", |
||||||
|
attempt, |
||||||
|
sleep, |
||||||
|
exc, |
||||||
|
) |
||||||
|
time.sleep(sleep) |
||||||
|
# persistent failure |
||||||
|
_logger.warning( |
||||||
|
"Batch embedding failed for texts starting at %d: %s", start_index, last_exc |
||||||
|
) |
||||||
|
return None, last_exc |
||||||
|
|
||||||
|
# process in batches |
||||||
|
i = 0 |
||||||
|
n = len(texts) |
||||||
|
while i < n: |
||||||
|
end = min(n, i + batch_size) |
||||||
|
chunk = texts[i:end] |
||||||
|
emb_chunk, emb_exc = _attempt_batch(chunk, i) |
||||||
|
if emb_chunk is not None: |
||||||
|
# success: assign |
||||||
|
for j, emb in enumerate(emb_chunk): |
||||||
|
results[i + j] = emb |
||||||
|
i = end |
||||||
|
continue |
||||||
|
|
||||||
|
# batch failed -> fallback to per-item attempts |
||||||
|
for j in range(i, end): |
||||||
|
t = texts[j] |
||||||
|
mid = motion_ids[j] if j < len(motion_ids) else None |
||||||
|
single, single_exc = _attempt_batch([t], j) |
||||||
|
if single: |
||||||
|
results[j] = single[0] |
||||||
|
continue |
||||||
|
|
||||||
|
# permanent failure for this item |
||||||
|
err_text = repr(single_exc) if single_exc is not None else "unknown" |
||||||
|
try: |
||||||
|
_db = db if db is not None else motion_db |
||||||
|
_db.append_audit_event( |
||||||
|
actor_id=None, |
||||||
|
action="embedding_failed", |
||||||
|
target_type="motion", |
||||||
|
target_id=str(mid) if mid is not None else None, |
||||||
|
metadata={"error": err_text}, |
||||||
|
) |
||||||
|
except Exception: |
||||||
|
_logger.exception("Failed to append audit event for embedding failure") |
||||||
|
results[j] = None |
||||||
|
|
||||||
|
i = end |
||||||
|
|
||||||
|
return results |
||||||
@ -0,0 +1,66 @@ |
|||||||
|
"""Tests for pipeline.ai_provider_wrapper — no monkeypatching, no mocks.""" |
||||||
|
|
||||||
|
import pipeline.ai_provider_wrapper as w |
||||||
|
from tests.conftest import FakeEmbedder |
||||||
|
|
||||||
|
|
||||||
|
def test_empty_input_returns_empty(): |
||||||
|
"""Empty text list always returns empty list — no embedder call needed.""" |
||||||
|
result = w.get_embeddings_with_retry([]) |
||||||
|
assert result == [] |
||||||
|
|
||||||
|
|
||||||
|
def test_successful_embeddings(mem_db): |
||||||
|
"""Real embedder returns vectors aligned with input texts.""" |
||||||
|
embedder = FakeEmbedder() |
||||||
|
result = w.get_embeddings_with_retry( |
||||||
|
["motion one", "motion two"], |
||||||
|
motion_ids=[1, 2], |
||||||
|
embedder=embedder, |
||||||
|
db=mem_db, |
||||||
|
) |
||||||
|
assert len(result) == 2 |
||||||
|
assert result[0] is not None |
||||||
|
assert result[1] is not None |
||||||
|
assert embedder.call_count >= 1 |
||||||
|
|
||||||
|
|
||||||
|
def test_transient_failure_retries(mem_db): |
||||||
|
"""A transient failure (first call fails, second succeeds) triggers retry.""" |
||||||
|
|
||||||
|
class TransientEmbedder: |
||||||
|
def __init__(self): |
||||||
|
self.call_count = 0 |
||||||
|
|
||||||
|
def __call__(self, texts, model=None, batch_size=50): |
||||||
|
self.call_count += 1 |
||||||
|
if self.call_count == 1: |
||||||
|
raise RuntimeError("Transient network error") |
||||||
|
return [[0.5] * 8 for _ in texts] |
||||||
|
|
||||||
|
embedder = TransientEmbedder() |
||||||
|
result = w.get_embeddings_with_retry( |
||||||
|
["motion text"], |
||||||
|
motion_ids=[42], |
||||||
|
embedder=embedder, |
||||||
|
db=mem_db, |
||||||
|
retries=3, |
||||||
|
) |
||||||
|
# After retry, should succeed |
||||||
|
assert result[0] is not None |
||||||
|
assert embedder.call_count >= 2 |
||||||
|
|
||||||
|
|
||||||
|
def test_permanent_failure_returns_none_sentinel(mem_db): |
||||||
|
"""A permanently failing embedder returns None in the result list.""" |
||||||
|
always_fails = FakeEmbedder(fail_indices={0}) |
||||||
|
|
||||||
|
result = w.get_embeddings_with_retry( |
||||||
|
["failing motion"], |
||||||
|
motion_ids=[99], |
||||||
|
embedder=always_fails, |
||||||
|
db=mem_db, |
||||||
|
retries=2, |
||||||
|
) |
||||||
|
# Result entry is None for the failed item |
||||||
|
assert result == [None] |
||||||
@ -0,0 +1,60 @@ |
|||||||
|
"""Tests for scripts.rerun_embeddings retry orchestration. |
||||||
|
|
||||||
|
No sys.modules tricks needed — duckdb is available in .venv. |
||||||
|
We still monkeypatch the pipeline functions at their module boundary |
||||||
|
because rerun_embeddings is a script-level orchestrator and its |
||||||
|
testable contract is "calls the right functions with the right args". |
||||||
|
""" |
||||||
|
|
||||||
|
import scripts.rerun_embeddings as rerun |
||||||
|
import pipeline.text_pipeline as tp |
||||||
|
|
||||||
|
|
||||||
|
def test_rerun_retries_missing(monkeypatch): |
||||||
|
"""When ensure_text_embeddings returns failed_ids, retry helper is called.""" |
||||||
|
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0) |
||||||
|
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: []) |
||||||
|
|
||||||
|
def first_call(db_path=None, model=None, batch_size=50, **kwargs): |
||||||
|
return (1, 0, 0, 1, [101, 102]) |
||||||
|
|
||||||
|
called = {"retried": False, "ids": None} |
||||||
|
|
||||||
|
def retry_call(db_path=None, ids=None, model=None, batch_size=10, **kwargs): |
||||||
|
called["retried"] = True |
||||||
|
called["ids"] = ids |
||||||
|
return (1, 0, 0, 0, []) |
||||||
|
|
||||||
|
monkeypatch.setattr(tp, "ensure_text_embeddings", first_call) |
||||||
|
monkeypatch.setattr(tp, "ensure_text_embeddings_for_ids", retry_call) |
||||||
|
|
||||||
|
summary = rerun.rerun_embeddings( |
||||||
|
"data/motions.db", model="test-model", retry_missing=True |
||||||
|
) |
||||||
|
|
||||||
|
assert called["retried"] is True |
||||||
|
assert set(called["ids"]) == {101, 102} |
||||||
|
|
||||||
|
|
||||||
|
def test_rerun_no_retry_when_no_failures(monkeypatch): |
||||||
|
"""When ensure_text_embeddings returns no failed_ids, retry is NOT called.""" |
||||||
|
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0) |
||||||
|
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: []) |
||||||
|
|
||||||
|
def no_failures(db_path=None, model=None, batch_size=50, **kwargs): |
||||||
|
return (5, 0, 0, 0, []) |
||||||
|
|
||||||
|
retry_called = {"v": False} |
||||||
|
|
||||||
|
def retry_should_not_be_called(**kwargs): |
||||||
|
retry_called["v"] = True |
||||||
|
return (0, 0, 0, 0, []) |
||||||
|
|
||||||
|
monkeypatch.setattr(tp, "ensure_text_embeddings", no_failures) |
||||||
|
monkeypatch.setattr( |
||||||
|
tp, "ensure_text_embeddings_for_ids", retry_should_not_be_called |
||||||
|
) |
||||||
|
|
||||||
|
rerun.rerun_embeddings("data/motions.db", model="test-model", retry_missing=True) |
||||||
|
|
||||||
|
assert retry_called["v"] is False |
||||||
@ -0,0 +1,68 @@ |
|||||||
|
"""Tests for similarity filter in compute_similarities — real DB, real code, no mocks.""" |
||||||
|
|
||||||
|
import json |
||||||
|
import duckdb |
||||||
|
from database import MotionDatabase |
||||||
|
import similarity.compute as sc |
||||||
|
|
||||||
|
|
||||||
|
def test_filter_skips_identical_short_title_pairs(tmp_path): |
||||||
|
"""Pairs with identical short titles and perfect cosine similarity are filtered out.""" |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
|
||||||
|
# 1. Initialize schema |
||||||
|
db = MotionDatabase(db_path) |
||||||
|
|
||||||
|
# 2. Insert 2 motions with identical short titles |
||||||
|
motion1 = { |
||||||
|
"title": "Aangenomen.", |
||||||
|
"description": "desc1", |
||||||
|
"date": "2020-01-01", |
||||||
|
"policy_area": "", |
||||||
|
"voting_results": {}, |
||||||
|
"winning_margin": 0.5, |
||||||
|
"url": "u1", |
||||||
|
} |
||||||
|
motion2 = { |
||||||
|
"title": "Aangenomen.", |
||||||
|
"description": "desc2", |
||||||
|
"date": "2020-01-02", |
||||||
|
"policy_area": "", |
||||||
|
"voting_results": {}, |
||||||
|
"winning_margin": 0.6, |
||||||
|
"url": "u2", |
||||||
|
} |
||||||
|
|
||||||
|
assert db.insert_motion(motion1) is True |
||||||
|
assert db.insert_motion(motion2) is True |
||||||
|
|
||||||
|
# fetch ids |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
id1 = conn.execute( |
||||||
|
"SELECT id FROM motions WHERE url = ?", (motion1["url"],) |
||||||
|
).fetchone()[0] |
||||||
|
id2 = conn.execute( |
||||||
|
"SELECT id FROM motions WHERE url = ?", (motion2["url"],) |
||||||
|
).fetchone()[0] |
||||||
|
|
||||||
|
assert id1 is not None and id2 is not None and id1 != id2 |
||||||
|
|
||||||
|
# 3. Insert identical unit vectors into fused_embeddings using store_fused_embedding |
||||||
|
vec = [1.0] + [0.0] * 7 # 8-dim unit vector |
||||||
|
|
||||||
|
# use a window id (schema requires NOT NULL); compute_similarities will read all fused embeddings when window_id=None |
||||||
|
window_id = "w" |
||||||
|
assert db.store_fused_embedding(id1, window_id, vec, svd_dims=0, text_dims=0) > 0 |
||||||
|
assert db.store_fused_embedding(id2, window_id, vec, svd_dims=0, text_dims=0) > 0 |
||||||
|
conn.close() |
||||||
|
|
||||||
|
# 4. Run compute_similarities |
||||||
|
inserted = sc.compute_similarities( |
||||||
|
vector_type="fused", |
||||||
|
window_id=None, |
||||||
|
db_path=db_path, |
||||||
|
) |
||||||
|
|
||||||
|
# 5. The pair (id1, id2) has perfect similarity and identical short titles |
||||||
|
# The filter should remove it → 0 rows inserted into similarity_cache |
||||||
|
assert inserted == 0, f"Expected 0 pairs after filter, got {inserted}" |
||||||
Loading…
Reference in new issue