You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_run_pipeline.py

113 lines
2.8 KiB

"""Tests for pipeline/run_pipeline.py"""
import argparse
import sys
import pytest
from pipeline.run_pipeline import _generate_windows, build_parser, run
from datetime import date
def test_generate_windows_quarterly():
start = date(2024, 1, 1)
end = date(2024, 12, 31)
windows = _generate_windows(start, end, "quarterly")
assert len(windows) == 4
ids = [w[0] for w in windows]
assert ids == ["2024-Q1", "2024-Q2", "2024-Q3", "2024-Q4"]
# Q1 bounds
assert windows[0][1] == "2024-01-01"
assert windows[0][2] == "2024-03-31"
# Q4 bounds
assert windows[3][1] == "2024-10-01"
assert windows[3][2] == "2024-12-31"
def test_generate_windows_annual():
start = date(2022, 6, 1)
end = date(2024, 3, 31)
windows = _generate_windows(start, end, "annual")
assert len(windows) == 3
ids = [w[0] for w in windows]
assert ids == ["2022", "2023", "2024"]
# 2024 should end at end_date, not Dec 31
assert windows[2][2] == "2024-03-31"
def test_generate_windows_mid_quarter_start():
"""Starting in the middle of Q2 should still produce a full Q2 window."""
start = date(2024, 5, 15)
end = date(2024, 9, 30)
windows = _generate_windows(start, end, "quarterly")
ids = [w[0] for w in windows]
assert "2024-Q2" in ids
assert "2024-Q3" in ids
def test_build_parser_defaults():
parser = build_parser()
args = parser.parse_args([])
assert args.db_path == "data/motions.db"
assert args.window_size == "quarterly"
assert args.svd_k == 50
assert args.dry_run is False
def test_run_dry_run(tmp_path, monkeypatch):
"""Dry-run should log actions and return 0 without touching the DB."""
db_path = str(tmp_path / "motions.db")
# Create minimal DB so MotionDatabase initialises
from database import MotionDatabase
MotionDatabase(db_path)
args = argparse.Namespace(
db_path=db_path,
start_date="2024-01-01",
end_date="2024-03-31",
window_size="quarterly",
svd_k=10,
text_model=None,
skip_metadata=False,
skip_extract=False,
skip_svd=False,
skip_text=False,
skip_fusion=False,
dry_run=True,
)
exit_code = run(args)
assert exit_code == 0
def test_run_skip_all(tmp_path):
"""Skipping all phases should still return 0."""
db_path = str(tmp_path / "motions.db")
from database import MotionDatabase
MotionDatabase(db_path)
args = argparse.Namespace(
db_path=db_path,
start_date="2024-01-01",
end_date="2024-03-31",
window_size="quarterly",
svd_k=10,
text_model=None,
skip_metadata=True,
skip_extract=True,
skip_svd=True,
skip_text=True,
skip_fusion=True,
dry_run=False,
)
exit_code = run(args)
assert exit_code == 0