You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
2.8 KiB
113 lines
2.8 KiB
"""Tests for pipeline/run_pipeline.py"""
|
|
|
|
import argparse
|
|
import sys
|
|
import pytest
|
|
|
|
from pipeline.run_pipeline import _generate_windows, build_parser, run
|
|
from datetime import date
|
|
|
|
|
|
def test_generate_windows_quarterly():
|
|
start = date(2024, 1, 1)
|
|
end = date(2024, 12, 31)
|
|
windows = _generate_windows(start, end, "quarterly")
|
|
|
|
assert len(windows) == 4
|
|
ids = [w[0] for w in windows]
|
|
assert ids == ["2024-Q1", "2024-Q2", "2024-Q3", "2024-Q4"]
|
|
|
|
# Q1 bounds
|
|
assert windows[0][1] == "2024-01-01"
|
|
assert windows[0][2] == "2024-03-31"
|
|
|
|
# Q4 bounds
|
|
assert windows[3][1] == "2024-10-01"
|
|
assert windows[3][2] == "2024-12-31"
|
|
|
|
|
|
def test_generate_windows_annual():
|
|
start = date(2022, 6, 1)
|
|
end = date(2024, 3, 31)
|
|
windows = _generate_windows(start, end, "annual")
|
|
|
|
assert len(windows) == 3
|
|
ids = [w[0] for w in windows]
|
|
assert ids == ["2022", "2023", "2024"]
|
|
|
|
# 2024 should end at end_date, not Dec 31
|
|
assert windows[2][2] == "2024-03-31"
|
|
|
|
|
|
def test_generate_windows_mid_quarter_start():
|
|
"""Starting in the middle of Q2 should still produce a full Q2 window."""
|
|
start = date(2024, 5, 15)
|
|
end = date(2024, 9, 30)
|
|
windows = _generate_windows(start, end, "quarterly")
|
|
|
|
ids = [w[0] for w in windows]
|
|
assert "2024-Q2" in ids
|
|
assert "2024-Q3" in ids
|
|
|
|
|
|
def test_build_parser_defaults():
|
|
parser = build_parser()
|
|
args = parser.parse_args([])
|
|
assert args.db_path == "data/motions.db"
|
|
assert args.window_size == "quarterly"
|
|
assert args.svd_k == 50
|
|
assert args.dry_run is False
|
|
|
|
|
|
def test_run_dry_run(tmp_path, monkeypatch):
|
|
"""Dry-run should log actions and return 0 without touching the DB."""
|
|
db_path = str(tmp_path / "motions.db")
|
|
|
|
# Create minimal DB so MotionDatabase initialises
|
|
from database import MotionDatabase
|
|
|
|
MotionDatabase(db_path)
|
|
|
|
args = argparse.Namespace(
|
|
db_path=db_path,
|
|
start_date="2024-01-01",
|
|
end_date="2024-03-31",
|
|
window_size="quarterly",
|
|
svd_k=10,
|
|
text_model=None,
|
|
skip_metadata=False,
|
|
skip_extract=False,
|
|
skip_svd=False,
|
|
skip_text=False,
|
|
skip_fusion=False,
|
|
dry_run=True,
|
|
)
|
|
|
|
exit_code = run(args)
|
|
assert exit_code == 0
|
|
|
|
|
|
def test_run_skip_all(tmp_path):
|
|
"""Skipping all phases should still return 0."""
|
|
db_path = str(tmp_path / "motions.db")
|
|
from database import MotionDatabase
|
|
|
|
MotionDatabase(db_path)
|
|
|
|
args = argparse.Namespace(
|
|
db_path=db_path,
|
|
start_date="2024-01-01",
|
|
end_date="2024-03-31",
|
|
window_size="quarterly",
|
|
svd_k=10,
|
|
text_model=None,
|
|
skip_metadata=True,
|
|
skip_extract=True,
|
|
skip_svd=True,
|
|
skip_text=True,
|
|
skip_fusion=True,
|
|
dry_run=False,
|
|
)
|
|
|
|
exit_code = run(args)
|
|
assert exit_code == 0
|
|
|