- pipeline/run_pipeline.py: CLI orchestrator for all 5 pipeline phases with
--dry-run, --skip-*, --window-size, --svd-k, --start/end-date flags
- analysis/{political_axis,trajectory,clustering,visualize}.py: PCA/anchor
ideological axis, MP drift trajectories, UMAP + KMeans clustering, Plotly HTML output
- api_client.py: capture ActorFractie per individual MP vote (comma in ActorNaam)
into mp_vote_parties dict on each motion
- database.insert_motion: auto-insert mp_votes rows with party affiliation for
newly ingested motions when mp_vote_parties is present
- Add scikit-learn to pyproject.toml for KMeans clustering
- tests/test_run_pipeline.py: window generation, dry-run, skip-all paths
- tests/test_analysis.py: PCA axis, anchor axis, trajectory drift, KMeans
Ref: thoughts/shared/plans/2026-03-21-parliamentary-embedding-pipeline-plan.md
main
parent
a36e6cba4e
commit
f2a831dfcf
@ -0,0 +1,8 @@ |
|||||||
|
"""Analysis modules for the parliamentary embedding pipeline. |
||||||
|
|
||||||
|
Modules: |
||||||
|
political_axis — project MP SVD vectors onto ideological axis |
||||||
|
trajectory — compute MP drift across aligned windows |
||||||
|
clustering — UMAP dimensionality reduction + cluster labelling |
||||||
|
visualize — Plotly interactive plots (outputs self-contained HTML) |
||||||
|
""" |
||||||
@ -0,0 +1,130 @@ |
|||||||
|
"""clustering.py — UMAP dimensionality reduction on fused embeddings. |
||||||
|
|
||||||
|
Reduces fused motion embeddings to 2D (or 3D) for visualisation, |
||||||
|
and optionally labels clusters using KMeans. |
||||||
|
|
||||||
|
Requires: umap-learn, scikit-learn (for KMeans) |
||||||
|
""" |
||||||
|
|
||||||
|
import json |
||||||
|
import logging |
||||||
|
from typing import Dict, List, Optional, Tuple |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
import duckdb |
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def _load_fused_vectors( |
||||||
|
db_path: str, window_id: Optional[str] = None |
||||||
|
) -> Tuple[List[int], List[str], np.ndarray]: |
||||||
|
"""Load fused embeddings from the DB. |
||||||
|
|
||||||
|
Returns (motion_ids, window_ids, matrix). |
||||||
|
Optionally filter by window_id. |
||||||
|
""" |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
if window_id: |
||||||
|
rows = conn.execute( |
||||||
|
"SELECT motion_id, window_id, vector FROM fused_embeddings WHERE window_id = ?", |
||||||
|
(window_id,), |
||||||
|
).fetchall() |
||||||
|
else: |
||||||
|
rows = conn.execute( |
||||||
|
"SELECT motion_id, window_id, vector FROM fused_embeddings ORDER BY window_id, motion_id" |
||||||
|
).fetchall() |
||||||
|
conn.close() |
||||||
|
|
||||||
|
motion_ids, window_ids, vectors = [], [], [] |
||||||
|
for motion_id, wid, vec_json in rows: |
||||||
|
try: |
||||||
|
vec = json.loads(vec_json) |
||||||
|
motion_ids.append(int(motion_id)) |
||||||
|
window_ids.append(wid) |
||||||
|
vectors.append(vec) |
||||||
|
except Exception: |
||||||
|
_logger.warning("Could not parse fused vector for motion %s", motion_id) |
||||||
|
|
||||||
|
if not vectors: |
||||||
|
return [], [], np.zeros((0, 0)) |
||||||
|
|
||||||
|
# Pad to common length if needed (shouldn't happen if pipeline is consistent) |
||||||
|
max_len = max(len(v) for v in vectors) |
||||||
|
mat = np.zeros((len(vectors), max_len), dtype=float) |
||||||
|
for i, v in enumerate(vectors): |
||||||
|
mat[i, : len(v)] = v |
||||||
|
|
||||||
|
return motion_ids, window_ids, mat |
||||||
|
|
||||||
|
|
||||||
|
def run_umap( |
||||||
|
db_path: str, |
||||||
|
window_id: Optional[str] = None, |
||||||
|
n_components: int = 2, |
||||||
|
n_neighbors: int = 15, |
||||||
|
min_dist: float = 0.1, |
||||||
|
random_state: int = 42, |
||||||
|
) -> Dict: |
||||||
|
"""Run UMAP on fused embeddings and return 2D/3D coordinates. |
||||||
|
|
||||||
|
Returns: |
||||||
|
{ |
||||||
|
"motion_ids": [...], |
||||||
|
"window_ids": [...], |
||||||
|
"coords": [[x, y], ...], # or [x, y, z] if n_components=3 |
||||||
|
"n_components": int, |
||||||
|
} |
||||||
|
""" |
||||||
|
try: |
||||||
|
import umap |
||||||
|
except ImportError: |
||||||
|
_logger.error("umap-learn is not installed; cannot run UMAP") |
||||||
|
return {} |
||||||
|
|
||||||
|
motion_ids, window_ids, mat = _load_fused_vectors(db_path, window_id) |
||||||
|
if mat.size == 0: |
||||||
|
_logger.warning("No fused embeddings found for window_id=%s", window_id) |
||||||
|
return {} |
||||||
|
|
||||||
|
if mat.shape[0] < n_neighbors + 1: |
||||||
|
# UMAP requires at least n_neighbors+1 samples |
||||||
|
n_neighbors = max(2, mat.shape[0] - 1) |
||||||
|
_logger.warning( |
||||||
|
"Reduced n_neighbors to %d due to small dataset (%d samples)", |
||||||
|
n_neighbors, |
||||||
|
mat.shape[0], |
||||||
|
) |
||||||
|
|
||||||
|
reducer = umap.UMAP( |
||||||
|
n_components=n_components, |
||||||
|
n_neighbors=n_neighbors, |
||||||
|
min_dist=min_dist, |
||||||
|
random_state=random_state, |
||||||
|
) |
||||||
|
coords = reducer.fit_transform(mat) |
||||||
|
|
||||||
|
return { |
||||||
|
"motion_ids": motion_ids, |
||||||
|
"window_ids": window_ids, |
||||||
|
"coords": coords.tolist(), |
||||||
|
"n_components": n_components, |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
def cluster_kmeans( |
||||||
|
coords: np.ndarray, n_clusters: int = 8, random_state: int = 42 |
||||||
|
) -> np.ndarray: |
||||||
|
"""Run KMeans on 2D/3D UMAP coordinates. |
||||||
|
|
||||||
|
Returns array of integer cluster labels (length = len(coords)). |
||||||
|
""" |
||||||
|
try: |
||||||
|
from sklearn.cluster import KMeans |
||||||
|
except ImportError: |
||||||
|
_logger.error("scikit-learn is not installed; cannot run KMeans") |
||||||
|
return np.zeros(len(coords), dtype=int) |
||||||
|
|
||||||
|
n_clusters = min(n_clusters, len(coords)) |
||||||
|
km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto") |
||||||
|
return km.fit_predict(coords) |
||||||
@ -0,0 +1,125 @@ |
|||||||
|
"""political_axis.py — Project MP SVD vectors onto an ideological axis. |
||||||
|
|
||||||
|
Two modes: |
||||||
|
1. PCA mode (default): compute the first principal component of all MP SVD |
||||||
|
vectors for a window and project each MP onto it. The sign is arbitrary |
||||||
|
but consistent within a window. |
||||||
|
|
||||||
|
2. Anchor mode: define the axis as the vector from the centroid of |
||||||
|
``left_parties`` to the centroid of ``right_parties``. Project all MPs |
||||||
|
onto this normalised anchor axis. |
||||||
|
|
||||||
|
Both modes return a dict mapping mp_name → scalar score for the given window. |
||||||
|
""" |
||||||
|
|
||||||
|
import json |
||||||
|
import logging |
||||||
|
from typing import Dict, List, Optional |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
import duckdb |
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def _load_mp_svd_vectors(db_path: str, window_id: str) -> Dict[str, np.ndarray]: |
||||||
|
"""Load all MP SVD vectors for a window from svd_vectors table.""" |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
rows = conn.execute( |
||||||
|
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'", |
||||||
|
(window_id,), |
||||||
|
).fetchall() |
||||||
|
conn.close() |
||||||
|
|
||||||
|
result = {} |
||||||
|
for mp_name, vec_json in rows: |
||||||
|
try: |
||||||
|
result[mp_name] = np.array(json.loads(vec_json), dtype=float) |
||||||
|
except Exception: |
||||||
|
_logger.warning("Could not parse SVD vector for MP %s", mp_name) |
||||||
|
return result |
||||||
|
|
||||||
|
|
||||||
|
def compute_pca_axis(db_path: str, window_id: str) -> Dict[str, float]: |
||||||
|
"""Project MP SVD vectors onto their first principal component. |
||||||
|
|
||||||
|
Returns {mp_name: score}. Returns empty dict if fewer than 2 MPs. |
||||||
|
""" |
||||||
|
mp_vecs = _load_mp_svd_vectors(db_path, window_id) |
||||||
|
if len(mp_vecs) < 2: |
||||||
|
_logger.warning( |
||||||
|
"window %s has only %d MPs; skipping PCA axis", window_id, len(mp_vecs) |
||||||
|
) |
||||||
|
return {} |
||||||
|
|
||||||
|
names = list(mp_vecs.keys()) |
||||||
|
mat = np.vstack([mp_vecs[n] for n in names]) # (n_mps, k) |
||||||
|
|
||||||
|
# Centre |
||||||
|
mat_centred = mat - mat.mean(axis=0) |
||||||
|
|
||||||
|
# First PC via SVD |
||||||
|
try: |
||||||
|
_, _, Vt = np.linalg.svd(mat_centred, full_matrices=False) |
||||||
|
axis = Vt[0] # (k,) |
||||||
|
except np.linalg.LinAlgError: |
||||||
|
_logger.exception("SVD failed in compute_pca_axis for window %s", window_id) |
||||||
|
return {} |
||||||
|
|
||||||
|
projections = mat_centred.dot(axis) |
||||||
|
return {name: float(score) for name, score in zip(names, projections)} |
||||||
|
|
||||||
|
|
||||||
|
def compute_anchor_axis( |
||||||
|
db_path: str, |
||||||
|
window_id: str, |
||||||
|
left_parties: List[str], |
||||||
|
right_parties: List[str], |
||||||
|
) -> Dict[str, float]: |
||||||
|
"""Project MP SVD vectors onto a left↔right anchor axis. |
||||||
|
|
||||||
|
The axis runs from the centroid of ``left_parties`` to the centroid of |
||||||
|
``right_parties``. Positive scores are toward the right. |
||||||
|
|
||||||
|
Returns {mp_name: score}. |
||||||
|
""" |
||||||
|
mp_vecs = _load_mp_svd_vectors(db_path, window_id) |
||||||
|
if not mp_vecs: |
||||||
|
return {} |
||||||
|
|
||||||
|
# Load party affiliation for this window from mp_metadata |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
rows = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchall() |
||||||
|
conn.close() |
||||||
|
party_of = {mp: party for mp, party in rows} |
||||||
|
|
||||||
|
left_vecs = [ |
||||||
|
mp_vecs[mp] |
||||||
|
for mp, party in party_of.items() |
||||||
|
if party in left_parties and mp in mp_vecs |
||||||
|
] |
||||||
|
right_vecs = [ |
||||||
|
mp_vecs[mp] |
||||||
|
for mp, party in party_of.items() |
||||||
|
if party in right_parties and mp in mp_vecs |
||||||
|
] |
||||||
|
|
||||||
|
if not left_vecs or not right_vecs: |
||||||
|
_logger.warning( |
||||||
|
"window %s: insufficient anchor parties (left=%d, right=%d)", |
||||||
|
window_id, |
||||||
|
len(left_vecs), |
||||||
|
len(right_vecs), |
||||||
|
) |
||||||
|
return {} |
||||||
|
|
||||||
|
left_centroid = np.mean(left_vecs, axis=0) |
||||||
|
right_centroid = np.mean(right_vecs, axis=0) |
||||||
|
axis = right_centroid - left_centroid |
||||||
|
norm = np.linalg.norm(axis) |
||||||
|
if norm < 1e-10: |
||||||
|
_logger.warning("Anchor axis has near-zero norm for window %s", window_id) |
||||||
|
return {} |
||||||
|
axis = axis / norm |
||||||
|
|
||||||
|
return {name: float(np.dot(vec, axis)) for name, vec in mp_vecs.items()} |
||||||
@ -0,0 +1,123 @@ |
|||||||
|
"""trajectory.py — Compute MP political drift across aligned time windows. |
||||||
|
|
||||||
|
For each MP that appears in multiple windows, computes: |
||||||
|
- The aligned SVD vector per window |
||||||
|
- The Euclidean distance between consecutive windows (drift) |
||||||
|
- Total cumulative drift |
||||||
|
|
||||||
|
Returns a dict keyed by mp_name containing per-window positions and drift scores. |
||||||
|
""" |
||||||
|
|
||||||
|
import json |
||||||
|
import logging |
||||||
|
from typing import Dict, List, Optional |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
import duckdb |
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def _load_window_ids(db_path: str) -> List[str]: |
||||||
|
"""Return all distinct window IDs from svd_vectors, in lexicographic order.""" |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
rows = conn.execute( |
||||||
|
"SELECT DISTINCT window_id FROM svd_vectors WHERE entity_type = 'mp' ORDER BY window_id" |
||||||
|
).fetchall() |
||||||
|
conn.close() |
||||||
|
return [r[0] for r in rows] |
||||||
|
|
||||||
|
|
||||||
|
def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.ndarray]: |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
rows = conn.execute( |
||||||
|
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'", |
||||||
|
(window_id,), |
||||||
|
).fetchall() |
||||||
|
conn.close() |
||||||
|
result = {} |
||||||
|
for mp_name, vec_json in rows: |
||||||
|
try: |
||||||
|
result[mp_name] = np.array(json.loads(vec_json), dtype=float) |
||||||
|
except Exception: |
||||||
|
_logger.warning( |
||||||
|
"Could not parse vector for MP %s window %s", mp_name, window_id |
||||||
|
) |
||||||
|
return result |
||||||
|
|
||||||
|
|
||||||
|
def compute_trajectories( |
||||||
|
db_path: str, |
||||||
|
window_ids: Optional[List[str]] = None, |
||||||
|
) -> Dict[str, Dict]: |
||||||
|
"""Compute per-MP trajectories across windows. |
||||||
|
|
||||||
|
Returns: |
||||||
|
{ |
||||||
|
mp_name: { |
||||||
|
"windows": [window_id, ...], |
||||||
|
"vectors": [[...], ...], # one vector per window |
||||||
|
"drift": [float, ...], # consecutive Euclidean distances |
||||||
|
"total_drift": float, |
||||||
|
} |
||||||
|
} |
||||||
|
Only MPs present in at least 2 windows are included. |
||||||
|
""" |
||||||
|
if window_ids is None: |
||||||
|
window_ids = _load_window_ids(db_path) |
||||||
|
|
||||||
|
if len(window_ids) < 2: |
||||||
|
_logger.info("Fewer than 2 windows — no trajectories to compute") |
||||||
|
return {} |
||||||
|
|
||||||
|
# Collect per-window vectors for each MP |
||||||
|
mp_data: Dict[str, Dict] = {} |
||||||
|
|
||||||
|
for wid in window_ids: |
||||||
|
vecs = _load_mp_vectors_for_window(db_path, wid) |
||||||
|
for mp_name, vec in vecs.items(): |
||||||
|
if mp_name not in mp_data: |
||||||
|
mp_data[mp_name] = {"windows": [], "vectors": []} |
||||||
|
mp_data[mp_name]["windows"].append(wid) |
||||||
|
mp_data[mp_name]["vectors"].append(vec) |
||||||
|
|
||||||
|
# Compute drift for MPs with >= 2 windows |
||||||
|
result = {} |
||||||
|
for mp_name, data in mp_data.items(): |
||||||
|
if len(data["windows"]) < 2: |
||||||
|
continue |
||||||
|
vecs = data["vectors"] |
||||||
|
drifts = [ |
||||||
|
float(np.linalg.norm(vecs[i + 1] - vecs[i])) for i in range(len(vecs) - 1) |
||||||
|
] |
||||||
|
result[mp_name] = { |
||||||
|
"windows": data["windows"], |
||||||
|
"vectors": [v.tolist() for v in vecs], |
||||||
|
"drift": drifts, |
||||||
|
"total_drift": float(sum(drifts)), |
||||||
|
} |
||||||
|
|
||||||
|
_logger.info( |
||||||
|
"Trajectories computed for %d MPs across %d windows", |
||||||
|
len(result), |
||||||
|
len(window_ids), |
||||||
|
) |
||||||
|
return result |
||||||
|
|
||||||
|
|
||||||
|
def top_drifters(trajectories: Dict[str, Dict], n: int = 10) -> List[Dict]: |
||||||
|
"""Return the top-n MPs by total drift, sorted descending. |
||||||
|
|
||||||
|
Each entry: {"mp_name": ..., "total_drift": ..., "windows": [...]} |
||||||
|
""" |
||||||
|
ranked = sorted( |
||||||
|
trajectories.items(), key=lambda kv: kv[1]["total_drift"], reverse=True |
||||||
|
) |
||||||
|
return [ |
||||||
|
{ |
||||||
|
"mp_name": mp, |
||||||
|
"total_drift": data["total_drift"], |
||||||
|
"windows": data["windows"], |
||||||
|
} |
||||||
|
for mp, data in ranked[:n] |
||||||
|
] |
||||||
@ -0,0 +1,163 @@ |
|||||||
|
"""visualize.py — Plotly interactive plots for parliamentary embeddings. |
||||||
|
|
||||||
|
Produces self-contained HTML files. |
||||||
|
|
||||||
|
Functions: |
||||||
|
plot_umap_scatter — 2D scatter of fused motion embeddings, coloured by cluster |
||||||
|
plot_mp_trajectory — Line plot of MP drift across windows |
||||||
|
plot_political_axis — Bar chart of MP scores on the ideological axis |
||||||
|
""" |
||||||
|
|
||||||
|
import logging |
||||||
|
from typing import Dict, List, Optional |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def _require_plotly(): |
||||||
|
try: |
||||||
|
import plotly.graph_objects as go |
||||||
|
import plotly.express as px |
||||||
|
|
||||||
|
return go, px |
||||||
|
except ImportError: |
||||||
|
raise ImportError("plotly is not installed. Install it with: uv add plotly") |
||||||
|
|
||||||
|
|
||||||
|
def plot_umap_scatter( |
||||||
|
motion_ids: List[int], |
||||||
|
coords: List[List[float]], |
||||||
|
labels: Optional[List[int]] = None, |
||||||
|
window_id: Optional[str] = None, |
||||||
|
output_path: str = "analysis_umap.html", |
||||||
|
) -> str: |
||||||
|
"""Produce a 2D scatter plot of UMAP-reduced fused embeddings. |
||||||
|
|
||||||
|
Args: |
||||||
|
motion_ids: Motion IDs (used as hover labels) |
||||||
|
coords: List of [x, y] coordinates |
||||||
|
labels: Optional cluster labels (integer per motion) |
||||||
|
window_id: Window label for the plot title |
||||||
|
output_path: Where to write the self-contained HTML |
||||||
|
|
||||||
|
Returns the output_path on success. |
||||||
|
""" |
||||||
|
go, px = _require_plotly() |
||||||
|
|
||||||
|
xs = [c[0] for c in coords] |
||||||
|
ys = [c[1] for c in coords] |
||||||
|
color = labels if labels is not None else [0] * len(motion_ids) |
||||||
|
title = f"UMAP — fused motion embeddings" + (f" ({window_id})" if window_id else "") |
||||||
|
|
||||||
|
fig = px.scatter( |
||||||
|
x=xs, |
||||||
|
y=ys, |
||||||
|
color=[str(c) for c in color], |
||||||
|
hover_name=[str(mid) for mid in motion_ids], |
||||||
|
title=title, |
||||||
|
labels={"x": "UMAP-1", "y": "UMAP-2", "color": "Cluster"}, |
||||||
|
) |
||||||
|
fig.write_html(output_path, include_plotlyjs="cdn") |
||||||
|
_logger.info("UMAP scatter written to %s", output_path) |
||||||
|
return output_path |
||||||
|
|
||||||
|
|
||||||
|
def plot_mp_trajectory( |
||||||
|
trajectories: Dict[str, Dict], |
||||||
|
mp_names: Optional[List[str]] = None, |
||||||
|
output_path: str = "analysis_trajectory.html", |
||||||
|
) -> str: |
||||||
|
"""Line plot of MP drift across time windows. |
||||||
|
|
||||||
|
Args: |
||||||
|
trajectories: Output of analysis.trajectory.compute_trajectories() |
||||||
|
mp_names: Subset of MPs to plot (default: all) |
||||||
|
output_path: Output HTML file path |
||||||
|
|
||||||
|
Returns the output_path on success. |
||||||
|
""" |
||||||
|
go, px = _require_plotly() |
||||||
|
|
||||||
|
if mp_names is None: |
||||||
|
mp_names = list(trajectories.keys()) |
||||||
|
|
||||||
|
fig = go.Figure() |
||||||
|
|
||||||
|
for mp in mp_names: |
||||||
|
if mp not in trajectories: |
||||||
|
continue |
||||||
|
data = trajectories[mp] |
||||||
|
windows = data["windows"] |
||||||
|
drifts_cumulative = [0.0] + list(np.cumsum(data["drift"])) |
||||||
|
# Plot cumulative drift per window transition |
||||||
|
x_labels = windows[: len(drifts_cumulative)] |
||||||
|
fig.add_trace( |
||||||
|
go.Scatter( |
||||||
|
x=x_labels, |
||||||
|
y=drifts_cumulative, |
||||||
|
mode="lines+markers", |
||||||
|
name=mp, |
||||||
|
) |
||||||
|
) |
||||||
|
|
||||||
|
fig.update_layout( |
||||||
|
title="MP Political Drift Over Time (Cumulative)", |
||||||
|
xaxis_title="Window", |
||||||
|
yaxis_title="Cumulative Drift", |
||||||
|
) |
||||||
|
fig.write_html(output_path, include_plotlyjs="cdn") |
||||||
|
_logger.info("Trajectory plot written to %s", output_path) |
||||||
|
return output_path |
||||||
|
|
||||||
|
|
||||||
|
def plot_political_axis( |
||||||
|
scores: Dict[str, float], |
||||||
|
party_of: Optional[Dict[str, str]] = None, |
||||||
|
window_id: Optional[str] = None, |
||||||
|
n_top: int = 30, |
||||||
|
output_path: str = "analysis_political_axis.html", |
||||||
|
) -> str: |
||||||
|
"""Horizontal bar chart of MP scores on the ideological axis. |
||||||
|
|
||||||
|
Args: |
||||||
|
scores: {mp_name: score} from political_axis module |
||||||
|
party_of: Optional {mp_name: party} for colour-coding |
||||||
|
window_id: Window label for the title |
||||||
|
n_top: Show only the top/bottom n MPs by score |
||||||
|
output_path: Output HTML path |
||||||
|
|
||||||
|
Returns the output_path on success. |
||||||
|
""" |
||||||
|
go, px = _require_plotly() |
||||||
|
|
||||||
|
# Sort by score |
||||||
|
sorted_items = sorted(scores.items(), key=lambda kv: kv[1]) |
||||||
|
|
||||||
|
# Take n_top from each end if list is large |
||||||
|
if len(sorted_items) > 2 * n_top: |
||||||
|
sorted_items = sorted_items[:n_top] + sorted_items[-n_top:] |
||||||
|
|
||||||
|
names = [item[0] for item in sorted_items] |
||||||
|
vals = [item[1] for item in sorted_items] |
||||||
|
colors = ( |
||||||
|
[party_of.get(n, "Unknown") for n in names] |
||||||
|
if party_of |
||||||
|
else ["Unknown"] * len(names) |
||||||
|
) |
||||||
|
|
||||||
|
title = "MP Ideological Axis Score" + (f" ({window_id})" if window_id else "") |
||||||
|
|
||||||
|
fig = px.bar( |
||||||
|
x=vals, |
||||||
|
y=names, |
||||||
|
color=colors, |
||||||
|
orientation="h", |
||||||
|
title=title, |
||||||
|
labels={"x": "Score (← left — right →)", "y": "MP", "color": "Party"}, |
||||||
|
) |
||||||
|
fig.update_layout(yaxis={"categoryorder": "total ascending"}) |
||||||
|
fig.write_html(output_path, include_plotlyjs="cdn") |
||||||
|
_logger.info("Political axis chart written to %s", output_path) |
||||||
|
return output_path |
||||||
@ -0,0 +1,261 @@ |
|||||||
|
"""CLI orchestrator for the parliamentary embedding pipeline. |
||||||
|
|
||||||
|
Runs all phases in sequence: |
||||||
|
1. fetch_mp_metadata — pull MP party + tenure from OData |
||||||
|
2. extract_mp_votes — parse voting_results JSON → mp_votes rows |
||||||
|
3. svd per window — build vote matrix, SVD, Procrustes-align |
||||||
|
4. text embeddings — fill any gaps in the embeddings table |
||||||
|
5. fuse per window — concatenate SVD + text vectors → fused_embeddings |
||||||
|
|
||||||
|
Usage: |
||||||
|
uv run python -m pipeline.run_pipeline [options] |
||||||
|
|
||||||
|
Options: |
||||||
|
--db-path PATH Path to the DuckDB file (default: data/motions.db) |
||||||
|
--start-date DATE Window start (YYYY-MM-DD, default: 2 years ago) |
||||||
|
--end-date DATE Window end (YYYY-MM-DD, default: today) |
||||||
|
--window-size {quarterly,annual} Time window granularity (default: quarterly) |
||||||
|
--svd-k INT SVD dimensionality (default: 50) |
||||||
|
--text-model TEXT Text embedding model name (default: from ai_provider) |
||||||
|
--skip-metadata Skip fetching MP metadata from OData |
||||||
|
--skip-extract Skip extracting MP votes from voting_results |
||||||
|
--skip-svd Skip SVD computation |
||||||
|
--skip-text Skip text embedding gap-fill |
||||||
|
--skip-fusion Skip vector fusion |
||||||
|
--dry-run Print actions but make no DB writes |
||||||
|
""" |
||||||
|
|
||||||
|
import argparse |
||||||
|
import calendar |
||||||
|
import logging |
||||||
|
import sys |
||||||
|
from datetime import date, timedelta |
||||||
|
from typing import List, Tuple |
||||||
|
|
||||||
|
from database import MotionDatabase |
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def _generate_windows( |
||||||
|
start: date, end: date, granularity: str |
||||||
|
) -> List[Tuple[str, str, str]]: |
||||||
|
"""Return list of (window_id, start_str, end_str) tuples. |
||||||
|
|
||||||
|
window_id format: |
||||||
|
quarterly → "2024-Q1", "2024-Q2", … |
||||||
|
annual → "2024" |
||||||
|
""" |
||||||
|
windows = [] |
||||||
|
cursor = date(start.year, start.month, 1) |
||||||
|
|
||||||
|
if granularity == "annual": |
||||||
|
cursor = date(start.year, 1, 1) |
||||||
|
while cursor <= end: |
||||||
|
year_end = date(cursor.year, 12, 31) |
||||||
|
w_end = min(year_end, end) |
||||||
|
windows.append((str(cursor.year), cursor.isoformat(), w_end.isoformat())) |
||||||
|
cursor = date(cursor.year + 1, 1, 1) |
||||||
|
else: |
||||||
|
# quarterly |
||||||
|
quarter_starts = {1: 1, 2: 4, 3: 7, 4: 10} |
||||||
|
quarter_ends = {1: 3, 2: 6, 3: 9, 4: 12} |
||||||
|
|
||||||
|
# Align cursor to quarter start |
||||||
|
q = (cursor.month - 1) // 3 + 1 |
||||||
|
cursor = date(cursor.year, quarter_starts[q], 1) |
||||||
|
|
||||||
|
while cursor <= end: |
||||||
|
q = (cursor.month - 1) // 3 + 1 |
||||||
|
q_end_month = quarter_ends[q] |
||||||
|
last_day = calendar.monthrange(cursor.year, q_end_month)[1] |
||||||
|
q_end = date(cursor.year, q_end_month, last_day) |
||||||
|
w_end = min(q_end, end) |
||||||
|
window_id = f"{cursor.year}-Q{q}" |
||||||
|
windows.append((window_id, cursor.isoformat(), w_end.isoformat())) |
||||||
|
cursor = q_end + timedelta(days=1) |
||||||
|
|
||||||
|
return windows |
||||||
|
|
||||||
|
|
||||||
|
def run(args: argparse.Namespace) -> int: |
||||||
|
"""Execute the pipeline. Returns exit code (0 = success).""" |
||||||
|
logging.basicConfig( |
||||||
|
level=logging.INFO, |
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
||||||
|
) |
||||||
|
|
||||||
|
db_path = args.db_path |
||||||
|
dry_run = args.dry_run |
||||||
|
|
||||||
|
if dry_run: |
||||||
|
_logger.info("DRY RUN — no writes will be made") |
||||||
|
|
||||||
|
# Resolve date range |
||||||
|
end_date = date.fromisoformat(args.end_date) if args.end_date else date.today() |
||||||
|
start_date = ( |
||||||
|
date.fromisoformat(args.start_date) |
||||||
|
if args.start_date |
||||||
|
else end_date - timedelta(days=730) |
||||||
|
) |
||||||
|
|
||||||
|
_logger.info( |
||||||
|
"Pipeline run: %s → %s (%s windows), db=%s", |
||||||
|
start_date, |
||||||
|
end_date, |
||||||
|
args.window_size, |
||||||
|
db_path, |
||||||
|
) |
||||||
|
|
||||||
|
db = MotionDatabase(db_path) |
||||||
|
|
||||||
|
# ── Phase 1: MP metadata ──────────────────────────────────────────────── |
||||||
|
if not args.skip_metadata: |
||||||
|
_logger.info("Phase 1: fetching MP metadata from OData") |
||||||
|
if not dry_run: |
||||||
|
from pipeline.fetch_mp_metadata import fetch_mp_metadata |
||||||
|
|
||||||
|
fetched, skipped = fetch_mp_metadata(db) |
||||||
|
_logger.info(" mp_metadata: fetched=%d skipped=%d", fetched, skipped) |
||||||
|
else: |
||||||
|
_logger.info(" [dry-run] would call fetch_mp_metadata(db)") |
||||||
|
else: |
||||||
|
_logger.info("Phase 1: skipped (--skip-metadata)") |
||||||
|
|
||||||
|
# ── Phase 2: Extract MP votes ──────────────────────────────────────────── |
||||||
|
if not args.skip_extract: |
||||||
|
_logger.info("Phase 2: extracting MP votes from voting_results") |
||||||
|
if not dry_run: |
||||||
|
from pipeline.extract_mp_votes import extract_mp_votes |
||||||
|
|
||||||
|
inserted, skipped = extract_mp_votes(db) |
||||||
|
_logger.info( |
||||||
|
" mp_votes: inserted=%d motions skipped=%d", inserted, skipped |
||||||
|
) |
||||||
|
else: |
||||||
|
_logger.info(" [dry-run] would call extract_mp_votes(db)") |
||||||
|
else: |
||||||
|
_logger.info("Phase 2: skipped (--skip-extract)") |
||||||
|
|
||||||
|
# ── Phase 3: SVD per window ────────────────────────────────────────────── |
||||||
|
if not args.skip_svd: |
||||||
|
windows = _generate_windows(start_date, end_date, args.window_size) |
||||||
|
_logger.info("Phase 3: SVD for %d windows (k=%d)", len(windows), args.svd_k) |
||||||
|
from pipeline.svd_pipeline import run_svd_for_window |
||||||
|
|
||||||
|
for window_id, w_start, w_end in windows: |
||||||
|
_logger.info(" window %s: %s → %s", window_id, w_start, w_end) |
||||||
|
if not dry_run: |
||||||
|
result = run_svd_for_window( |
||||||
|
db=db, |
||||||
|
window_id=window_id, |
||||||
|
start_date=w_start, |
||||||
|
end_date=w_end, |
||||||
|
k=args.svd_k, |
||||||
|
) |
||||||
|
_logger.info( |
||||||
|
" k_used=%d stored_mp=%d stored_motion=%d", |
||||||
|
result["k_used"], |
||||||
|
result["stored_mp"], |
||||||
|
result["stored_motion"], |
||||||
|
) |
||||||
|
else: |
||||||
|
_logger.info(" [dry-run] would run SVD for window %s", window_id) |
||||||
|
else: |
||||||
|
_logger.info("Phase 3: skipped (--skip-svd)") |
||||||
|
|
||||||
|
# ── Phase 4: Text embeddings ────────────────────────────────────────────── |
||||||
|
if not args.skip_text: |
||||||
|
_logger.info("Phase 4: ensuring text embeddings") |
||||||
|
if not dry_run: |
||||||
|
from pipeline.text_pipeline import ensure_text_embeddings |
||||||
|
|
||||||
|
stored, existing, no_text, errors = ensure_text_embeddings( |
||||||
|
db_path=db_path, model=args.text_model |
||||||
|
) |
||||||
|
_logger.info( |
||||||
|
" embeddings: stored=%d existing=%d no_text=%d errors=%d", |
||||||
|
stored, |
||||||
|
existing, |
||||||
|
no_text, |
||||||
|
errors, |
||||||
|
) |
||||||
|
else: |
||||||
|
_logger.info(" [dry-run] would call ensure_text_embeddings") |
||||||
|
else: |
||||||
|
_logger.info("Phase 4: skipped (--skip-text)") |
||||||
|
|
||||||
|
# ── Phase 5: Fusion per window ──────────────────────────────────────────── |
||||||
|
if not args.skip_fusion: |
||||||
|
windows = _generate_windows(start_date, end_date, args.window_size) |
||||||
|
_logger.info("Phase 5: fusing vectors for %d windows", len(windows)) |
||||||
|
from pipeline.fusion import fuse_for_window |
||||||
|
|
||||||
|
for window_id, _w_start, _w_end in windows: |
||||||
|
if not dry_run: |
||||||
|
result = fuse_for_window( |
||||||
|
window_id=window_id, |
||||||
|
db_path=db_path, |
||||||
|
model=args.text_model, |
||||||
|
) |
||||||
|
_logger.info( |
||||||
|
" window %s: fused=%d skipped_no_svd=%d skipped_no_text=%d", |
||||||
|
window_id, |
||||||
|
result["fused"], |
||||||
|
result.get("skipped_no_svd", 0), |
||||||
|
result.get("skipped_no_text", 0), |
||||||
|
) |
||||||
|
else: |
||||||
|
_logger.info(" [dry-run] would fuse window %s", window_id) |
||||||
|
else: |
||||||
|
_logger.info("Phase 5: skipped (--skip-fusion)") |
||||||
|
|
||||||
|
_logger.info("Pipeline complete.") |
||||||
|
return 0 |
||||||
|
|
||||||
|
|
||||||
|
def build_parser() -> argparse.ArgumentParser: |
||||||
|
parser = argparse.ArgumentParser( |
||||||
|
description="Parliamentary embedding pipeline orchestrator", |
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
||||||
|
) |
||||||
|
parser.add_argument( |
||||||
|
"--db-path", default="data/motions.db", help="Path to DuckDB file" |
||||||
|
) |
||||||
|
parser.add_argument("--start-date", default=None, help="Window start YYYY-MM-DD") |
||||||
|
parser.add_argument("--end-date", default=None, help="Window end YYYY-MM-DD") |
||||||
|
parser.add_argument( |
||||||
|
"--window-size", |
||||||
|
choices=["quarterly", "annual"], |
||||||
|
default="quarterly", |
||||||
|
help="Time window granularity", |
||||||
|
) |
||||||
|
parser.add_argument("--svd-k", type=int, default=50, help="SVD dimensions") |
||||||
|
parser.add_argument( |
||||||
|
"--text-model", |
||||||
|
default=None, |
||||||
|
help="Text embedding model (default: ai_provider default)", |
||||||
|
) |
||||||
|
parser.add_argument( |
||||||
|
"--skip-metadata", action="store_true", help="Skip MP metadata fetch" |
||||||
|
) |
||||||
|
parser.add_argument( |
||||||
|
"--skip-extract", action="store_true", help="Skip MP vote extraction" |
||||||
|
) |
||||||
|
parser.add_argument("--skip-svd", action="store_true", help="Skip SVD computation") |
||||||
|
parser.add_argument( |
||||||
|
"--skip-text", action="store_true", help="Skip text embedding gap-fill" |
||||||
|
) |
||||||
|
parser.add_argument("--skip-fusion", action="store_true", help="Skip vector fusion") |
||||||
|
parser.add_argument( |
||||||
|
"--dry-run", |
||||||
|
action="store_true", |
||||||
|
help="Print what would happen without writing anything", |
||||||
|
) |
||||||
|
return parser |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
parser = build_parser() |
||||||
|
args = parser.parse_args() |
||||||
|
sys.exit(run(args)) |
||||||
@ -0,0 +1,195 @@ |
|||||||
|
"""Tests for analysis modules: political_axis, trajectory, clustering.""" |
||||||
|
|
||||||
|
import json |
||||||
|
import numpy as np |
||||||
|
import pytest |
||||||
|
|
||||||
|
duckdb = pytest.importorskip("duckdb") |
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────────── |
||||||
|
|
||||||
|
|
||||||
|
def _setup_svd_vectors(db_path: str, window_ids_mp_vecs: dict): |
||||||
|
"""Insert synthetic MP SVD vectors into svd_vectors table. |
||||||
|
|
||||||
|
window_ids_mp_vecs: {window_id: {mp_name: np.ndarray}} |
||||||
|
""" |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
conn.execute( |
||||||
|
""" |
||||||
|
CREATE TABLE IF NOT EXISTS svd_vectors ( |
||||||
|
id INTEGER, |
||||||
|
window_id TEXT, |
||||||
|
entity_type TEXT, |
||||||
|
entity_id TEXT, |
||||||
|
vector JSON, |
||||||
|
model TEXT, |
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
||||||
|
) |
||||||
|
""" |
||||||
|
) |
||||||
|
for wid, mp_vecs in window_ids_mp_vecs.items(): |
||||||
|
for mp_name, vec in mp_vecs.items(): |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, 'mp', ?, ?, 'test')", |
||||||
|
(wid, mp_name, json.dumps(vec.tolist())), |
||||||
|
) |
||||||
|
conn.close() |
||||||
|
|
||||||
|
|
||||||
|
def _setup_mp_metadata(db_path: str, mp_party: dict): |
||||||
|
"""Insert synthetic MP metadata rows.""" |
||||||
|
conn = duckdb.connect(db_path) |
||||||
|
conn.execute( |
||||||
|
""" |
||||||
|
CREATE TABLE IF NOT EXISTS mp_metadata ( |
||||||
|
mp_name TEXT, |
||||||
|
party TEXT, |
||||||
|
van DATE, |
||||||
|
tot_en_met DATE, |
||||||
|
persoon_id TEXT |
||||||
|
) |
||||||
|
""" |
||||||
|
) |
||||||
|
for mp_name, party in mp_party.items(): |
||||||
|
conn.execute( |
||||||
|
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)", |
||||||
|
(mp_name, party), |
||||||
|
) |
||||||
|
conn.close() |
||||||
|
|
||||||
|
|
||||||
|
# ── political_axis ──────────────────────────────────────────────────────────── |
||||||
|
|
||||||
|
|
||||||
|
class TestPoliticalAxis: |
||||||
|
def test_pca_axis_basic(self, tmp_path): |
||||||
|
np.random.seed(42) |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
n_mps, k = 20, 5 |
||||||
|
|
||||||
|
# Create a low-rank set of MP vectors (they should have a clear first PC) |
||||||
|
vecs = np.random.randn(n_mps, k) |
||||||
|
mp_names = [f"MP_{i}" for i in range(n_mps)] |
||||||
|
_setup_svd_vectors( |
||||||
|
db_path, {"2024-Q1": {mp_names[i]: vecs[i] for i in range(n_mps)}} |
||||||
|
) |
||||||
|
|
||||||
|
from analysis.political_axis import compute_pca_axis |
||||||
|
|
||||||
|
scores = compute_pca_axis(db_path, "2024-Q1") |
||||||
|
assert len(scores) == n_mps |
||||||
|
assert all(isinstance(v, float) for v in scores.values()) |
||||||
|
# Scores should have non-trivial variance |
||||||
|
vals = list(scores.values()) |
||||||
|
assert np.std(vals) > 0.0 |
||||||
|
|
||||||
|
def test_pca_axis_too_few_mps(self, tmp_path): |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
_setup_svd_vectors(db_path, {"w1": {"MP_A": np.array([1.0, 0.0])}}) |
||||||
|
|
||||||
|
from analysis.political_axis import compute_pca_axis |
||||||
|
|
||||||
|
scores = compute_pca_axis(db_path, "w1") |
||||||
|
assert scores == {} |
||||||
|
|
||||||
|
def test_anchor_axis_basic(self, tmp_path): |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
# Two clusters clearly separated on dim 0 |
||||||
|
left_vec = np.array([-2.0, 0.0, 0.0]) |
||||||
|
right_vec = np.array([2.0, 0.0, 0.0]) |
||||||
|
mp_vecs = { |
||||||
|
"Left_A": left_vec + np.array([0.1, 0.0, 0.0]), |
||||||
|
"Left_B": left_vec - np.array([0.1, 0.0, 0.0]), |
||||||
|
"Right_A": right_vec + np.array([0.1, 0.0, 0.0]), |
||||||
|
"Right_B": right_vec - np.array([0.1, 0.0, 0.0]), |
||||||
|
"Centre": np.array([0.0, 0.0, 0.0]), |
||||||
|
} |
||||||
|
_setup_svd_vectors(db_path, {"w1": mp_vecs}) |
||||||
|
_setup_mp_metadata( |
||||||
|
db_path, |
||||||
|
{ |
||||||
|
"Left_A": "SP", |
||||||
|
"Left_B": "SP", |
||||||
|
"Right_A": "VVD", |
||||||
|
"Right_B": "VVD", |
||||||
|
"Centre": "D66", |
||||||
|
}, |
||||||
|
) |
||||||
|
|
||||||
|
from analysis.political_axis import compute_anchor_axis |
||||||
|
|
||||||
|
scores = compute_anchor_axis( |
||||||
|
db_path, "w1", left_parties=["SP"], right_parties=["VVD"] |
||||||
|
) |
||||||
|
assert len(scores) == 5 |
||||||
|
# Left MPs should have negative scores, Right MPs positive |
||||||
|
assert scores["Left_A"] < scores["Right_A"] |
||||||
|
assert scores["Left_B"] < scores["Right_B"] |
||||||
|
|
||||||
|
|
||||||
|
# ── trajectory ─────────────────────────────────────────────────────────────── |
||||||
|
|
||||||
|
|
||||||
|
class TestTrajectory: |
||||||
|
def test_basic_trajectory(self, tmp_path): |
||||||
|
np.random.seed(0) |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
|
||||||
|
vec_w1 = {"MP_A": np.array([1.0, 0.0]), "MP_B": np.array([0.0, 1.0])} |
||||||
|
vec_w2 = { |
||||||
|
"MP_A": np.array([1.5, 0.5]), |
||||||
|
"MP_B": np.array([0.0, 1.0]), |
||||||
|
"MP_C": np.array([2.0, 2.0]), |
||||||
|
} |
||||||
|
_setup_svd_vectors(db_path, {"2024-Q1": vec_w1, "2024-Q2": vec_w2}) |
||||||
|
|
||||||
|
from analysis.trajectory import compute_trajectories, top_drifters |
||||||
|
|
||||||
|
traj = compute_trajectories(db_path) |
||||||
|
|
||||||
|
# Only MPs appearing in >= 2 windows |
||||||
|
assert "MP_A" in traj |
||||||
|
assert "MP_B" in traj |
||||||
|
assert "MP_C" not in traj # only in one window |
||||||
|
|
||||||
|
assert len(traj["MP_A"]["drift"]) == 1 |
||||||
|
assert traj["MP_A"]["total_drift"] > 0.0 |
||||||
|
|
||||||
|
# MP_B didn't move — drift should be 0 |
||||||
|
assert traj["MP_B"]["total_drift"] == pytest.approx(0.0) |
||||||
|
|
||||||
|
drifters = top_drifters(traj, n=5) |
||||||
|
assert drifters[0]["mp_name"] == "MP_A" |
||||||
|
|
||||||
|
def test_fewer_than_2_windows(self, tmp_path): |
||||||
|
db_path = str(tmp_path / "test.db") |
||||||
|
_setup_svd_vectors(db_path, {"2024-Q1": {"MP_A": np.array([1.0, 2.0])}}) |
||||||
|
|
||||||
|
from analysis.trajectory import compute_trajectories |
||||||
|
|
||||||
|
traj = compute_trajectories(db_path) |
||||||
|
assert traj == {} |
||||||
|
|
||||||
|
|
||||||
|
# ── clustering ──────────────────────────────────────────────────────────────── |
||||||
|
|
||||||
|
|
||||||
|
class TestClustering: |
||||||
|
def test_cluster_kmeans_basic(self): |
||||||
|
from analysis.clustering import cluster_kmeans |
||||||
|
import numpy as np |
||||||
|
|
||||||
|
coords = np.random.randn(20, 2) |
||||||
|
labels = cluster_kmeans(coords, n_clusters=3) |
||||||
|
assert len(labels) == 20 |
||||||
|
assert set(labels).issubset({0, 1, 2}) |
||||||
|
|
||||||
|
def test_cluster_kmeans_fewer_points_than_clusters(self): |
||||||
|
from analysis.clustering import cluster_kmeans |
||||||
|
|
||||||
|
coords = np.array([[0.0, 0.0], [1.0, 1.0]]) |
||||||
|
labels = cluster_kmeans(coords, n_clusters=5) |
||||||
|
# Should not crash; n_clusters clamped to len(coords) |
||||||
|
assert len(labels) == 2 |
||||||
@ -0,0 +1,113 @@ |
|||||||
|
"""Tests for pipeline/run_pipeline.py""" |
||||||
|
|
||||||
|
import argparse |
||||||
|
import sys |
||||||
|
import pytest |
||||||
|
|
||||||
|
from pipeline.run_pipeline import _generate_windows, build_parser, run |
||||||
|
from datetime import date |
||||||
|
|
||||||
|
|
||||||
|
def test_generate_windows_quarterly(): |
||||||
|
start = date(2024, 1, 1) |
||||||
|
end = date(2024, 12, 31) |
||||||
|
windows = _generate_windows(start, end, "quarterly") |
||||||
|
|
||||||
|
assert len(windows) == 4 |
||||||
|
ids = [w[0] for w in windows] |
||||||
|
assert ids == ["2024-Q1", "2024-Q2", "2024-Q3", "2024-Q4"] |
||||||
|
|
||||||
|
# Q1 bounds |
||||||
|
assert windows[0][1] == "2024-01-01" |
||||||
|
assert windows[0][2] == "2024-03-31" |
||||||
|
|
||||||
|
# Q4 bounds |
||||||
|
assert windows[3][1] == "2024-10-01" |
||||||
|
assert windows[3][2] == "2024-12-31" |
||||||
|
|
||||||
|
|
||||||
|
def test_generate_windows_annual(): |
||||||
|
start = date(2022, 6, 1) |
||||||
|
end = date(2024, 3, 31) |
||||||
|
windows = _generate_windows(start, end, "annual") |
||||||
|
|
||||||
|
assert len(windows) == 3 |
||||||
|
ids = [w[0] for w in windows] |
||||||
|
assert ids == ["2022", "2023", "2024"] |
||||||
|
|
||||||
|
# 2024 should end at end_date, not Dec 31 |
||||||
|
assert windows[2][2] == "2024-03-31" |
||||||
|
|
||||||
|
|
||||||
|
def test_generate_windows_mid_quarter_start(): |
||||||
|
"""Starting in the middle of Q2 should still produce a full Q2 window.""" |
||||||
|
start = date(2024, 5, 15) |
||||||
|
end = date(2024, 9, 30) |
||||||
|
windows = _generate_windows(start, end, "quarterly") |
||||||
|
|
||||||
|
ids = [w[0] for w in windows] |
||||||
|
assert "2024-Q2" in ids |
||||||
|
assert "2024-Q3" in ids |
||||||
|
|
||||||
|
|
||||||
|
def test_build_parser_defaults(): |
||||||
|
parser = build_parser() |
||||||
|
args = parser.parse_args([]) |
||||||
|
assert args.db_path == "data/motions.db" |
||||||
|
assert args.window_size == "quarterly" |
||||||
|
assert args.svd_k == 50 |
||||||
|
assert args.dry_run is False |
||||||
|
|
||||||
|
|
||||||
|
def test_run_dry_run(tmp_path, monkeypatch): |
||||||
|
"""Dry-run should log actions and return 0 without touching the DB.""" |
||||||
|
db_path = str(tmp_path / "motions.db") |
||||||
|
|
||||||
|
# Create minimal DB so MotionDatabase initialises |
||||||
|
from database import MotionDatabase |
||||||
|
|
||||||
|
MotionDatabase(db_path) |
||||||
|
|
||||||
|
args = argparse.Namespace( |
||||||
|
db_path=db_path, |
||||||
|
start_date="2024-01-01", |
||||||
|
end_date="2024-03-31", |
||||||
|
window_size="quarterly", |
||||||
|
svd_k=10, |
||||||
|
text_model=None, |
||||||
|
skip_metadata=False, |
||||||
|
skip_extract=False, |
||||||
|
skip_svd=False, |
||||||
|
skip_text=False, |
||||||
|
skip_fusion=False, |
||||||
|
dry_run=True, |
||||||
|
) |
||||||
|
|
||||||
|
exit_code = run(args) |
||||||
|
assert exit_code == 0 |
||||||
|
|
||||||
|
|
||||||
|
def test_run_skip_all(tmp_path): |
||||||
|
"""Skipping all phases should still return 0.""" |
||||||
|
db_path = str(tmp_path / "motions.db") |
||||||
|
from database import MotionDatabase |
||||||
|
|
||||||
|
MotionDatabase(db_path) |
||||||
|
|
||||||
|
args = argparse.Namespace( |
||||||
|
db_path=db_path, |
||||||
|
start_date="2024-01-01", |
||||||
|
end_date="2024-03-31", |
||||||
|
window_size="quarterly", |
||||||
|
svd_k=10, |
||||||
|
text_model=None, |
||||||
|
skip_metadata=True, |
||||||
|
skip_extract=True, |
||||||
|
skip_svd=True, |
||||||
|
skip_text=True, |
||||||
|
skip_fusion=True, |
||||||
|
dry_run=False, |
||||||
|
) |
||||||
|
|
||||||
|
exit_code = run(args) |
||||||
|
assert exit_code == 0 |
||||||
Loading…
Reference in new issue