You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_rerun_embeddings_retry.py

60 lines
2.1 KiB

"""Tests for scripts.rerun_embeddings retry orchestration.
No sys.modules tricks needed — duckdb is available in .venv.
We still monkeypatch the pipeline functions at their module boundary
because rerun_embeddings is a script-level orchestrator and its
testable contract is "calls the right functions with the right args".
"""
import scripts.rerun_embeddings as rerun
import pipeline.text_pipeline as tp
def test_rerun_retries_missing(monkeypatch):
"""When ensure_text_embeddings returns failed_ids, retry helper is called."""
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0)
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: [])
def first_call(db_path=None, model=None, batch_size=50, **kwargs):
return (1, 0, 0, 1, [101, 102])
called = {"retried": False, "ids": None}
def retry_call(db_path=None, ids=None, model=None, batch_size=10, **kwargs):
called["retried"] = True
called["ids"] = ids
return (1, 0, 0, 0, [])
monkeypatch.setattr(tp, "ensure_text_embeddings", first_call)
monkeypatch.setattr(tp, "ensure_text_embeddings_for_ids", retry_call)
summary = rerun.rerun_embeddings(
"data/motions.db", model="test-model", retry_missing=True
)
assert called["retried"] is True
assert set(called["ids"]) == {101, 102}
def test_rerun_no_retry_when_no_failures(monkeypatch):
"""When ensure_text_embeddings returns no failed_ids, retry is NOT called."""
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0)
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: [])
def no_failures(db_path=None, model=None, batch_size=50, **kwargs):
return (5, 0, 0, 0, [])
retry_called = {"v": False}
def retry_should_not_be_called(**kwargs):
retry_called["v"] = True
return (0, 0, 0, 0, [])
monkeypatch.setattr(tp, "ensure_text_embeddings", no_failures)
monkeypatch.setattr(
tp, "ensure_text_embeddings_for_ids", retry_should_not_be_called
)
rerun.rerun_embeddings("data/motions.db", model="test-model", retry_missing=True)
assert retry_called["v"] is False