You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
349 lines
11 KiB
349 lines
11 KiB
"""Tests for scripts/motion_drift.py."""
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
|
|
import duckdb
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
def _setup_test_db(db_path: str, windows: dict = None):
|
|
"""Create a test database with synthetic SVD data.
|
|
|
|
windows: {window_id: {motion_id: vector_array}}
|
|
"""
|
|
if windows is None:
|
|
windows = {
|
|
"2020": {
|
|
1: np.array([1.0, 0.5, 0.2]),
|
|
2: np.array([-0.8, 0.3, 0.1]),
|
|
3: np.array([0.5, -0.9, 0.4]),
|
|
},
|
|
"2021": {
|
|
1: np.array([1.1, 0.6, 0.3]),
|
|
2: np.array([-0.7, 0.4, 0.2]),
|
|
3: np.array([0.6, -0.8, 0.5]),
|
|
},
|
|
"2022": {
|
|
1: np.array([1.2, 0.7, 0.4]),
|
|
2: np.array([-0.6, 0.5, 0.3]),
|
|
3: np.array([0.7, -0.7, 0.6]),
|
|
},
|
|
}
|
|
|
|
con = duckdb.connect(db_path)
|
|
try:
|
|
con.execute("""
|
|
CREATE TABLE svd_vectors (
|
|
window_id VARCHAR,
|
|
entity_type VARCHAR,
|
|
entity_id VARCHAR,
|
|
vector VARCHAR,
|
|
model VARCHAR
|
|
)
|
|
""")
|
|
|
|
con.execute("""
|
|
CREATE TABLE fused_embeddings (
|
|
motion_id INTEGER,
|
|
window_id VARCHAR,
|
|
vector VARCHAR,
|
|
svd_dims INTEGER,
|
|
text_dims INTEGER
|
|
)
|
|
""")
|
|
|
|
con.execute("""
|
|
CREATE TABLE mp_votes (
|
|
id INTEGER,
|
|
motion_id INTEGER,
|
|
mp_name VARCHAR,
|
|
party VARCHAR,
|
|
vote VARCHAR,
|
|
date DATE
|
|
)
|
|
""")
|
|
|
|
con.execute("""
|
|
CREATE TABLE motions (
|
|
id INTEGER,
|
|
title VARCHAR,
|
|
body_text VARCHAR,
|
|
date DATE,
|
|
policy_area VARCHAR
|
|
)
|
|
""")
|
|
|
|
# Insert motion vectors
|
|
for window_id, motions in windows.items():
|
|
for motion_id, vector in motions.items():
|
|
con.execute(
|
|
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector) VALUES (?, 'motion', ?, ?)",
|
|
[window_id, str(motion_id), json.dumps(vector.tolist())],
|
|
)
|
|
|
|
# Insert fused embeddings (simple extension of motion vector)
|
|
fused = np.concatenate([vector, np.zeros(10)]) # 3 SVD + 10 text dims
|
|
con.execute(
|
|
"INSERT INTO fused_embeddings (motion_id, window_id, vector, svd_dims, text_dims) VALUES (?, ?, ?, 3, 10)",
|
|
[motion_id, window_id, json.dumps(fused.tolist())],
|
|
)
|
|
|
|
# Insert motion metadata
|
|
con.execute(
|
|
"INSERT INTO motions (id, title, date) VALUES (?, ?, '2020-01-01')",
|
|
[motion_id, f"Motion {motion_id}"],
|
|
)
|
|
|
|
# Insert some voting data
|
|
con.execute("""
|
|
INSERT INTO mp_votes (motion_id, mp_name, party, vote, date) VALUES
|
|
(1, 'MP1', 'PVV', 'voor', '2020-06-01'),
|
|
(1, 'MP2', 'SP', 'voor', '2020-06-01'),
|
|
(2, 'MP3', 'VVD', 'voor', '2020-06-01'),
|
|
(3, 'MP4', 'PvdA', 'voor', '2020-06-01'),
|
|
""")
|
|
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
class TestMotionDriftScript:
|
|
"""Test the motion_drift.py script."""
|
|
|
|
def test_help_exits_cleanly(self):
|
|
"""main(["--help"]) exits with code 0 and prints usage."""
|
|
from scripts.motion_drift import main
|
|
|
|
with pytest.raises(SystemExit) as exc_info:
|
|
main(["--help"])
|
|
assert exc_info.value.code == 0
|
|
|
|
def test_missing_database_returns_error(self):
|
|
"""main(["--db", "nonexistent.db"]) returns exit code 1."""
|
|
from scripts.motion_drift import main
|
|
|
|
result = main(["--db", "nonexistent.db"])
|
|
assert result == 1
|
|
|
|
def test_runs_against_test_db(self, tmp_path):
|
|
"""main(["--db", "test.db", "--output", "/tmp/test"]) runs without error."""
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_test_db(db_path)
|
|
|
|
from scripts.motion_drift import main
|
|
|
|
output_dir = str(tmp_path / "output")
|
|
result = main(["--db", db_path, "--output", output_dir])
|
|
assert result == 0
|
|
assert os.path.exists(os.path.join(output_dir, "report.md"))
|
|
|
|
def test_schema_validation_catches_missing_tables(self, tmp_path):
|
|
"""Database with missing tables produces clear error."""
|
|
db_path = str(tmp_path / "empty.db")
|
|
con = duckdb.connect(db_path)
|
|
con.close()
|
|
|
|
from scripts.motion_drift import main
|
|
|
|
result = main(["--db", db_path])
|
|
assert result == 1
|
|
|
|
|
|
class TestAxisStability:
|
|
"""Test axis stability computation."""
|
|
|
|
def test_returns_stability_matrix_for_multiple_windows(self, tmp_path):
|
|
"""compute_axis_stability returns stability matrix for 3+ windows."""
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_test_db(db_path)
|
|
|
|
from scripts.motion_drift import compute_axis_stability
|
|
|
|
con = duckdb.connect(db_path, read_only=True)
|
|
try:
|
|
result = compute_axis_stability(
|
|
con, ["2020", "2021", "2022"], top_n=3, n_components=3
|
|
)
|
|
assert "stability_matrix" in result
|
|
# With < 50 motions per window, falls back to party-based method
|
|
# which returns empty if mp_metadata doesn't exist
|
|
assert "stable_axes" in result
|
|
assert "avg_stability" in result
|
|
finally:
|
|
con.close()
|
|
|
|
def test_stability_values_in_valid_range(self, tmp_path):
|
|
"""Stability matrix values are in [0, 1] (cosine similarity)."""
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_test_db(db_path)
|
|
|
|
from scripts.motion_drift import compute_axis_stability
|
|
|
|
con = duckdb.connect(db_path, read_only=True)
|
|
try:
|
|
result = compute_axis_stability(
|
|
con, ["2020", "2021", "2022"], top_n=3, n_components=3
|
|
)
|
|
matrix = result["stability_matrix"]
|
|
if matrix.size > 0:
|
|
assert matrix.min() >= -1.0
|
|
assert matrix.max() <= 1.0
|
|
finally:
|
|
con.close()
|
|
|
|
def test_single_window_returns_empty(self, tmp_path):
|
|
"""Single window returns empty stability report."""
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_test_db(db_path)
|
|
|
|
from scripts.motion_drift import compute_axis_stability
|
|
|
|
con = duckdb.connect(db_path, read_only=True)
|
|
try:
|
|
result = compute_axis_stability(con, ["2020"], top_n=3, n_components=3)
|
|
assert result["stability_matrix"].size == 0
|
|
assert result["stable_axes"] == []
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
class TestSemanticDrift:
|
|
"""Test semantic drift computation."""
|
|
|
|
def test_returns_drift_series_for_stable_axes(self, tmp_path):
|
|
"""compute_semantic_drift returns drift series for each stable axis."""
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_test_db(db_path)
|
|
|
|
from scripts.motion_drift import compute_semantic_drift
|
|
|
|
con = duckdb.connect(db_path, read_only=True)
|
|
try:
|
|
result = compute_semantic_drift(
|
|
con, [1, 2, 3], ["2020", "2021", "2022"], top_n=3, n_components=3
|
|
)
|
|
assert "drift_series" in result
|
|
for axis, values in result["drift_series"].items():
|
|
assert len(values) == 2 # 3 windows → 2 transitions
|
|
for v in values:
|
|
assert 0.0 <= v <= 2.0 # cosine distance range
|
|
finally:
|
|
con.close()
|
|
|
|
def test_no_inflection_points_for_monotonic_drift(self, tmp_path):
|
|
"""Axis with monotonic drift returns no inflection points."""
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_test_db(db_path)
|
|
|
|
from scripts.motion_drift import compute_semantic_drift
|
|
|
|
con = duckdb.connect(db_path, read_only=True)
|
|
try:
|
|
result = compute_semantic_drift(
|
|
con, [1], ["2020", "2021", "2022"], top_n=3, n_components=3
|
|
)
|
|
# With only 2 drift values, inflection detection is limited
|
|
# But should not crash
|
|
assert "inflection_points" in result
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
class TestPartyVoting:
|
|
"""Test party voting analysis."""
|
|
|
|
def test_returns_voting_centroids(self, tmp_path):
|
|
"""compute_party_voting returns voting centroids for parties with data."""
|
|
db_path = str(tmp_path / "test.db")
|
|
_setup_test_db(db_path)
|
|
|
|
from scripts.motion_drift import compute_party_voting
|
|
|
|
con = duckdb.connect(db_path, read_only=True)
|
|
try:
|
|
result = compute_party_voting(con, [1, 2, 3], ["2020"])
|
|
assert "party_trajectories" in result
|
|
# Should have at least one party from test data
|
|
assert len(result["party_trajectories"]) > 0
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
class TestReportGeneration:
|
|
"""Test report generation."""
|
|
|
|
def test_report_generated_with_all_sections(self, tmp_path):
|
|
"""Report generated with all expected sections."""
|
|
from scripts.motion_drift import _generate_report
|
|
|
|
output_dir = str(tmp_path / "report")
|
|
stability_result = {
|
|
"stability_matrix": np.array(
|
|
[[[1.0, 0.8], [0.8, 1.0]], [[1.0, 0.9], [0.9, 1.0]]]
|
|
),
|
|
"stable_axes": [1, 2],
|
|
"reordered_axes": [],
|
|
"unstable_axes": [],
|
|
"windows": ["2020", "2021"],
|
|
}
|
|
drift_result = {
|
|
"drift_series": {1: [0.1, 0.15], 2: [0.05, 0.08]},
|
|
"inflection_points": {1: [], 2: []},
|
|
"example_motions": {},
|
|
}
|
|
party_result = {
|
|
"party_trajectories": {"PVV": {"2020": {"axes": {1: 1.0, 2: 0.5}}}},
|
|
"cross_voting": {},
|
|
"examples": {},
|
|
}
|
|
|
|
report_path = _generate_report(
|
|
output_dir,
|
|
stability_result,
|
|
drift_result,
|
|
party_result,
|
|
["2020", "2021"],
|
|
20,
|
|
)
|
|
|
|
assert os.path.exists(report_path)
|
|
with open(report_path) as f:
|
|
content = f.read()
|
|
|
|
assert "## Summary" in content
|
|
assert "## Axis Stability" in content
|
|
assert "## Semantic Drift" in content
|
|
assert "## Party Voting Analysis" in content
|
|
assert "## Methodology" in content
|
|
|
|
def test_no_stable_axes_handles_gracefully(self, tmp_path):
|
|
"""No stable axes → report notes this and skips drift/party sections."""
|
|
from scripts.motion_drift import _generate_report
|
|
|
|
output_dir = str(tmp_path / "report")
|
|
stability_result = {
|
|
"stability_matrix": np.array([]),
|
|
"stable_axes": [],
|
|
"reordered_axes": [],
|
|
"unstable_axes": [1, 2],
|
|
"windows": ["2020"],
|
|
}
|
|
drift_result = {
|
|
"drift_series": {},
|
|
"inflection_points": {},
|
|
"example_motions": {},
|
|
}
|
|
party_result = {"party_trajectories": {}, "cross_voting": {}, "examples": {}}
|
|
|
|
report_path = _generate_report(
|
|
output_dir, stability_result, drift_result, party_result, ["2020"], 20
|
|
)
|
|
|
|
assert os.path.exists(report_path)
|
|
with open(report_path) as f:
|
|
content = f.read()
|
|
|
|
assert "No stable axes" in content or "No drift data available" in content
|
|
|