You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
84 lines
3.2 KiB
84 lines
3.2 KiB
"""Tests for scripts/rerun_embeddings.py.
|
|
|
|
Monkeypatches pipeline functions directly on their bound module references
|
|
inside rerun_embeddings. Import at module level so the real 'database' module
|
|
is in sys.modules before any test-local sys.modules.setdefault calls run.
|
|
"""
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import scripts.rerun_embeddings as rer
|
|
|
|
|
|
def test_rerun_embeddings_calls_pipeline_steps(monkeypatch, tmp_path):
|
|
db_file = str(tmp_path / "motions.db")
|
|
fake_windows = ["2022-Q3", "2023-Q1", "2024-Q2"]
|
|
|
|
called = {"ensure": False, "fuse_windows": [], "sim_windows": []}
|
|
|
|
# Patch duckdb.connect used in _clear_embeddings and _get_all_windows
|
|
fake_conn = MagicMock()
|
|
fake_conn.execute.return_value.rowcount = 0
|
|
fake_conn.execute.return_value.fetchall.return_value = [(w,) for w in fake_windows]
|
|
fake_duckdb = MagicMock()
|
|
fake_duckdb.connect.return_value = fake_conn
|
|
monkeypatch.setattr(rer, "duckdb", fake_duckdb)
|
|
|
|
# ensure_text_embeddings now returns a 5-tuple:
|
|
# (stored, skipped_existing, skipped_no_text, errors, failed_ids)
|
|
def fake_ensure(db_path=None, model=None, batch_size=50, **kwargs):
|
|
called["ensure"] = True
|
|
return (5, 0, 2, 0, [])
|
|
|
|
def fake_fuse(window_id, db_path=None):
|
|
called["fuse_windows"].append(window_id)
|
|
return {
|
|
"inserted": 1,
|
|
"skipped_missing_text": 0,
|
|
"skipped_missing_svd": 0,
|
|
"errors": 0,
|
|
}
|
|
|
|
def fake_sim(vector_type="fused", window_id=None, db_path=None, top_k=10, **kwargs):
|
|
called["sim_windows"].append(window_id)
|
|
return 10
|
|
|
|
monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings", fake_ensure)
|
|
monkeypatch.setattr(rer.fusion_pipeline, "fuse_for_window", fake_fuse)
|
|
monkeypatch.setattr(rer.similarity_compute, "compute_similarities", fake_sim)
|
|
|
|
summary = rer.rerun_embeddings(db_file)
|
|
|
|
assert called["ensure"] is True
|
|
assert called["fuse_windows"] == fake_windows
|
|
assert called["sim_windows"] == fake_windows
|
|
assert summary["windows_processed"] == len(fake_windows)
|
|
assert summary["embeddings_stored"] == 5
|
|
assert summary["embeddings_skipped_no_text"] == 2
|
|
assert summary["embeddings_failed_ids"] == []
|
|
|
|
|
|
def test_rerun_retries_when_retry_missing_and_failed_ids(monkeypatch, tmp_path):
|
|
"""When retry_missing=True and first pass returns failed_ids, retry is triggered."""
|
|
db_file = str(tmp_path / "motions.db")
|
|
|
|
monkeypatch.setattr(rer, "_clear_embeddings", lambda db_path: 0)
|
|
monkeypatch.setattr(rer, "_get_all_windows", lambda db_path: [])
|
|
|
|
retry_called = {"ids": None}
|
|
|
|
def fake_ensure(db_path=None, model=None, batch_size=50, **kwargs):
|
|
return (3, 0, 0, 2, [201, 202])
|
|
|
|
def fake_retry(db_path=None, ids=None, model=None, batch_size=10, **kwargs):
|
|
retry_called["ids"] = ids
|
|
return (2, 0, 0, 0, [])
|
|
|
|
monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings", fake_ensure)
|
|
monkeypatch.setattr(rer.text_pipeline, "ensure_text_embeddings_for_ids", fake_retry)
|
|
|
|
summary = rer.rerun_embeddings(db_file, retry_missing=True)
|
|
|
|
assert retry_called["ids"] is not None, "retry was not called"
|
|
assert set(retry_called["ids"]) == {201, 202}
|
|
assert summary["embeddings_failed_ids"] == [201, 202]
|
|
|