You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_motion_drift.py

349 lines
11 KiB

"""Tests for scripts/motion_drift.py."""
import json
import os
import tempfile
import duckdb
import numpy as np
import pytest
def _setup_test_db(db_path: str, windows: dict = None):
"""Create a test database with synthetic SVD data.
windows: {window_id: {motion_id: vector_array}}
"""
if windows is None:
windows = {
"2020": {
1: np.array([1.0, 0.5, 0.2]),
2: np.array([-0.8, 0.3, 0.1]),
3: np.array([0.5, -0.9, 0.4]),
},
"2021": {
1: np.array([1.1, 0.6, 0.3]),
2: np.array([-0.7, 0.4, 0.2]),
3: np.array([0.6, -0.8, 0.5]),
},
"2022": {
1: np.array([1.2, 0.7, 0.4]),
2: np.array([-0.6, 0.5, 0.3]),
3: np.array([0.7, -0.7, 0.6]),
},
}
con = duckdb.connect(db_path)
try:
con.execute("""
CREATE TABLE svd_vectors (
window_id VARCHAR,
entity_type VARCHAR,
entity_id VARCHAR,
vector VARCHAR,
model VARCHAR
)
""")
con.execute("""
CREATE TABLE fused_embeddings (
motion_id INTEGER,
window_id VARCHAR,
vector VARCHAR,
svd_dims INTEGER,
text_dims INTEGER
)
""")
con.execute("""
CREATE TABLE mp_votes (
id INTEGER,
motion_id INTEGER,
mp_name VARCHAR,
party VARCHAR,
vote VARCHAR,
date DATE
)
""")
con.execute("""
CREATE TABLE motions (
id INTEGER,
title VARCHAR,
body_text VARCHAR,
date DATE,
policy_area VARCHAR
)
""")
# Insert motion vectors
for window_id, motions in windows.items():
for motion_id, vector in motions.items():
con.execute(
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector) VALUES (?, 'motion', ?, ?)",
[window_id, str(motion_id), json.dumps(vector.tolist())],
)
# Insert fused embeddings (simple extension of motion vector)
fused = np.concatenate([vector, np.zeros(10)]) # 3 SVD + 10 text dims
con.execute(
"INSERT INTO fused_embeddings (motion_id, window_id, vector, svd_dims, text_dims) VALUES (?, ?, ?, 3, 10)",
[motion_id, window_id, json.dumps(fused.tolist())],
)
# Insert motion metadata
con.execute(
"INSERT INTO motions (id, title, date) VALUES (?, ?, '2020-01-01')",
[motion_id, f"Motion {motion_id}"],
)
# Insert some voting data
con.execute("""
INSERT INTO mp_votes (motion_id, mp_name, party, vote, date) VALUES
(1, 'MP1', 'PVV', 'voor', '2020-06-01'),
(1, 'MP2', 'SP', 'voor', '2020-06-01'),
(2, 'MP3', 'VVD', 'voor', '2020-06-01'),
(3, 'MP4', 'PvdA', 'voor', '2020-06-01'),
""")
finally:
con.close()
class TestMotionDriftScript:
"""Test the motion_drift.py script."""
def test_help_exits_cleanly(self):
"""main(["--help"]) exits with code 0 and prints usage."""
from scripts.motion_drift import main
with pytest.raises(SystemExit) as exc_info:
main(["--help"])
assert exc_info.value.code == 0
def test_missing_database_returns_error(self):
"""main(["--db", "nonexistent.db"]) returns exit code 1."""
from scripts.motion_drift import main
result = main(["--db", "nonexistent.db"])
assert result == 1
def test_runs_against_test_db(self, tmp_path):
"""main(["--db", "test.db", "--output", "/tmp/test"]) runs without error."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import main
output_dir = str(tmp_path / "output")
result = main(["--db", db_path, "--output", output_dir])
assert result == 0
assert os.path.exists(os.path.join(output_dir, "report.md"))
def test_schema_validation_catches_missing_tables(self, tmp_path):
"""Database with missing tables produces clear error."""
db_path = str(tmp_path / "empty.db")
con = duckdb.connect(db_path)
con.close()
from scripts.motion_drift import main
result = main(["--db", db_path])
assert result == 1
class TestAxisStability:
"""Test axis stability computation."""
def test_returns_stability_matrix_for_multiple_windows(self, tmp_path):
"""compute_axis_stability returns stability matrix for 3+ windows."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_axis_stability
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_axis_stability(
con, ["2020", "2021", "2022"], top_n=3, n_components=3
)
assert "stability_matrix" in result
# With < 50 motions per window, falls back to party-based method
# which returns empty if mp_metadata doesn't exist
assert "stable_axes" in result
assert "avg_stability" in result
finally:
con.close()
def test_stability_values_in_valid_range(self, tmp_path):
"""Stability matrix values are in [0, 1] (cosine similarity)."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_axis_stability
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_axis_stability(
con, ["2020", "2021", "2022"], top_n=3, n_components=3
)
matrix = result["stability_matrix"]
if matrix.size > 0:
assert matrix.min() >= -1.0
assert matrix.max() <= 1.0
finally:
con.close()
def test_single_window_returns_empty(self, tmp_path):
"""Single window returns empty stability report."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_axis_stability
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_axis_stability(con, ["2020"], top_n=3, n_components=3)
assert result["stability_matrix"].size == 0
assert result["stable_axes"] == []
finally:
con.close()
class TestSemanticDrift:
"""Test semantic drift computation."""
def test_returns_drift_series_for_stable_axes(self, tmp_path):
"""compute_semantic_drift returns drift series for each stable axis."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_semantic_drift
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_semantic_drift(
con, [1, 2, 3], ["2020", "2021", "2022"], top_n=3, n_components=3
)
assert "drift_series" in result
for axis, values in result["drift_series"].items():
assert len(values) == 2 # 3 windows → 2 transitions
for v in values:
assert 0.0 <= v <= 2.0 # cosine distance range
finally:
con.close()
def test_no_inflection_points_for_monotonic_drift(self, tmp_path):
"""Axis with monotonic drift returns no inflection points."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_semantic_drift
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_semantic_drift(
con, [1], ["2020", "2021", "2022"], top_n=3, n_components=3
)
# With only 2 drift values, inflection detection is limited
# But should not crash
assert "inflection_points" in result
finally:
con.close()
class TestPartyVoting:
"""Test party voting analysis."""
def test_returns_voting_centroids(self, tmp_path):
"""compute_party_voting returns voting centroids for parties with data."""
db_path = str(tmp_path / "test.db")
_setup_test_db(db_path)
from scripts.motion_drift import compute_party_voting
con = duckdb.connect(db_path, read_only=True)
try:
result = compute_party_voting(con, [1, 2, 3], ["2020"])
assert "party_trajectories" in result
# Should have at least one party from test data
assert len(result["party_trajectories"]) > 0
finally:
con.close()
class TestReportGeneration:
"""Test report generation."""
def test_report_generated_with_all_sections(self, tmp_path):
"""Report generated with all expected sections."""
from scripts.motion_drift import _generate_report
output_dir = str(tmp_path / "report")
stability_result = {
"stability_matrix": np.array(
[[[1.0, 0.8], [0.8, 1.0]], [[1.0, 0.9], [0.9, 1.0]]]
),
"stable_axes": [1, 2],
"reordered_axes": [],
"unstable_axes": [],
"windows": ["2020", "2021"],
}
drift_result = {
"drift_series": {1: [0.1, 0.15], 2: [0.05, 0.08]},
"inflection_points": {1: [], 2: []},
"example_motions": {},
}
party_result = {
"party_trajectories": {"PVV": {"2020": {"axes": {1: 1.0, 2: 0.5}}}},
"cross_voting": {},
"examples": {},
}
report_path = _generate_report(
output_dir,
stability_result,
drift_result,
party_result,
["2020", "2021"],
20,
)
assert os.path.exists(report_path)
with open(report_path) as f:
content = f.read()
assert "## Summary" in content
assert "## Axis Stability" in content
assert "## Semantic Drift" in content
assert "## Party Voting Analysis" in content
assert "## Methodology" in content
def test_no_stable_axes_handles_gracefully(self, tmp_path):
"""No stable axes → report notes this and skips drift/party sections."""
from scripts.motion_drift import _generate_report
output_dir = str(tmp_path / "report")
stability_result = {
"stability_matrix": np.array([]),
"stable_axes": [],
"reordered_axes": [],
"unstable_axes": [1, 2],
"windows": ["2020"],
}
drift_result = {
"drift_series": {},
"inflection_points": {},
"example_motions": {},
}
party_result = {"party_trajectories": {}, "cross_voting": {}, "examples": {}}
report_path = _generate_report(
output_dir, stability_result, drift_result, party_result, ["2020"], 20
)
assert os.path.exists(report_path)
with open(report_path) as f:
content = f.read()
assert "No stable axes" in content or "No drift data available" in content