You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_rerun_embeddings.py

84 lines
3.2 KiB

"""Tests for scripts/rerun_embeddings.py.
Monkeypatches pipeline functions directly on their bound module references
inside rerun_embeddings. Import at module level so the real 'database' module
is in sys.modules before any test-local sys.modules.setdefault calls run.
"""
from unittest.mock import MagicMock
import scripts.rerun_embeddings as rer
def test_rerun_embeddings_calls_pipeline_steps(monkeypatch, tmp_path):
db_file = str(tmp_path / "motions.db")
fake_windows = ["2022-Q3", "2023-Q1", "2024-Q2"]
called = {"ensure": False, "fuse_windows": [], "sim_windows": []}
# Patch duckdb.connect used in _clear_embeddings and _get_all_windows
fake_conn = MagicMock()
fake_conn.execute.return_value.rowcount = 0
fake_conn.execute.return_value.fetchall.return_value = [(w,) for w in fake_windows]
fake_duckdb = MagicMock()
fake_duckdb.connect.return_value = fake_conn
monkeypatch.setattr(rer, "duckdb", fake_duckdb)
# ensure_text_embeddings now returns a 5-tuple:
# (stored, skipped_existing, skipped_no_text, errors, failed_ids)
def fake_ensure(db_path=None, model=None, batch_size=50, **kwargs):
called["ensure"] = True
return (5, 0, 2, 0, [])
def fake_fuse(window_id, db_path=None):
called["fuse_windows"].append(window_id)
return {
"inserted": 1,
"skipped_missing_text": 0,
"skipped_missing_svd": 0,
"errors": 0,
}
def fake_sim(vector_type="fused", window_id=None, db_path=None, top_k=10, **kwargs):
called["sim_windows"].append(window_id)
return 10
monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings", fake_ensure)
monkeypatch.setattr(rer.fusion_pipeline, "fuse_for_window", fake_fuse)
monkeypatch.setattr(rer.similarity_compute, "compute_similarities", fake_sim)
summary = rer.rerun_embeddings(db_file)
assert called["ensure"] is True
assert called["fuse_windows"] == fake_windows
assert called["sim_windows"] == fake_windows
assert summary["windows_processed"] == len(fake_windows)
assert summary["embeddings_stored"] == 5
assert summary["embeddings_skipped_no_text"] == 2
assert summary["embeddings_failed_ids"] == []
def test_rerun_retries_when_retry_missing_and_failed_ids(monkeypatch, tmp_path):
"""When retry_missing=True and first pass returns failed_ids, retry is triggered."""
db_file = str(tmp_path / "motions.db")
monkeypatch.setattr(rer, "_clear_embeddings", lambda db_path: 0)
monkeypatch.setattr(rer, "_get_all_windows", lambda db_path: [])
retry_called = {"ids": None}
def fake_ensure(db_path=None, model=None, batch_size=50, **kwargs):
return (3, 0, 0, 2, [201, 202])
def fake_retry(db_path=None, ids=None, model=None, batch_size=10, **kwargs):
retry_called["ids"] = ids
return (2, 0, 0, 0, [])
monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings", fake_ensure)
monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings_for_ids", fake_retry)
summary = rer.rerun_embeddings(db_file, retry_missing=True)
assert retry_called["ids"] is not None, "retry was not called"
assert set(retry_called["ids"]) == {201, 202}
assert summary["embeddings_failed_ids"] == [201, 202]