- pipeline/run_pipeline.py: CLI orchestrator for all 5 pipeline phases with
--dry-run, --skip-*, --window-size, --svd-k, --start/end-date flags
- analysis/{political_axis,trajectory,clustering,visualize}.py: PCA/anchor
ideological axis, MP drift trajectories, UMAP + KMeans clustering, Plotly HTML output
- api_client.py: capture ActorFractie per individual MP vote (comma in ActorNaam)
into mp_vote_parties dict on each motion
- database.insert_motion: auto-insert mp_votes rows with party affiliation for
newly ingested motions when mp_vote_parties is present
- Add scikit-learn to pyproject.toml for KMeans clustering
- tests/test_run_pipeline.py: window generation, dry-run, skip-all paths
- tests/test_analysis.py: PCA axis, anchor axis, trajectory drift, KMeans
Ref: thoughts/shared/plans/2026-03-21-parliamentary-embedding-pipeline-plan.md
main
parent
a36e6cba4e
commit
f2a831dfcf
@ -0,0 +1,8 @@ |
||||
"""Analysis modules for the parliamentary embedding pipeline. |
||||
|
||||
Modules: |
||||
political_axis — project MP SVD vectors onto ideological axis |
||||
trajectory — compute MP drift across aligned windows |
||||
clustering — UMAP dimensionality reduction + cluster labelling |
||||
visualize — Plotly interactive plots (outputs self-contained HTML) |
||||
""" |
||||
@ -0,0 +1,130 @@ |
||||
"""clustering.py — UMAP dimensionality reduction on fused embeddings. |
||||
|
||||
Reduces fused motion embeddings to 2D (or 3D) for visualisation, |
||||
and optionally labels clusters using KMeans. |
||||
|
||||
Requires: umap-learn, scikit-learn (for KMeans) |
||||
""" |
||||
|
||||
import json |
||||
import logging |
||||
from typing import Dict, List, Optional, Tuple |
||||
|
||||
import numpy as np |
||||
import duckdb |
||||
|
||||
_logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
def _load_fused_vectors( |
||||
db_path: str, window_id: Optional[str] = None |
||||
) -> Tuple[List[int], List[str], np.ndarray]: |
||||
"""Load fused embeddings from the DB. |
||||
|
||||
Returns (motion_ids, window_ids, matrix). |
||||
Optionally filter by window_id. |
||||
""" |
||||
conn = duckdb.connect(db_path) |
||||
if window_id: |
||||
rows = conn.execute( |
||||
"SELECT motion_id, window_id, vector FROM fused_embeddings WHERE window_id = ?", |
||||
(window_id,), |
||||
).fetchall() |
||||
else: |
||||
rows = conn.execute( |
||||
"SELECT motion_id, window_id, vector FROM fused_embeddings ORDER BY window_id, motion_id" |
||||
).fetchall() |
||||
conn.close() |
||||
|
||||
motion_ids, window_ids, vectors = [], [], [] |
||||
for motion_id, wid, vec_json in rows: |
||||
try: |
||||
vec = json.loads(vec_json) |
||||
motion_ids.append(int(motion_id)) |
||||
window_ids.append(wid) |
||||
vectors.append(vec) |
||||
except Exception: |
||||
_logger.warning("Could not parse fused vector for motion %s", motion_id) |
||||
|
||||
if not vectors: |
||||
return [], [], np.zeros((0, 0)) |
||||
|
||||
# Pad to common length if needed (shouldn't happen if pipeline is consistent) |
||||
max_len = max(len(v) for v in vectors) |
||||
mat = np.zeros((len(vectors), max_len), dtype=float) |
||||
for i, v in enumerate(vectors): |
||||
mat[i, : len(v)] = v |
||||
|
||||
return motion_ids, window_ids, mat |
||||
|
||||
|
||||
def run_umap( |
||||
db_path: str, |
||||
window_id: Optional[str] = None, |
||||
n_components: int = 2, |
||||
n_neighbors: int = 15, |
||||
min_dist: float = 0.1, |
||||
random_state: int = 42, |
||||
) -> Dict: |
||||
"""Run UMAP on fused embeddings and return 2D/3D coordinates. |
||||
|
||||
Returns: |
||||
{ |
||||
"motion_ids": [...], |
||||
"window_ids": [...], |
||||
"coords": [[x, y], ...], # or [x, y, z] if n_components=3 |
||||
"n_components": int, |
||||
} |
||||
""" |
||||
try: |
||||
import umap |
||||
except ImportError: |
||||
_logger.error("umap-learn is not installed; cannot run UMAP") |
||||
return {} |
||||
|
||||
motion_ids, window_ids, mat = _load_fused_vectors(db_path, window_id) |
||||
if mat.size == 0: |
||||
_logger.warning("No fused embeddings found for window_id=%s", window_id) |
||||
return {} |
||||
|
||||
if mat.shape[0] < n_neighbors + 1: |
||||
# UMAP requires at least n_neighbors+1 samples |
||||
n_neighbors = max(2, mat.shape[0] - 1) |
||||
_logger.warning( |
||||
"Reduced n_neighbors to %d due to small dataset (%d samples)", |
||||
n_neighbors, |
||||
mat.shape[0], |
||||
) |
||||
|
||||
reducer = umap.UMAP( |
||||
n_components=n_components, |
||||
n_neighbors=n_neighbors, |
||||
min_dist=min_dist, |
||||
random_state=random_state, |
||||
) |
||||
coords = reducer.fit_transform(mat) |
||||
|
||||
return { |
||||
"motion_ids": motion_ids, |
||||
"window_ids": window_ids, |
||||
"coords": coords.tolist(), |
||||
"n_components": n_components, |
||||
} |
||||
|
||||
|
||||
def cluster_kmeans( |
||||
coords: np.ndarray, n_clusters: int = 8, random_state: int = 42 |
||||
) -> np.ndarray: |
||||
"""Run KMeans on 2D/3D UMAP coordinates. |
||||
|
||||
Returns array of integer cluster labels (length = len(coords)). |
||||
""" |
||||
try: |
||||
from sklearn.cluster import KMeans |
||||
except ImportError: |
||||
_logger.error("scikit-learn is not installed; cannot run KMeans") |
||||
return np.zeros(len(coords), dtype=int) |
||||
|
||||
n_clusters = min(n_clusters, len(coords)) |
||||
km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto") |
||||
return km.fit_predict(coords) |
||||
@ -0,0 +1,125 @@ |
||||
"""political_axis.py — Project MP SVD vectors onto an ideological axis. |
||||
|
||||
Two modes: |
||||
1. PCA mode (default): compute the first principal component of all MP SVD |
||||
vectors for a window and project each MP onto it. The sign is arbitrary |
||||
but consistent within a window. |
||||
|
||||
2. Anchor mode: define the axis as the vector from the centroid of |
||||
``left_parties`` to the centroid of ``right_parties``. Project all MPs |
||||
onto this normalised anchor axis. |
||||
|
||||
Both modes return a dict mapping mp_name → scalar score for the given window. |
||||
""" |
||||
|
||||
import json |
||||
import logging |
||||
from typing import Dict, List, Optional |
||||
|
||||
import numpy as np |
||||
import duckdb |
||||
|
||||
_logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
def _load_mp_svd_vectors(db_path: str, window_id: str) -> Dict[str, np.ndarray]: |
||||
"""Load all MP SVD vectors for a window from svd_vectors table.""" |
||||
conn = duckdb.connect(db_path) |
||||
rows = conn.execute( |
||||
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'", |
||||
(window_id,), |
||||
).fetchall() |
||||
conn.close() |
||||
|
||||
result = {} |
||||
for mp_name, vec_json in rows: |
||||
try: |
||||
result[mp_name] = np.array(json.loads(vec_json), dtype=float) |
||||
except Exception: |
||||
_logger.warning("Could not parse SVD vector for MP %s", mp_name) |
||||
return result |
||||
|
||||
|
||||
def compute_pca_axis(db_path: str, window_id: str) -> Dict[str, float]: |
||||
"""Project MP SVD vectors onto their first principal component. |
||||
|
||||
Returns {mp_name: score}. Returns empty dict if fewer than 2 MPs. |
||||
""" |
||||
mp_vecs = _load_mp_svd_vectors(db_path, window_id) |
||||
if len(mp_vecs) < 2: |
||||
_logger.warning( |
||||
"window %s has only %d MPs; skipping PCA axis", window_id, len(mp_vecs) |
||||
) |
||||
return {} |
||||
|
||||
names = list(mp_vecs.keys()) |
||||
mat = np.vstack([mp_vecs[n] for n in names]) # (n_mps, k) |
||||
|
||||
# Centre |
||||
mat_centred = mat - mat.mean(axis=0) |
||||
|
||||
# First PC via SVD |
||||
try: |
||||
_, _, Vt = np.linalg.svd(mat_centred, full_matrices=False) |
||||
axis = Vt[0] # (k,) |
||||
except np.linalg.LinAlgError: |
||||
_logger.exception("SVD failed in compute_pca_axis for window %s", window_id) |
||||
return {} |
||||
|
||||
projections = mat_centred.dot(axis) |
||||
return {name: float(score) for name, score in zip(names, projections)} |
||||
|
||||
|
||||
def compute_anchor_axis( |
||||
db_path: str, |
||||
window_id: str, |
||||
left_parties: List[str], |
||||
right_parties: List[str], |
||||
) -> Dict[str, float]: |
||||
"""Project MP SVD vectors onto a left↔right anchor axis. |
||||
|
||||
The axis runs from the centroid of ``left_parties`` to the centroid of |
||||
``right_parties``. Positive scores are toward the right. |
||||
|
||||
Returns {mp_name: score}. |
||||
""" |
||||
mp_vecs = _load_mp_svd_vectors(db_path, window_id) |
||||
if not mp_vecs: |
||||
return {} |
||||
|
||||
# Load party affiliation for this window from mp_metadata |
||||
conn = duckdb.connect(db_path) |
||||
rows = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchall() |
||||
conn.close() |
||||
party_of = {mp: party for mp, party in rows} |
||||
|
||||
left_vecs = [ |
||||
mp_vecs[mp] |
||||
for mp, party in party_of.items() |
||||
if party in left_parties and mp in mp_vecs |
||||
] |
||||
right_vecs = [ |
||||
mp_vecs[mp] |
||||
for mp, party in party_of.items() |
||||
if party in right_parties and mp in mp_vecs |
||||
] |
||||
|
||||
if not left_vecs or not right_vecs: |
||||
_logger.warning( |
||||
"window %s: insufficient anchor parties (left=%d, right=%d)", |
||||
window_id, |
||||
len(left_vecs), |
||||
len(right_vecs), |
||||
) |
||||
return {} |
||||
|
||||
left_centroid = np.mean(left_vecs, axis=0) |
||||
right_centroid = np.mean(right_vecs, axis=0) |
||||
axis = right_centroid - left_centroid |
||||
norm = np.linalg.norm(axis) |
||||
if norm < 1e-10: |
||||
_logger.warning("Anchor axis has near-zero norm for window %s", window_id) |
||||
return {} |
||||
axis = axis / norm |
||||
|
||||
return {name: float(np.dot(vec, axis)) for name, vec in mp_vecs.items()} |
||||
@ -0,0 +1,123 @@ |
||||
"""trajectory.py — Compute MP political drift across aligned time windows. |
||||
|
||||
For each MP that appears in multiple windows, computes: |
||||
- The aligned SVD vector per window |
||||
- The Euclidean distance between consecutive windows (drift) |
||||
- Total cumulative drift |
||||
|
||||
Returns a dict keyed by mp_name containing per-window positions and drift scores. |
||||
""" |
||||
|
||||
import json |
||||
import logging |
||||
from typing import Dict, List, Optional |
||||
|
||||
import numpy as np |
||||
import duckdb |
||||
|
||||
_logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
def _load_window_ids(db_path: str) -> List[str]: |
||||
"""Return all distinct window IDs from svd_vectors, in lexicographic order.""" |
||||
conn = duckdb.connect(db_path) |
||||
rows = conn.execute( |
||||
"SELECT DISTINCT window_id FROM svd_vectors WHERE entity_type = 'mp' ORDER BY window_id" |
||||
).fetchall() |
||||
conn.close() |
||||
return [r[0] for r in rows] |
||||
|
||||
|
||||
def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.ndarray]: |
||||
conn = duckdb.connect(db_path) |
||||
rows = conn.execute( |
||||
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'", |
||||
(window_id,), |
||||
).fetchall() |
||||
conn.close() |
||||
result = {} |
||||
for mp_name, vec_json in rows: |
||||
try: |
||||
result[mp_name] = np.array(json.loads(vec_json), dtype=float) |
||||
except Exception: |
||||
_logger.warning( |
||||
"Could not parse vector for MP %s window %s", mp_name, window_id |
||||
) |
||||
return result |
||||
|
||||
|
||||
def compute_trajectories( |
||||
db_path: str, |
||||
window_ids: Optional[List[str]] = None, |
||||
) -> Dict[str, Dict]: |
||||
"""Compute per-MP trajectories across windows. |
||||
|
||||
Returns: |
||||
{ |
||||
mp_name: { |
||||
"windows": [window_id, ...], |
||||
"vectors": [[...], ...], # one vector per window |
||||
"drift": [float, ...], # consecutive Euclidean distances |
||||
"total_drift": float, |
||||
} |
||||
} |
||||
Only MPs present in at least 2 windows are included. |
||||
""" |
||||
if window_ids is None: |
||||
window_ids = _load_window_ids(db_path) |
||||
|
||||
if len(window_ids) < 2: |
||||
_logger.info("Fewer than 2 windows — no trajectories to compute") |
||||
return {} |
||||
|
||||
# Collect per-window vectors for each MP |
||||
mp_data: Dict[str, Dict] = {} |
||||
|
||||
for wid in window_ids: |
||||
vecs = _load_mp_vectors_for_window(db_path, wid) |
||||
for mp_name, vec in vecs.items(): |
||||
if mp_name not in mp_data: |
||||
mp_data[mp_name] = {"windows": [], "vectors": []} |
||||
mp_data[mp_name]["windows"].append(wid) |
||||
mp_data[mp_name]["vectors"].append(vec) |
||||
|
||||
# Compute drift for MPs with >= 2 windows |
||||
result = {} |
||||
for mp_name, data in mp_data.items(): |
||||
if len(data["windows"]) < 2: |
||||
continue |
||||
vecs = data["vectors"] |
||||
drifts = [ |
||||
float(np.linalg.norm(vecs[i + 1] - vecs[i])) for i in range(len(vecs) - 1) |
||||
] |
||||
result[mp_name] = { |
||||
"windows": data["windows"], |
||||
"vectors": [v.tolist() for v in vecs], |
||||
"drift": drifts, |
||||
"total_drift": float(sum(drifts)), |
||||
} |
||||
|
||||
_logger.info( |
||||
"Trajectories computed for %d MPs across %d windows", |
||||
len(result), |
||||
len(window_ids), |
||||
) |
||||
return result |
||||
|
||||
|
||||
def top_drifters(trajectories: Dict[str, Dict], n: int = 10) -> List[Dict]: |
||||
"""Return the top-n MPs by total drift, sorted descending. |
||||
|
||||
Each entry: {"mp_name": ..., "total_drift": ..., "windows": [...]} |
||||
""" |
||||
ranked = sorted( |
||||
trajectories.items(), key=lambda kv: kv[1]["total_drift"], reverse=True |
||||
) |
||||
return [ |
||||
{ |
||||
"mp_name": mp, |
||||
"total_drift": data["total_drift"], |
||||
"windows": data["windows"], |
||||
} |
||||
for mp, data in ranked[:n] |
||||
] |
||||
@ -0,0 +1,163 @@ |
||||
"""visualize.py — Plotly interactive plots for parliamentary embeddings. |
||||
|
||||
Produces self-contained HTML files. |
||||
|
||||
Functions: |
||||
plot_umap_scatter — 2D scatter of fused motion embeddings, coloured by cluster |
||||
plot_mp_trajectory — Line plot of MP drift across windows |
||||
plot_political_axis — Bar chart of MP scores on the ideological axis |
||||
""" |
||||
|
||||
import logging |
||||
from typing import Dict, List, Optional |
||||
|
||||
import numpy as np |
||||
|
||||
_logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
def _require_plotly(): |
||||
try: |
||||
import plotly.graph_objects as go |
||||
import plotly.express as px |
||||
|
||||
return go, px |
||||
except ImportError: |
||||
raise ImportError("plotly is not installed. Install it with: uv add plotly") |
||||
|
||||
|
||||
def plot_umap_scatter( |
||||
motion_ids: List[int], |
||||
coords: List[List[float]], |
||||
labels: Optional[List[int]] = None, |
||||
window_id: Optional[str] = None, |
||||
output_path: str = "analysis_umap.html", |
||||
) -> str: |
||||
"""Produce a 2D scatter plot of UMAP-reduced fused embeddings. |
||||
|
||||
Args: |
||||
motion_ids: Motion IDs (used as hover labels) |
||||
coords: List of [x, y] coordinates |
||||
labels: Optional cluster labels (integer per motion) |
||||
window_id: Window label for the plot title |
||||
output_path: Where to write the self-contained HTML |
||||
|
||||
Returns the output_path on success. |
||||
""" |
||||
go, px = _require_plotly() |
||||
|
||||
xs = [c[0] for c in coords] |
||||
ys = [c[1] for c in coords] |
||||
color = labels if labels is not None else [0] * len(motion_ids) |
||||
title = f"UMAP — fused motion embeddings" + (f" ({window_id})" if window_id else "") |
||||
|
||||
fig = px.scatter( |
||||
x=xs, |
||||
y=ys, |
||||
color=[str(c) for c in color], |
||||
hover_name=[str(mid) for mid in motion_ids], |
||||
title=title, |
||||
labels={"x": "UMAP-1", "y": "UMAP-2", "color": "Cluster"}, |
||||
) |
||||
fig.write_html(output_path, include_plotlyjs="cdn") |
||||
_logger.info("UMAP scatter written to %s", output_path) |
||||
return output_path |
||||
|
||||
|
||||
def plot_mp_trajectory( |
||||
trajectories: Dict[str, Dict], |
||||
mp_names: Optional[List[str]] = None, |
||||
output_path: str = "analysis_trajectory.html", |
||||
) -> str: |
||||
"""Line plot of MP drift across time windows. |
||||
|
||||
Args: |
||||
trajectories: Output of analysis.trajectory.compute_trajectories() |
||||
mp_names: Subset of MPs to plot (default: all) |
||||
output_path: Output HTML file path |
||||
|
||||
Returns the output_path on success. |
||||
""" |
||||
go, px = _require_plotly() |
||||
|
||||
if mp_names is None: |
||||
mp_names = list(trajectories.keys()) |
||||
|
||||
fig = go.Figure() |
||||
|
||||
for mp in mp_names: |
||||
if mp not in trajectories: |
||||
continue |
||||
data = trajectories[mp] |
||||
windows = data["windows"] |
||||
drifts_cumulative = [0.0] + list(np.cumsum(data["drift"])) |
||||
# Plot cumulative drift per window transition |
||||
x_labels = windows[: len(drifts_cumulative)] |
||||
fig.add_trace( |
||||
go.Scatter( |
||||
x=x_labels, |
||||
y=drifts_cumulative, |
||||
mode="lines+markers", |
||||
name=mp, |
||||
) |
||||
) |
||||
|
||||
fig.update_layout( |
||||
title="MP Political Drift Over Time (Cumulative)", |
||||
xaxis_title="Window", |
||||
yaxis_title="Cumulative Drift", |
||||
) |
||||
fig.write_html(output_path, include_plotlyjs="cdn") |
||||
_logger.info("Trajectory plot written to %s", output_path) |
||||
return output_path |
||||
|
||||
|
||||
def plot_political_axis( |
||||
scores: Dict[str, float], |
||||
party_of: Optional[Dict[str, str]] = None, |
||||
window_id: Optional[str] = None, |
||||
n_top: int = 30, |
||||
output_path: str = "analysis_political_axis.html", |
||||
) -> str: |
||||
"""Horizontal bar chart of MP scores on the ideological axis. |
||||
|
||||
Args: |
||||
scores: {mp_name: score} from political_axis module |
||||
party_of: Optional {mp_name: party} for colour-coding |
||||
window_id: Window label for the title |
||||
n_top: Show only the top/bottom n MPs by score |
||||
output_path: Output HTML path |
||||
|
||||
Returns the output_path on success. |
||||
""" |
||||
go, px = _require_plotly() |
||||
|
||||
# Sort by score |
||||
sorted_items = sorted(scores.items(), key=lambda kv: kv[1]) |
||||
|
||||
# Take n_top from each end if list is large |
||||
if len(sorted_items) > 2 * n_top: |
||||
sorted_items = sorted_items[:n_top] + sorted_items[-n_top:] |
||||
|
||||
names = [item[0] for item in sorted_items] |
||||
vals = [item[1] for item in sorted_items] |
||||
colors = ( |
||||
[party_of.get(n, "Unknown") for n in names] |
||||
if party_of |
||||
else ["Unknown"] * len(names) |
||||
) |
||||
|
||||
title = "MP Ideological Axis Score" + (f" ({window_id})" if window_id else "") |
||||
|
||||
fig = px.bar( |
||||
x=vals, |
||||
y=names, |
||||
color=colors, |
||||
orientation="h", |
||||
title=title, |
||||
labels={"x": "Score (← left — right →)", "y": "MP", "color": "Party"}, |
||||
) |
||||
fig.update_layout(yaxis={"categoryorder": "total ascending"}) |
||||
fig.write_html(output_path, include_plotlyjs="cdn") |
||||
_logger.info("Political axis chart written to %s", output_path) |
||||
return output_path |
||||
@ -0,0 +1,261 @@ |
||||
"""CLI orchestrator for the parliamentary embedding pipeline. |
||||
|
||||
Runs all phases in sequence: |
||||
1. fetch_mp_metadata — pull MP party + tenure from OData |
||||
2. extract_mp_votes — parse voting_results JSON → mp_votes rows |
||||
3. svd per window — build vote matrix, SVD, Procrustes-align |
||||
4. text embeddings — fill any gaps in the embeddings table |
||||
5. fuse per window — concatenate SVD + text vectors → fused_embeddings |
||||
|
||||
Usage: |
||||
uv run python -m pipeline.run_pipeline [options] |
||||
|
||||
Options: |
||||
--db-path PATH Path to the DuckDB file (default: data/motions.db) |
||||
--start-date DATE Window start (YYYY-MM-DD, default: 2 years ago) |
||||
--end-date DATE Window end (YYYY-MM-DD, default: today) |
||||
--window-size {quarterly,annual} Time window granularity (default: quarterly) |
||||
--svd-k INT SVD dimensionality (default: 50) |
||||
--text-model TEXT Text embedding model name (default: from ai_provider) |
||||
--skip-metadata Skip fetching MP metadata from OData |
||||
--skip-extract Skip extracting MP votes from voting_results |
||||
--skip-svd Skip SVD computation |
||||
--skip-text Skip text embedding gap-fill |
||||
--skip-fusion Skip vector fusion |
||||
--dry-run Print actions but make no DB writes |
||||
""" |
||||
|
||||
import argparse |
||||
import calendar |
||||
import logging |
||||
import sys |
||||
from datetime import date, timedelta |
||||
from typing import List, Tuple |
||||
|
||||
from database import MotionDatabase |
||||
|
||||
_logger = logging.getLogger(__name__) |
||||
|
||||
|
||||
def _generate_windows( |
||||
start: date, end: date, granularity: str |
||||
) -> List[Tuple[str, str, str]]: |
||||
"""Return list of (window_id, start_str, end_str) tuples. |
||||
|
||||
window_id format: |
||||
quarterly → "2024-Q1", "2024-Q2", … |
||||
annual → "2024" |
||||
""" |
||||
windows = [] |
||||
cursor = date(start.year, start.month, 1) |
||||
|
||||
if granularity == "annual": |
||||
cursor = date(start.year, 1, 1) |
||||
while cursor <= end: |
||||
year_end = date(cursor.year, 12, 31) |
||||
w_end = min(year_end, end) |
||||
windows.append((str(cursor.year), cursor.isoformat(), w_end.isoformat())) |
||||
cursor = date(cursor.year + 1, 1, 1) |
||||
else: |
||||
# quarterly |
||||
quarter_starts = {1: 1, 2: 4, 3: 7, 4: 10} |
||||
quarter_ends = {1: 3, 2: 6, 3: 9, 4: 12} |
||||
|
||||
# Align cursor to quarter start |
||||
q = (cursor.month - 1) // 3 + 1 |
||||
cursor = date(cursor.year, quarter_starts[q], 1) |
||||
|
||||
while cursor <= end: |
||||
q = (cursor.month - 1) // 3 + 1 |
||||
q_end_month = quarter_ends[q] |
||||
last_day = calendar.monthrange(cursor.year, q_end_month)[1] |
||||
q_end = date(cursor.year, q_end_month, last_day) |
||||
w_end = min(q_end, end) |
||||
window_id = f"{cursor.year}-Q{q}" |
||||
windows.append((window_id, cursor.isoformat(), w_end.isoformat())) |
||||
cursor = q_end + timedelta(days=1) |
||||
|
||||
return windows |
||||
|
||||
|
||||
def run(args: argparse.Namespace) -> int: |
||||
"""Execute the pipeline. Returns exit code (0 = success).""" |
||||
logging.basicConfig( |
||||
level=logging.INFO, |
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
||||
) |
||||
|
||||
db_path = args.db_path |
||||
dry_run = args.dry_run |
||||
|
||||
if dry_run: |
||||
_logger.info("DRY RUN — no writes will be made") |
||||
|
||||
# Resolve date range |
||||
end_date = date.fromisoformat(args.end_date) if args.end_date else date.today() |
||||
start_date = ( |
||||
date.fromisoformat(args.start_date) |
||||
if args.start_date |
||||
else end_date - timedelta(days=730) |
||||
) |
||||
|
||||
_logger.info( |
||||
"Pipeline run: %s → %s (%s windows), db=%s", |
||||
start_date, |
||||
end_date, |
||||
args.window_size, |
||||
db_path, |
||||
) |
||||
|
||||
db = MotionDatabase(db_path) |
||||
|
||||
# ── Phase 1: MP metadata ──────────────────────────────────────────────── |
||||
if not args.skip_metadata: |
||||
_logger.info("Phase 1: fetching MP metadata from OData") |
||||
if not dry_run: |
||||
from pipeline.fetch_mp_metadata import fetch_mp_metadata |
||||
|
||||
fetched, skipped = fetch_mp_metadata(db) |
||||
_logger.info(" mp_metadata: fetched=%d skipped=%d", fetched, skipped) |
||||
else: |
||||
_logger.info(" [dry-run] would call fetch_mp_metadata(db)") |
||||
else: |
||||
_logger.info("Phase 1: skipped (--skip-metadata)") |
||||
|
||||
# ── Phase 2: Extract MP votes ──────────────────────────────────────────── |
||||
if not args.skip_extract: |
||||
_logger.info("Phase 2: extracting MP votes from voting_results") |
||||
if not dry_run: |
||||
from pipeline.extract_mp_votes import extract_mp_votes |
||||
|
||||
inserted, skipped = extract_mp_votes(db) |
||||
_logger.info( |
||||
" mp_votes: inserted=%d motions skipped=%d", inserted, skipped |
||||
) |
||||
else: |
||||
_logger.info(" [dry-run] would call extract_mp_votes(db)") |
||||
else: |
||||
_logger.info("Phase 2: skipped (--skip-extract)") |
||||
|
||||
# ── Phase 3: SVD per window ────────────────────────────────────────────── |
||||
if not args.skip_svd: |
||||
windows = _generate_windows(start_date, end_date, args.window_size) |
||||
_logger.info("Phase 3: SVD for %d windows (k=%d)", len(windows), args.svd_k) |
||||
from pipeline.svd_pipeline import run_svd_for_window |
||||
|
||||
for window_id, w_start, w_end in windows: |
||||
_logger.info(" window %s: %s → %s", window_id, w_start, w_end) |
||||
if not dry_run: |
||||
result = run_svd_for_window( |
||||
db=db, |
||||
window_id=window_id, |
||||
start_date=w_start, |
||||
end_date=w_end, |
||||
k=args.svd_k, |
||||
) |
||||
_logger.info( |
||||
" k_used=%d stored_mp=%d stored_motion=%d", |
||||
result["k_used"], |
||||
result["stored_mp"], |
||||
result["stored_motion"], |
||||
) |
||||
else: |
||||
_logger.info(" [dry-run] would run SVD for window %s", window_id) |
||||
else: |
||||
_logger.info("Phase 3: skipped (--skip-svd)") |
||||
|
||||
# ── Phase 4: Text embeddings ────────────────────────────────────────────── |
||||
if not args.skip_text: |
||||
_logger.info("Phase 4: ensuring text embeddings") |
||||
if not dry_run: |
||||
from pipeline.text_pipeline import ensure_text_embeddings |
||||
|
||||
stored, existing, no_text, errors = ensure_text_embeddings( |
||||
db_path=db_path, model=args.text_model |
||||
) |
||||
_logger.info( |
||||
" embeddings: stored=%d existing=%d no_text=%d errors=%d", |
||||
stored, |
||||
existing, |
||||
no_text, |
||||
errors, |
||||
) |
||||
else: |
||||
_logger.info(" [dry-run] would call ensure_text_embeddings") |
||||
else: |
||||
_logger.info("Phase 4: skipped (--skip-text)") |
||||
|
||||
# ── Phase 5: Fusion per window ──────────────────────────────────────────── |
||||
if not args.skip_fusion: |
||||
windows = _generate_windows(start_date, end_date, args.window_size) |
||||
_logger.info("Phase 5: fusing vectors for %d windows", len(windows)) |
||||
from pipeline.fusion import fuse_for_window |
||||
|
||||
for window_id, _w_start, _w_end in windows: |
||||
if not dry_run: |
||||
result = fuse_for_window( |
||||
window_id=window_id, |
||||
db_path=db_path, |
||||
model=args.text_model, |
||||
) |
||||
_logger.info( |
||||
" window %s: fused=%d skipped_no_svd=%d skipped_no_text=%d", |
||||
window_id, |
||||
result["fused"], |
||||
result.get("skipped_no_svd", 0), |
||||
result.get("skipped_no_text", 0), |
||||
) |
||||
else: |
||||
_logger.info(" [dry-run] would fuse window %s", window_id) |
||||
else: |
||||
_logger.info("Phase 5: skipped (--skip-fusion)") |
||||
|
||||
_logger.info("Pipeline complete.") |
||||
return 0 |
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser: |
||||
parser = argparse.ArgumentParser( |
||||
description="Parliamentary embedding pipeline orchestrator", |
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
||||
) |
||||
parser.add_argument( |
||||
"--db-path", default="data/motions.db", help="Path to DuckDB file" |
||||
) |
||||
parser.add_argument("--start-date", default=None, help="Window start YYYY-MM-DD") |
||||
parser.add_argument("--end-date", default=None, help="Window end YYYY-MM-DD") |
||||
parser.add_argument( |
||||
"--window-size", |
||||
choices=["quarterly", "annual"], |
||||
default="quarterly", |
||||
help="Time window granularity", |
||||
) |
||||
parser.add_argument("--svd-k", type=int, default=50, help="SVD dimensions") |
||||
parser.add_argument( |
||||
"--text-model", |
||||
default=None, |
||||
help="Text embedding model (default: ai_provider default)", |
||||
) |
||||
parser.add_argument( |
||||
"--skip-metadata", action="store_true", help="Skip MP metadata fetch" |
||||
) |
||||
parser.add_argument( |
||||
"--skip-extract", action="store_true", help="Skip MP vote extraction" |
||||
) |
||||
parser.add_argument("--skip-svd", action="store_true", help="Skip SVD computation") |
||||
parser.add_argument( |
||||
"--skip-text", action="store_true", help="Skip text embedding gap-fill" |
||||
) |
||||
parser.add_argument("--skip-fusion", action="store_true", help="Skip vector fusion") |
||||
parser.add_argument( |
||||
"--dry-run", |
||||
action="store_true", |
||||
help="Print what would happen without writing anything", |
||||
) |
||||
return parser |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
parser = build_parser() |
||||
args = parser.parse_args() |
||||
sys.exit(run(args)) |
||||
@ -0,0 +1,195 @@ |
||||
"""Tests for analysis modules: political_axis, trajectory, clustering.""" |
||||
|
||||
import json |
||||
import numpy as np |
||||
import pytest |
||||
|
||||
duckdb = pytest.importorskip("duckdb") |
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────────── |
||||
|
||||
|
||||
def _setup_svd_vectors(db_path: str, window_ids_mp_vecs: dict): |
||||
"""Insert synthetic MP SVD vectors into svd_vectors table. |
||||
|
||||
window_ids_mp_vecs: {window_id: {mp_name: np.ndarray}} |
||||
""" |
||||
conn = duckdb.connect(db_path) |
||||
conn.execute( |
||||
""" |
||||
CREATE TABLE IF NOT EXISTS svd_vectors ( |
||||
id INTEGER, |
||||
window_id TEXT, |
||||
entity_type TEXT, |
||||
entity_id TEXT, |
||||
vector JSON, |
||||
model TEXT, |
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
||||
) |
||||
""" |
||||
) |
||||
for wid, mp_vecs in window_ids_mp_vecs.items(): |
||||
for mp_name, vec in mp_vecs.items(): |
||||
conn.execute( |
||||
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, 'mp', ?, ?, 'test')", |
||||
(wid, mp_name, json.dumps(vec.tolist())), |
||||
) |
||||
conn.close() |
||||
|
||||
|
||||
def _setup_mp_metadata(db_path: str, mp_party: dict): |
||||
"""Insert synthetic MP metadata rows.""" |
||||
conn = duckdb.connect(db_path) |
||||
conn.execute( |
||||
""" |
||||
CREATE TABLE IF NOT EXISTS mp_metadata ( |
||||
mp_name TEXT, |
||||
party TEXT, |
||||
van DATE, |
||||
tot_en_met DATE, |
||||
persoon_id TEXT |
||||
) |
||||
""" |
||||
) |
||||
for mp_name, party in mp_party.items(): |
||||
conn.execute( |
||||
"INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)", |
||||
(mp_name, party), |
||||
) |
||||
conn.close() |
||||
|
||||
|
||||
# ── political_axis ──────────────────────────────────────────────────────────── |
||||
|
||||
|
||||
class TestPoliticalAxis: |
||||
def test_pca_axis_basic(self, tmp_path): |
||||
np.random.seed(42) |
||||
db_path = str(tmp_path / "test.db") |
||||
n_mps, k = 20, 5 |
||||
|
||||
# Create a low-rank set of MP vectors (they should have a clear first PC) |
||||
vecs = np.random.randn(n_mps, k) |
||||
mp_names = [f"MP_{i}" for i in range(n_mps)] |
||||
_setup_svd_vectors( |
||||
db_path, {"2024-Q1": {mp_names[i]: vecs[i] for i in range(n_mps)}} |
||||
) |
||||
|
||||
from analysis.political_axis import compute_pca_axis |
||||
|
||||
scores = compute_pca_axis(db_path, "2024-Q1") |
||||
assert len(scores) == n_mps |
||||
assert all(isinstance(v, float) for v in scores.values()) |
||||
# Scores should have non-trivial variance |
||||
vals = list(scores.values()) |
||||
assert np.std(vals) > 0.0 |
||||
|
||||
def test_pca_axis_too_few_mps(self, tmp_path): |
||||
db_path = str(tmp_path / "test.db") |
||||
_setup_svd_vectors(db_path, {"w1": {"MP_A": np.array([1.0, 0.0])}}) |
||||
|
||||
from analysis.political_axis import compute_pca_axis |
||||
|
||||
scores = compute_pca_axis(db_path, "w1") |
||||
assert scores == {} |
||||
|
||||
def test_anchor_axis_basic(self, tmp_path): |
||||
db_path = str(tmp_path / "test.db") |
||||
# Two clusters clearly separated on dim 0 |
||||
left_vec = np.array([-2.0, 0.0, 0.0]) |
||||
right_vec = np.array([2.0, 0.0, 0.0]) |
||||
mp_vecs = { |
||||
"Left_A": left_vec + np.array([0.1, 0.0, 0.0]), |
||||
"Left_B": left_vec - np.array([0.1, 0.0, 0.0]), |
||||
"Right_A": right_vec + np.array([0.1, 0.0, 0.0]), |
||||
"Right_B": right_vec - np.array([0.1, 0.0, 0.0]), |
||||
"Centre": np.array([0.0, 0.0, 0.0]), |
||||
} |
||||
_setup_svd_vectors(db_path, {"w1": mp_vecs}) |
||||
_setup_mp_metadata( |
||||
db_path, |
||||
{ |
||||
"Left_A": "SP", |
||||
"Left_B": "SP", |
||||
"Right_A": "VVD", |
||||
"Right_B": "VVD", |
||||
"Centre": "D66", |
||||
}, |
||||
) |
||||
|
||||
from analysis.political_axis import compute_anchor_axis |
||||
|
||||
scores = compute_anchor_axis( |
||||
db_path, "w1", left_parties=["SP"], right_parties=["VVD"] |
||||
) |
||||
assert len(scores) == 5 |
||||
# Left MPs should have negative scores, Right MPs positive |
||||
assert scores["Left_A"] < scores["Right_A"] |
||||
assert scores["Left_B"] < scores["Right_B"] |
||||
|
||||
|
||||
# ── trajectory ─────────────────────────────────────────────────────────────── |
||||
|
||||
|
||||
class TestTrajectory: |
||||
def test_basic_trajectory(self, tmp_path): |
||||
np.random.seed(0) |
||||
db_path = str(tmp_path / "test.db") |
||||
|
||||
vec_w1 = {"MP_A": np.array([1.0, 0.0]), "MP_B": np.array([0.0, 1.0])} |
||||
vec_w2 = { |
||||
"MP_A": np.array([1.5, 0.5]), |
||||
"MP_B": np.array([0.0, 1.0]), |
||||
"MP_C": np.array([2.0, 2.0]), |
||||
} |
||||
_setup_svd_vectors(db_path, {"2024-Q1": vec_w1, "2024-Q2": vec_w2}) |
||||
|
||||
from analysis.trajectory import compute_trajectories, top_drifters |
||||
|
||||
traj = compute_trajectories(db_path) |
||||
|
||||
# Only MPs appearing in >= 2 windows |
||||
assert "MP_A" in traj |
||||
assert "MP_B" in traj |
||||
assert "MP_C" not in traj # only in one window |
||||
|
||||
assert len(traj["MP_A"]["drift"]) == 1 |
||||
assert traj["MP_A"]["total_drift"] > 0.0 |
||||
|
||||
# MP_B didn't move — drift should be 0 |
||||
assert traj["MP_B"]["total_drift"] == pytest.approx(0.0) |
||||
|
||||
drifters = top_drifters(traj, n=5) |
||||
assert drifters[0]["mp_name"] == "MP_A" |
||||
|
||||
def test_fewer_than_2_windows(self, tmp_path): |
||||
db_path = str(tmp_path / "test.db") |
||||
_setup_svd_vectors(db_path, {"2024-Q1": {"MP_A": np.array([1.0, 2.0])}}) |
||||
|
||||
from analysis.trajectory import compute_trajectories |
||||
|
||||
traj = compute_trajectories(db_path) |
||||
assert traj == {} |
||||
|
||||
|
||||
# ── clustering ──────────────────────────────────────────────────────────────── |
||||
|
||||
|
||||
class TestClustering: |
||||
def test_cluster_kmeans_basic(self): |
||||
from analysis.clustering import cluster_kmeans |
||||
import numpy as np |
||||
|
||||
coords = np.random.randn(20, 2) |
||||
labels = cluster_kmeans(coords, n_clusters=3) |
||||
assert len(labels) == 20 |
||||
assert set(labels).issubset({0, 1, 2}) |
||||
|
||||
def test_cluster_kmeans_fewer_points_than_clusters(self): |
||||
from analysis.clustering import cluster_kmeans |
||||
|
||||
coords = np.array([[0.0, 0.0], [1.0, 1.0]]) |
||||
labels = cluster_kmeans(coords, n_clusters=5) |
||||
# Should not crash; n_clusters clamped to len(coords) |
||||
assert len(labels) == 2 |
||||
@ -0,0 +1,113 @@ |
||||
"""Tests for pipeline/run_pipeline.py""" |
||||
|
||||
import argparse |
||||
import sys |
||||
import pytest |
||||
|
||||
from pipeline.run_pipeline import _generate_windows, build_parser, run |
||||
from datetime import date |
||||
|
||||
|
||||
def test_generate_windows_quarterly(): |
||||
start = date(2024, 1, 1) |
||||
end = date(2024, 12, 31) |
||||
windows = _generate_windows(start, end, "quarterly") |
||||
|
||||
assert len(windows) == 4 |
||||
ids = [w[0] for w in windows] |
||||
assert ids == ["2024-Q1", "2024-Q2", "2024-Q3", "2024-Q4"] |
||||
|
||||
# Q1 bounds |
||||
assert windows[0][1] == "2024-01-01" |
||||
assert windows[0][2] == "2024-03-31" |
||||
|
||||
# Q4 bounds |
||||
assert windows[3][1] == "2024-10-01" |
||||
assert windows[3][2] == "2024-12-31" |
||||
|
||||
|
||||
def test_generate_windows_annual(): |
||||
start = date(2022, 6, 1) |
||||
end = date(2024, 3, 31) |
||||
windows = _generate_windows(start, end, "annual") |
||||
|
||||
assert len(windows) == 3 |
||||
ids = [w[0] for w in windows] |
||||
assert ids == ["2022", "2023", "2024"] |
||||
|
||||
# 2024 should end at end_date, not Dec 31 |
||||
assert windows[2][2] == "2024-03-31" |
||||
|
||||
|
||||
def test_generate_windows_mid_quarter_start(): |
||||
"""Starting in the middle of Q2 should still produce a full Q2 window.""" |
||||
start = date(2024, 5, 15) |
||||
end = date(2024, 9, 30) |
||||
windows = _generate_windows(start, end, "quarterly") |
||||
|
||||
ids = [w[0] for w in windows] |
||||
assert "2024-Q2" in ids |
||||
assert "2024-Q3" in ids |
||||
|
||||
|
||||
def test_build_parser_defaults(): |
||||
parser = build_parser() |
||||
args = parser.parse_args([]) |
||||
assert args.db_path == "data/motions.db" |
||||
assert args.window_size == "quarterly" |
||||
assert args.svd_k == 50 |
||||
assert args.dry_run is False |
||||
|
||||
|
||||
def test_run_dry_run(tmp_path, monkeypatch): |
||||
"""Dry-run should log actions and return 0 without touching the DB.""" |
||||
db_path = str(tmp_path / "motions.db") |
||||
|
||||
# Create minimal DB so MotionDatabase initialises |
||||
from database import MotionDatabase |
||||
|
||||
MotionDatabase(db_path) |
||||
|
||||
args = argparse.Namespace( |
||||
db_path=db_path, |
||||
start_date="2024-01-01", |
||||
end_date="2024-03-31", |
||||
window_size="quarterly", |
||||
svd_k=10, |
||||
text_model=None, |
||||
skip_metadata=False, |
||||
skip_extract=False, |
||||
skip_svd=False, |
||||
skip_text=False, |
||||
skip_fusion=False, |
||||
dry_run=True, |
||||
) |
||||
|
||||
exit_code = run(args) |
||||
assert exit_code == 0 |
||||
|
||||
|
||||
def test_run_skip_all(tmp_path): |
||||
"""Skipping all phases should still return 0.""" |
||||
db_path = str(tmp_path / "motions.db") |
||||
from database import MotionDatabase |
||||
|
||||
MotionDatabase(db_path) |
||||
|
||||
args = argparse.Namespace( |
||||
db_path=db_path, |
||||
start_date="2024-01-01", |
||||
end_date="2024-03-31", |
||||
window_size="quarterly", |
||||
svd_k=10, |
||||
text_model=None, |
||||
skip_metadata=True, |
||||
skip_extract=True, |
||||
skip_svd=True, |
||||
skip_text=True, |
||||
skip_fusion=True, |
||||
dry_run=False, |
||||
) |
||||
|
||||
exit_code = run(args) |
||||
assert exit_code == 0 |
||||
Loading…
Reference in new issue