"""Tests for scripts/rerun_embeddings.py. Monkeypatches pipeline functions directly on their bound module references inside rerun_embeddings. Import at module level so the real 'database' module is in sys.modules before any test-local sys.modules.setdefault calls run. """ from unittest.mock import MagicMock import scripts.rerun_embeddings as rer def test_rerun_embeddings_calls_pipeline_steps(monkeypatch, tmp_path): db_file = str(tmp_path / "motions.db") fake_windows = ["2022-Q3", "2023-Q1", "2024-Q2"] called = {"ensure": False, "fuse_windows": [], "sim_windows": []} # Patch duckdb.connect used in _clear_embeddings and _get_all_windows fake_conn = MagicMock() fake_conn.execute.return_value.rowcount = 0 fake_conn.execute.return_value.fetchall.return_value = [(w,) for w in fake_windows] fake_duckdb = MagicMock() fake_duckdb.connect.return_value = fake_conn monkeypatch.setattr(rer, "duckdb", fake_duckdb) # ensure_text_embeddings now returns a 5-tuple: # (stored, skipped_existing, skipped_no_text, errors, failed_ids) def fake_ensure(db_path=None, model=None, batch_size=50, **kwargs): called["ensure"] = True return (5, 0, 2, 0, []) def fake_fuse(window_id, db_path=None): called["fuse_windows"].append(window_id) return { "inserted": 1, "skipped_missing_text": 0, "skipped_missing_svd": 0, "errors": 0, } def fake_sim(vector_type="fused", window_id=None, db_path=None, top_k=10, **kwargs): called["sim_windows"].append(window_id) return 10 monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings", fake_ensure) monkeypatch.setattr(rer.fusion_pipeline, "fuse_for_window", fake_fuse) monkeypatch.setattr(rer.similarity_compute, "compute_similarities", fake_sim) summary = rer.rerun_embeddings(db_file) assert called["ensure"] is True assert called["fuse_windows"] == fake_windows assert called["sim_windows"] == fake_windows assert summary["windows_processed"] == len(fake_windows) assert summary["embeddings_stored"] == 5 assert summary["embeddings_skipped_no_text"] == 2 assert summary["embeddings_failed_ids"] == [] def test_rerun_retries_when_retry_missing_and_failed_ids(monkeypatch, tmp_path): """When retry_missing=True and first pass returns failed_ids, retry is triggered.""" db_file = str(tmp_path / "motions.db") monkeypatch.setattr(rer, "_clear_embeddings", lambda db_path: 0) monkeypatch.setattr(rer, "_get_all_windows", lambda db_path: []) retry_called = {"ids": None} def fake_ensure(db_path=None, model=None, batch_size=50, **kwargs): return (3, 0, 0, 2, [201, 202]) def fake_retry(db_path=None, ids=None, model=None, batch_size=10, **kwargs): retry_called["ids"] = ids return (2, 0, 0, 0, []) monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings", fake_ensure) monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings_for_ids", fake_retry) summary = rer.rerun_embeddings(db_file, retry_missing=True) assert retry_called["ids"] is not None, "retry was not called" assert set(retry_called["ids"]) == {201, 202} assert summary["embeddings_failed_ids"] == [201, 202]