You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
2.1 KiB
60 lines
2.1 KiB
"""Tests for scripts.rerun_embeddings retry orchestration.
|
|
|
|
No sys.modules tricks needed — duckdb is available in .venv.
|
|
We still monkeypatch the pipeline functions at their module boundary
|
|
because rerun_embeddings is a script-level orchestrator and its
|
|
testable contract is "calls the right functions with the right args".
|
|
"""
|
|
|
|
import scripts.rerun_embeddings as rerun
|
|
import pipeline.text_pipeline as tp
|
|
|
|
|
|
def test_rerun_retries_missing(monkeypatch):
|
|
"""When ensure_text_embeddings returns failed_ids, retry helper is called."""
|
|
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0)
|
|
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: [])
|
|
|
|
def first_call(db_path=None, model=None, batch_size=50, **kwargs):
|
|
return (1, 0, 0, 1, [101, 102])
|
|
|
|
called = {"retried": False, "ids": None}
|
|
|
|
def retry_call(db_path=None, ids=None, model=None, batch_size=10, **kwargs):
|
|
called["retried"] = True
|
|
called["ids"] = ids
|
|
return (1, 0, 0, 0, [])
|
|
|
|
monkeypatch.setattr(tp, "ensure_text_embeddings", first_call)
|
|
monkeypatch.setattr(tp, "ensure_text_embeddings_for_ids", retry_call)
|
|
|
|
summary = rerun.rerun_embeddings(
|
|
"data/motions.db", model="test-model", retry_missing=True
|
|
)
|
|
|
|
assert called["retried"] is True
|
|
assert set(called["ids"]) == {101, 102}
|
|
|
|
|
|
def test_rerun_no_retry_when_no_failures(monkeypatch):
|
|
"""When ensure_text_embeddings returns no failed_ids, retry is NOT called."""
|
|
monkeypatch.setattr(rerun, "_clear_embeddings", lambda db_path: 0)
|
|
monkeypatch.setattr(rerun, "_get_all_windows", lambda db_path: [])
|
|
|
|
def no_failures(db_path=None, model=None, batch_size=50, **kwargs):
|
|
return (5, 0, 0, 0, [])
|
|
|
|
retry_called = {"v": False}
|
|
|
|
def retry_should_not_be_called(**kwargs):
|
|
retry_called["v"] = True
|
|
return (0, 0, 0, 0, [])
|
|
|
|
monkeypatch.setattr(tp, "ensure_text_embeddings", no_failures)
|
|
monkeypatch.setattr(
|
|
tp, "ensure_text_embeddings_for_ids", retry_should_not_be_called
|
|
)
|
|
|
|
rerun.rerun_embeddings("data/motions.db", model="test-model", retry_missing=True)
|
|
|
|
assert retry_called["v"] is False
|
|
|