diff --git a/analysis/__init__.py b/analysis/__init__.py new file mode 100644 index 0000000..6772f98 --- /dev/null +++ b/analysis/__init__.py @@ -0,0 +1,8 @@ +"""Analysis modules for the parliamentary embedding pipeline. + +Modules: + political_axis — project MP SVD vectors onto ideological axis + trajectory — compute MP drift across aligned windows + clustering — UMAP dimensionality reduction + cluster labelling + visualize — Plotly interactive plots (outputs self-contained HTML) +""" diff --git a/analysis/clustering.py b/analysis/clustering.py new file mode 100644 index 0000000..92d35fb --- /dev/null +++ b/analysis/clustering.py @@ -0,0 +1,130 @@ +"""clustering.py — UMAP dimensionality reduction on fused embeddings. + +Reduces fused motion embeddings to 2D (or 3D) for visualisation, +and optionally labels clusters using KMeans. + +Requires: umap-learn, scikit-learn (for KMeans) +""" + +import json +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +import duckdb + +_logger = logging.getLogger(__name__) + + +def _load_fused_vectors( + db_path: str, window_id: Optional[str] = None +) -> Tuple[List[int], List[str], np.ndarray]: + """Load fused embeddings from the DB. + + Returns (motion_ids, window_ids, matrix). + Optionally filter by window_id. + """ + conn = duckdb.connect(db_path) + if window_id: + rows = conn.execute( + "SELECT motion_id, window_id, vector FROM fused_embeddings WHERE window_id = ?", + (window_id,), + ).fetchall() + else: + rows = conn.execute( + "SELECT motion_id, window_id, vector FROM fused_embeddings ORDER BY window_id, motion_id" + ).fetchall() + conn.close() + + motion_ids, window_ids, vectors = [], [], [] + for motion_id, wid, vec_json in rows: + try: + vec = json.loads(vec_json) + motion_ids.append(int(motion_id)) + window_ids.append(wid) + vectors.append(vec) + except Exception: + _logger.warning("Could not parse fused vector for motion %s", motion_id) + + if not vectors: + return [], [], np.zeros((0, 0)) + + # Pad to common length if needed (shouldn't happen if pipeline is consistent) + max_len = max(len(v) for v in vectors) + mat = np.zeros((len(vectors), max_len), dtype=float) + for i, v in enumerate(vectors): + mat[i, : len(v)] = v + + return motion_ids, window_ids, mat + + +def run_umap( + db_path: str, + window_id: Optional[str] = None, + n_components: int = 2, + n_neighbors: int = 15, + min_dist: float = 0.1, + random_state: int = 42, +) -> Dict: + """Run UMAP on fused embeddings and return 2D/3D coordinates. + + Returns: + { + "motion_ids": [...], + "window_ids": [...], + "coords": [[x, y], ...], # or [x, y, z] if n_components=3 + "n_components": int, + } + """ + try: + import umap + except ImportError: + _logger.error("umap-learn is not installed; cannot run UMAP") + return {} + + motion_ids, window_ids, mat = _load_fused_vectors(db_path, window_id) + if mat.size == 0: + _logger.warning("No fused embeddings found for window_id=%s", window_id) + return {} + + if mat.shape[0] < n_neighbors + 1: + # UMAP requires at least n_neighbors+1 samples + n_neighbors = max(2, mat.shape[0] - 1) + _logger.warning( + "Reduced n_neighbors to %d due to small dataset (%d samples)", + n_neighbors, + mat.shape[0], + ) + + reducer = umap.UMAP( + n_components=n_components, + n_neighbors=n_neighbors, + min_dist=min_dist, + random_state=random_state, + ) + coords = reducer.fit_transform(mat) + + return { + "motion_ids": motion_ids, + "window_ids": window_ids, + "coords": coords.tolist(), + "n_components": n_components, + } + + +def cluster_kmeans( + coords: np.ndarray, n_clusters: int = 8, random_state: int = 42 +) -> np.ndarray: + """Run KMeans on 2D/3D UMAP coordinates. + + Returns array of integer cluster labels (length = len(coords)). + """ + try: + from sklearn.cluster import KMeans + except ImportError: + _logger.error("scikit-learn is not installed; cannot run KMeans") + return np.zeros(len(coords), dtype=int) + + n_clusters = min(n_clusters, len(coords)) + km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto") + return km.fit_predict(coords) diff --git a/analysis/political_axis.py b/analysis/political_axis.py new file mode 100644 index 0000000..b90bf56 --- /dev/null +++ b/analysis/political_axis.py @@ -0,0 +1,125 @@ +"""political_axis.py — Project MP SVD vectors onto an ideological axis. + +Two modes: + 1. PCA mode (default): compute the first principal component of all MP SVD + vectors for a window and project each MP onto it. The sign is arbitrary + but consistent within a window. + + 2. Anchor mode: define the axis as the vector from the centroid of + ``left_parties`` to the centroid of ``right_parties``. Project all MPs + onto this normalised anchor axis. + +Both modes return a dict mapping mp_name → scalar score for the given window. +""" + +import json +import logging +from typing import Dict, List, Optional + +import numpy as np +import duckdb + +_logger = logging.getLogger(__name__) + + +def _load_mp_svd_vectors(db_path: str, window_id: str) -> Dict[str, np.ndarray]: + """Load all MP SVD vectors for a window from svd_vectors table.""" + conn = duckdb.connect(db_path) + rows = conn.execute( + "SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'", + (window_id,), + ).fetchall() + conn.close() + + result = {} + for mp_name, vec_json in rows: + try: + result[mp_name] = np.array(json.loads(vec_json), dtype=float) + except Exception: + _logger.warning("Could not parse SVD vector for MP %s", mp_name) + return result + + +def compute_pca_axis(db_path: str, window_id: str) -> Dict[str, float]: + """Project MP SVD vectors onto their first principal component. + + Returns {mp_name: score}. Returns empty dict if fewer than 2 MPs. + """ + mp_vecs = _load_mp_svd_vectors(db_path, window_id) + if len(mp_vecs) < 2: + _logger.warning( + "window %s has only %d MPs; skipping PCA axis", window_id, len(mp_vecs) + ) + return {} + + names = list(mp_vecs.keys()) + mat = np.vstack([mp_vecs[n] for n in names]) # (n_mps, k) + + # Centre + mat_centred = mat - mat.mean(axis=0) + + # First PC via SVD + try: + _, _, Vt = np.linalg.svd(mat_centred, full_matrices=False) + axis = Vt[0] # (k,) + except np.linalg.LinAlgError: + _logger.exception("SVD failed in compute_pca_axis for window %s", window_id) + return {} + + projections = mat_centred.dot(axis) + return {name: float(score) for name, score in zip(names, projections)} + + +def compute_anchor_axis( + db_path: str, + window_id: str, + left_parties: List[str], + right_parties: List[str], +) -> Dict[str, float]: + """Project MP SVD vectors onto a left↔right anchor axis. + + The axis runs from the centroid of ``left_parties`` to the centroid of + ``right_parties``. Positive scores are toward the right. + + Returns {mp_name: score}. + """ + mp_vecs = _load_mp_svd_vectors(db_path, window_id) + if not mp_vecs: + return {} + + # Load party affiliation for this window from mp_metadata + conn = duckdb.connect(db_path) + rows = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchall() + conn.close() + party_of = {mp: party for mp, party in rows} + + left_vecs = [ + mp_vecs[mp] + for mp, party in party_of.items() + if party in left_parties and mp in mp_vecs + ] + right_vecs = [ + mp_vecs[mp] + for mp, party in party_of.items() + if party in right_parties and mp in mp_vecs + ] + + if not left_vecs or not right_vecs: + _logger.warning( + "window %s: insufficient anchor parties (left=%d, right=%d)", + window_id, + len(left_vecs), + len(right_vecs), + ) + return {} + + left_centroid = np.mean(left_vecs, axis=0) + right_centroid = np.mean(right_vecs, axis=0) + axis = right_centroid - left_centroid + norm = np.linalg.norm(axis) + if norm < 1e-10: + _logger.warning("Anchor axis has near-zero norm for window %s", window_id) + return {} + axis = axis / norm + + return {name: float(np.dot(vec, axis)) for name, vec in mp_vecs.items()} diff --git a/analysis/trajectory.py b/analysis/trajectory.py new file mode 100644 index 0000000..44bbdd5 --- /dev/null +++ b/analysis/trajectory.py @@ -0,0 +1,123 @@ +"""trajectory.py — Compute MP political drift across aligned time windows. + +For each MP that appears in multiple windows, computes: + - The aligned SVD vector per window + - The Euclidean distance between consecutive windows (drift) + - Total cumulative drift + +Returns a dict keyed by mp_name containing per-window positions and drift scores. +""" + +import json +import logging +from typing import Dict, List, Optional + +import numpy as np +import duckdb + +_logger = logging.getLogger(__name__) + + +def _load_window_ids(db_path: str) -> List[str]: + """Return all distinct window IDs from svd_vectors, in lexicographic order.""" + conn = duckdb.connect(db_path) + rows = conn.execute( + "SELECT DISTINCT window_id FROM svd_vectors WHERE entity_type = 'mp' ORDER BY window_id" + ).fetchall() + conn.close() + return [r[0] for r in rows] + + +def _load_mp_vectors_for_window(db_path: str, window_id: str) -> Dict[str, np.ndarray]: + conn = duckdb.connect(db_path) + rows = conn.execute( + "SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'mp'", + (window_id,), + ).fetchall() + conn.close() + result = {} + for mp_name, vec_json in rows: + try: + result[mp_name] = np.array(json.loads(vec_json), dtype=float) + except Exception: + _logger.warning( + "Could not parse vector for MP %s window %s", mp_name, window_id + ) + return result + + +def compute_trajectories( + db_path: str, + window_ids: Optional[List[str]] = None, +) -> Dict[str, Dict]: + """Compute per-MP trajectories across windows. + + Returns: + { + mp_name: { + "windows": [window_id, ...], + "vectors": [[...], ...], # one vector per window + "drift": [float, ...], # consecutive Euclidean distances + "total_drift": float, + } + } + Only MPs present in at least 2 windows are included. + """ + if window_ids is None: + window_ids = _load_window_ids(db_path) + + if len(window_ids) < 2: + _logger.info("Fewer than 2 windows — no trajectories to compute") + return {} + + # Collect per-window vectors for each MP + mp_data: Dict[str, Dict] = {} + + for wid in window_ids: + vecs = _load_mp_vectors_for_window(db_path, wid) + for mp_name, vec in vecs.items(): + if mp_name not in mp_data: + mp_data[mp_name] = {"windows": [], "vectors": []} + mp_data[mp_name]["windows"].append(wid) + mp_data[mp_name]["vectors"].append(vec) + + # Compute drift for MPs with >= 2 windows + result = {} + for mp_name, data in mp_data.items(): + if len(data["windows"]) < 2: + continue + vecs = data["vectors"] + drifts = [ + float(np.linalg.norm(vecs[i + 1] - vecs[i])) for i in range(len(vecs) - 1) + ] + result[mp_name] = { + "windows": data["windows"], + "vectors": [v.tolist() for v in vecs], + "drift": drifts, + "total_drift": float(sum(drifts)), + } + + _logger.info( + "Trajectories computed for %d MPs across %d windows", + len(result), + len(window_ids), + ) + return result + + +def top_drifters(trajectories: Dict[str, Dict], n: int = 10) -> List[Dict]: + """Return the top-n MPs by total drift, sorted descending. + + Each entry: {"mp_name": ..., "total_drift": ..., "windows": [...]} + """ + ranked = sorted( + trajectories.items(), key=lambda kv: kv[1]["total_drift"], reverse=True + ) + return [ + { + "mp_name": mp, + "total_drift": data["total_drift"], + "windows": data["windows"], + } + for mp, data in ranked[:n] + ] diff --git a/analysis/visualize.py b/analysis/visualize.py new file mode 100644 index 0000000..595624a --- /dev/null +++ b/analysis/visualize.py @@ -0,0 +1,163 @@ +"""visualize.py — Plotly interactive plots for parliamentary embeddings. + +Produces self-contained HTML files. + +Functions: + plot_umap_scatter — 2D scatter of fused motion embeddings, coloured by cluster + plot_mp_trajectory — Line plot of MP drift across windows + plot_political_axis — Bar chart of MP scores on the ideological axis +""" + +import logging +from typing import Dict, List, Optional + +import numpy as np + +_logger = logging.getLogger(__name__) + + +def _require_plotly(): + try: + import plotly.graph_objects as go + import plotly.express as px + + return go, px + except ImportError: + raise ImportError("plotly is not installed. Install it with: uv add plotly") + + +def plot_umap_scatter( + motion_ids: List[int], + coords: List[List[float]], + labels: Optional[List[int]] = None, + window_id: Optional[str] = None, + output_path: str = "analysis_umap.html", +) -> str: + """Produce a 2D scatter plot of UMAP-reduced fused embeddings. + + Args: + motion_ids: Motion IDs (used as hover labels) + coords: List of [x, y] coordinates + labels: Optional cluster labels (integer per motion) + window_id: Window label for the plot title + output_path: Where to write the self-contained HTML + + Returns the output_path on success. + """ + go, px = _require_plotly() + + xs = [c[0] for c in coords] + ys = [c[1] for c in coords] + color = labels if labels is not None else [0] * len(motion_ids) + title = f"UMAP — fused motion embeddings" + (f" ({window_id})" if window_id else "") + + fig = px.scatter( + x=xs, + y=ys, + color=[str(c) for c in color], + hover_name=[str(mid) for mid in motion_ids], + title=title, + labels={"x": "UMAP-1", "y": "UMAP-2", "color": "Cluster"}, + ) + fig.write_html(output_path, include_plotlyjs="cdn") + _logger.info("UMAP scatter written to %s", output_path) + return output_path + + +def plot_mp_trajectory( + trajectories: Dict[str, Dict], + mp_names: Optional[List[str]] = None, + output_path: str = "analysis_trajectory.html", +) -> str: + """Line plot of MP drift across time windows. + + Args: + trajectories: Output of analysis.trajectory.compute_trajectories() + mp_names: Subset of MPs to plot (default: all) + output_path: Output HTML file path + + Returns the output_path on success. + """ + go, px = _require_plotly() + + if mp_names is None: + mp_names = list(trajectories.keys()) + + fig = go.Figure() + + for mp in mp_names: + if mp not in trajectories: + continue + data = trajectories[mp] + windows = data["windows"] + drifts_cumulative = [0.0] + list(np.cumsum(data["drift"])) + # Plot cumulative drift per window transition + x_labels = windows[: len(drifts_cumulative)] + fig.add_trace( + go.Scatter( + x=x_labels, + y=drifts_cumulative, + mode="lines+markers", + name=mp, + ) + ) + + fig.update_layout( + title="MP Political Drift Over Time (Cumulative)", + xaxis_title="Window", + yaxis_title="Cumulative Drift", + ) + fig.write_html(output_path, include_plotlyjs="cdn") + _logger.info("Trajectory plot written to %s", output_path) + return output_path + + +def plot_political_axis( + scores: Dict[str, float], + party_of: Optional[Dict[str, str]] = None, + window_id: Optional[str] = None, + n_top: int = 30, + output_path: str = "analysis_political_axis.html", +) -> str: + """Horizontal bar chart of MP scores on the ideological axis. + + Args: + scores: {mp_name: score} from political_axis module + party_of: Optional {mp_name: party} for colour-coding + window_id: Window label for the title + n_top: Show only the top/bottom n MPs by score + output_path: Output HTML path + + Returns the output_path on success. + """ + go, px = _require_plotly() + + # Sort by score + sorted_items = sorted(scores.items(), key=lambda kv: kv[1]) + + # Take n_top from each end if list is large + if len(sorted_items) > 2 * n_top: + sorted_items = sorted_items[:n_top] + sorted_items[-n_top:] + + names = [item[0] for item in sorted_items] + vals = [item[1] for item in sorted_items] + colors = ( + [party_of.get(n, "Unknown") for n in names] + if party_of + else ["Unknown"] * len(names) + ) + + title = "MP Ideological Axis Score" + (f" ({window_id})" if window_id else "") + + fig = px.bar( + x=vals, + y=names, + color=colors, + orientation="h", + title=title, + labels={"x": "Score (← left — right →)", "y": "MP", "color": "Party"}, + ) + fig.update_layout(yaxis={"categoryorder": "total ascending"}) + fig.write_html(output_path, include_plotlyjs="cdn") + _logger.info("Political axis chart written to %s", output_path) + return output_path diff --git a/api_client.py b/api_client.py index b7fb1d8..b34d213 100644 --- a/api_client.py +++ b/api_client.py @@ -92,7 +92,12 @@ class TweedeKamerAPI: # Group records by Besluit_Id (decision/motion) motion_groups = defaultdict( - lambda: {"votes": {}, "besluit_id": None, "latest_date": None} + lambda: { + "votes": {}, + "mp_vote_parties": {}, + "besluit_id": None, + "latest_date": None, + } ) for record in records: @@ -120,6 +125,14 @@ class TweedeKamerAPI: motion_groups[besluit_id]["votes"][party_name] = vote motion_groups[besluit_id]["besluit_id"] = besluit_id + # For individual MPs (ActorNaam contains comma), also capture their party + if "," in party_name: + actor_fractie = record.get("ActorFractie") + if actor_fractie: + motion_groups[besluit_id]["mp_vote_parties"][party_name] = ( + actor_fractie + ) + # Track the latest date for this motion if ( not motion_groups[besluit_id]["latest_date"] @@ -166,6 +179,7 @@ class TweedeKamerAPI: motion_details["title"], motion_details["description"] ), "voting_results": voting_results, + "mp_vote_parties": motion_data["mp_vote_parties"], "winning_margin": winning_margin, "url": f"https://www.tweedekamer.nl/kamerstukken/stemmingsuitslagen/{besluit_id}", "externe_identifier": motion_details.get("externe_identifier"), diff --git a/database.py b/database.py index b794425..0490411 100644 --- a/database.py +++ b/database.py @@ -190,6 +190,33 @@ class MotionDatabase: ) conn.close() + + # Also insert mp_vote rows for individual MPs if party data is available. + # This only runs for brand-new motions (existing motions are rejected above), + # so there is no risk of duplicates — no existence check needed here. + mp_vote_parties = motion_data.get("mp_vote_parties", {}) + voting_results_raw = motion_data.get("voting_results", {}) + if mp_vote_parties: + conn2 = duckdb.connect(self.db_path) + row = conn2.execute( + "SELECT id FROM motions WHERE url = ? LIMIT 1", + (motion_data["url"],), + ).fetchone() + conn2.close() + motion_id = row[0] if row else None + + if motion_id is not None: + motion_date = motion_data.get("date", "") + for mp_name, party in mp_vote_parties.items(): + vote = voting_results_raw.get(mp_name, "afwezig") + self.insert_mp_vote( + motion_id=motion_id, + mp_name=mp_name, + party=party, + vote=vote, + date=motion_date, + ) + return True except Exception as e: diff --git a/pipeline/run_pipeline.py b/pipeline/run_pipeline.py new file mode 100644 index 0000000..8f855d9 --- /dev/null +++ b/pipeline/run_pipeline.py @@ -0,0 +1,261 @@ +"""CLI orchestrator for the parliamentary embedding pipeline. + +Runs all phases in sequence: + 1. fetch_mp_metadata — pull MP party + tenure from OData + 2. extract_mp_votes — parse voting_results JSON → mp_votes rows + 3. svd per window — build vote matrix, SVD, Procrustes-align + 4. text embeddings — fill any gaps in the embeddings table + 5. fuse per window — concatenate SVD + text vectors → fused_embeddings + +Usage: + uv run python -m pipeline.run_pipeline [options] + +Options: + --db-path PATH Path to the DuckDB file (default: data/motions.db) + --start-date DATE Window start (YYYY-MM-DD, default: 2 years ago) + --end-date DATE Window end (YYYY-MM-DD, default: today) + --window-size {quarterly,annual} Time window granularity (default: quarterly) + --svd-k INT SVD dimensionality (default: 50) + --text-model TEXT Text embedding model name (default: from ai_provider) + --skip-metadata Skip fetching MP metadata from OData + --skip-extract Skip extracting MP votes from voting_results + --skip-svd Skip SVD computation + --skip-text Skip text embedding gap-fill + --skip-fusion Skip vector fusion + --dry-run Print actions but make no DB writes +""" + +import argparse +import calendar +import logging +import sys +from datetime import date, timedelta +from typing import List, Tuple + +from database import MotionDatabase + +_logger = logging.getLogger(__name__) + + +def _generate_windows( + start: date, end: date, granularity: str +) -> List[Tuple[str, str, str]]: + """Return list of (window_id, start_str, end_str) tuples. + + window_id format: + quarterly → "2024-Q1", "2024-Q2", … + annual → "2024" + """ + windows = [] + cursor = date(start.year, start.month, 1) + + if granularity == "annual": + cursor = date(start.year, 1, 1) + while cursor <= end: + year_end = date(cursor.year, 12, 31) + w_end = min(year_end, end) + windows.append((str(cursor.year), cursor.isoformat(), w_end.isoformat())) + cursor = date(cursor.year + 1, 1, 1) + else: + # quarterly + quarter_starts = {1: 1, 2: 4, 3: 7, 4: 10} + quarter_ends = {1: 3, 2: 6, 3: 9, 4: 12} + + # Align cursor to quarter start + q = (cursor.month - 1) // 3 + 1 + cursor = date(cursor.year, quarter_starts[q], 1) + + while cursor <= end: + q = (cursor.month - 1) // 3 + 1 + q_end_month = quarter_ends[q] + last_day = calendar.monthrange(cursor.year, q_end_month)[1] + q_end = date(cursor.year, q_end_month, last_day) + w_end = min(q_end, end) + window_id = f"{cursor.year}-Q{q}" + windows.append((window_id, cursor.isoformat(), w_end.isoformat())) + cursor = q_end + timedelta(days=1) + + return windows + + +def run(args: argparse.Namespace) -> int: + """Execute the pipeline. Returns exit code (0 = success).""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + db_path = args.db_path + dry_run = args.dry_run + + if dry_run: + _logger.info("DRY RUN — no writes will be made") + + # Resolve date range + end_date = date.fromisoformat(args.end_date) if args.end_date else date.today() + start_date = ( + date.fromisoformat(args.start_date) + if args.start_date + else end_date - timedelta(days=730) + ) + + _logger.info( + "Pipeline run: %s → %s (%s windows), db=%s", + start_date, + end_date, + args.window_size, + db_path, + ) + + db = MotionDatabase(db_path) + + # ── Phase 1: MP metadata ──────────────────────────────────────────────── + if not args.skip_metadata: + _logger.info("Phase 1: fetching MP metadata from OData") + if not dry_run: + from pipeline.fetch_mp_metadata import fetch_mp_metadata + + fetched, skipped = fetch_mp_metadata(db) + _logger.info(" mp_metadata: fetched=%d skipped=%d", fetched, skipped) + else: + _logger.info(" [dry-run] would call fetch_mp_metadata(db)") + else: + _logger.info("Phase 1: skipped (--skip-metadata)") + + # ── Phase 2: Extract MP votes ──────────────────────────────────────────── + if not args.skip_extract: + _logger.info("Phase 2: extracting MP votes from voting_results") + if not dry_run: + from pipeline.extract_mp_votes import extract_mp_votes + + inserted, skipped = extract_mp_votes(db) + _logger.info( + " mp_votes: inserted=%d motions skipped=%d", inserted, skipped + ) + else: + _logger.info(" [dry-run] would call extract_mp_votes(db)") + else: + _logger.info("Phase 2: skipped (--skip-extract)") + + # ── Phase 3: SVD per window ────────────────────────────────────────────── + if not args.skip_svd: + windows = _generate_windows(start_date, end_date, args.window_size) + _logger.info("Phase 3: SVD for %d windows (k=%d)", len(windows), args.svd_k) + from pipeline.svd_pipeline import run_svd_for_window + + for window_id, w_start, w_end in windows: + _logger.info(" window %s: %s → %s", window_id, w_start, w_end) + if not dry_run: + result = run_svd_for_window( + db=db, + window_id=window_id, + start_date=w_start, + end_date=w_end, + k=args.svd_k, + ) + _logger.info( + " k_used=%d stored_mp=%d stored_motion=%d", + result["k_used"], + result["stored_mp"], + result["stored_motion"], + ) + else: + _logger.info(" [dry-run] would run SVD for window %s", window_id) + else: + _logger.info("Phase 3: skipped (--skip-svd)") + + # ── Phase 4: Text embeddings ────────────────────────────────────────────── + if not args.skip_text: + _logger.info("Phase 4: ensuring text embeddings") + if not dry_run: + from pipeline.text_pipeline import ensure_text_embeddings + + stored, existing, no_text, errors = ensure_text_embeddings( + db_path=db_path, model=args.text_model + ) + _logger.info( + " embeddings: stored=%d existing=%d no_text=%d errors=%d", + stored, + existing, + no_text, + errors, + ) + else: + _logger.info(" [dry-run] would call ensure_text_embeddings") + else: + _logger.info("Phase 4: skipped (--skip-text)") + + # ── Phase 5: Fusion per window ──────────────────────────────────────────── + if not args.skip_fusion: + windows = _generate_windows(start_date, end_date, args.window_size) + _logger.info("Phase 5: fusing vectors for %d windows", len(windows)) + from pipeline.fusion import fuse_for_window + + for window_id, _w_start, _w_end in windows: + if not dry_run: + result = fuse_for_window( + window_id=window_id, + db_path=db_path, + model=args.text_model, + ) + _logger.info( + " window %s: fused=%d skipped_no_svd=%d skipped_no_text=%d", + window_id, + result["fused"], + result.get("skipped_no_svd", 0), + result.get("skipped_no_text", 0), + ) + else: + _logger.info(" [dry-run] would fuse window %s", window_id) + else: + _logger.info("Phase 5: skipped (--skip-fusion)") + + _logger.info("Pipeline complete.") + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Parliamentary embedding pipeline orchestrator", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--db-path", default="data/motions.db", help="Path to DuckDB file" + ) + parser.add_argument("--start-date", default=None, help="Window start YYYY-MM-DD") + parser.add_argument("--end-date", default=None, help="Window end YYYY-MM-DD") + parser.add_argument( + "--window-size", + choices=["quarterly", "annual"], + default="quarterly", + help="Time window granularity", + ) + parser.add_argument("--svd-k", type=int, default=50, help="SVD dimensions") + parser.add_argument( + "--text-model", + default=None, + help="Text embedding model (default: ai_provider default)", + ) + parser.add_argument( + "--skip-metadata", action="store_true", help="Skip MP metadata fetch" + ) + parser.add_argument( + "--skip-extract", action="store_true", help="Skip MP vote extraction" + ) + parser.add_argument("--skip-svd", action="store_true", help="Skip SVD computation") + parser.add_argument( + "--skip-text", action="store_true", help="Skip text embedding gap-fill" + ) + parser.add_argument("--skip-fusion", action="store_true", help="Skip vector fusion") + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would happen without writing anything", + ) + return parser + + +if __name__ == "__main__": + parser = build_parser() + args = parser.parse_args() + sys.exit(run(args)) diff --git a/pyproject.toml b/pyproject.toml index 1d58676..a9975b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,4 +15,5 @@ dependencies = [ "requests>=2.32.4", "schedule>=1.2.2", "streamlit>=1.48.0", + "scikit-learn>=1.8.0", ] diff --git a/tests/test_analysis.py b/tests/test_analysis.py new file mode 100644 index 0000000..12285aa --- /dev/null +++ b/tests/test_analysis.py @@ -0,0 +1,195 @@ +"""Tests for analysis modules: political_axis, trajectory, clustering.""" + +import json +import numpy as np +import pytest + +duckdb = pytest.importorskip("duckdb") + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _setup_svd_vectors(db_path: str, window_ids_mp_vecs: dict): + """Insert synthetic MP SVD vectors into svd_vectors table. + + window_ids_mp_vecs: {window_id: {mp_name: np.ndarray}} + """ + conn = duckdb.connect(db_path) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS svd_vectors ( + id INTEGER, + window_id TEXT, + entity_type TEXT, + entity_id TEXT, + vector JSON, + model TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + for wid, mp_vecs in window_ids_mp_vecs.items(): + for mp_name, vec in mp_vecs.items(): + conn.execute( + "INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES (?, 'mp', ?, ?, 'test')", + (wid, mp_name, json.dumps(vec.tolist())), + ) + conn.close() + + +def _setup_mp_metadata(db_path: str, mp_party: dict): + """Insert synthetic MP metadata rows.""" + conn = duckdb.connect(db_path) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS mp_metadata ( + mp_name TEXT, + party TEXT, + van DATE, + tot_en_met DATE, + persoon_id TEXT + ) + """ + ) + for mp_name, party in mp_party.items(): + conn.execute( + "INSERT INTO mp_metadata (mp_name, party) VALUES (?, ?)", + (mp_name, party), + ) + conn.close() + + +# ── political_axis ──────────────────────────────────────────────────────────── + + +class TestPoliticalAxis: + def test_pca_axis_basic(self, tmp_path): + np.random.seed(42) + db_path = str(tmp_path / "test.db") + n_mps, k = 20, 5 + + # Create a low-rank set of MP vectors (they should have a clear first PC) + vecs = np.random.randn(n_mps, k) + mp_names = [f"MP_{i}" for i in range(n_mps)] + _setup_svd_vectors( + db_path, {"2024-Q1": {mp_names[i]: vecs[i] for i in range(n_mps)}} + ) + + from analysis.political_axis import compute_pca_axis + + scores = compute_pca_axis(db_path, "2024-Q1") + assert len(scores) == n_mps + assert all(isinstance(v, float) for v in scores.values()) + # Scores should have non-trivial variance + vals = list(scores.values()) + assert np.std(vals) > 0.0 + + def test_pca_axis_too_few_mps(self, tmp_path): + db_path = str(tmp_path / "test.db") + _setup_svd_vectors(db_path, {"w1": {"MP_A": np.array([1.0, 0.0])}}) + + from analysis.political_axis import compute_pca_axis + + scores = compute_pca_axis(db_path, "w1") + assert scores == {} + + def test_anchor_axis_basic(self, tmp_path): + db_path = str(tmp_path / "test.db") + # Two clusters clearly separated on dim 0 + left_vec = np.array([-2.0, 0.0, 0.0]) + right_vec = np.array([2.0, 0.0, 0.0]) + mp_vecs = { + "Left_A": left_vec + np.array([0.1, 0.0, 0.0]), + "Left_B": left_vec - np.array([0.1, 0.0, 0.0]), + "Right_A": right_vec + np.array([0.1, 0.0, 0.0]), + "Right_B": right_vec - np.array([0.1, 0.0, 0.0]), + "Centre": np.array([0.0, 0.0, 0.0]), + } + _setup_svd_vectors(db_path, {"w1": mp_vecs}) + _setup_mp_metadata( + db_path, + { + "Left_A": "SP", + "Left_B": "SP", + "Right_A": "VVD", + "Right_B": "VVD", + "Centre": "D66", + }, + ) + + from analysis.political_axis import compute_anchor_axis + + scores = compute_anchor_axis( + db_path, "w1", left_parties=["SP"], right_parties=["VVD"] + ) + assert len(scores) == 5 + # Left MPs should have negative scores, Right MPs positive + assert scores["Left_A"] < scores["Right_A"] + assert scores["Left_B"] < scores["Right_B"] + + +# ── trajectory ─────────────────────────────────────────────────────────────── + + +class TestTrajectory: + def test_basic_trajectory(self, tmp_path): + np.random.seed(0) + db_path = str(tmp_path / "test.db") + + vec_w1 = {"MP_A": np.array([1.0, 0.0]), "MP_B": np.array([0.0, 1.0])} + vec_w2 = { + "MP_A": np.array([1.5, 0.5]), + "MP_B": np.array([0.0, 1.0]), + "MP_C": np.array([2.0, 2.0]), + } + _setup_svd_vectors(db_path, {"2024-Q1": vec_w1, "2024-Q2": vec_w2}) + + from analysis.trajectory import compute_trajectories, top_drifters + + traj = compute_trajectories(db_path) + + # Only MPs appearing in >= 2 windows + assert "MP_A" in traj + assert "MP_B" in traj + assert "MP_C" not in traj # only in one window + + assert len(traj["MP_A"]["drift"]) == 1 + assert traj["MP_A"]["total_drift"] > 0.0 + + # MP_B didn't move — drift should be 0 + assert traj["MP_B"]["total_drift"] == pytest.approx(0.0) + + drifters = top_drifters(traj, n=5) + assert drifters[0]["mp_name"] == "MP_A" + + def test_fewer_than_2_windows(self, tmp_path): + db_path = str(tmp_path / "test.db") + _setup_svd_vectors(db_path, {"2024-Q1": {"MP_A": np.array([1.0, 2.0])}}) + + from analysis.trajectory import compute_trajectories + + traj = compute_trajectories(db_path) + assert traj == {} + + +# ── clustering ──────────────────────────────────────────────────────────────── + + +class TestClustering: + def test_cluster_kmeans_basic(self): + from analysis.clustering import cluster_kmeans + import numpy as np + + coords = np.random.randn(20, 2) + labels = cluster_kmeans(coords, n_clusters=3) + assert len(labels) == 20 + assert set(labels).issubset({0, 1, 2}) + + def test_cluster_kmeans_fewer_points_than_clusters(self): + from analysis.clustering import cluster_kmeans + + coords = np.array([[0.0, 0.0], [1.0, 1.0]]) + labels = cluster_kmeans(coords, n_clusters=5) + # Should not crash; n_clusters clamped to len(coords) + assert len(labels) == 2 diff --git a/tests/test_run_pipeline.py b/tests/test_run_pipeline.py new file mode 100644 index 0000000..e75de95 --- /dev/null +++ b/tests/test_run_pipeline.py @@ -0,0 +1,113 @@ +"""Tests for pipeline/run_pipeline.py""" + +import argparse +import sys +import pytest + +from pipeline.run_pipeline import _generate_windows, build_parser, run +from datetime import date + + +def test_generate_windows_quarterly(): + start = date(2024, 1, 1) + end = date(2024, 12, 31) + windows = _generate_windows(start, end, "quarterly") + + assert len(windows) == 4 + ids = [w[0] for w in windows] + assert ids == ["2024-Q1", "2024-Q2", "2024-Q3", "2024-Q4"] + + # Q1 bounds + assert windows[0][1] == "2024-01-01" + assert windows[0][2] == "2024-03-31" + + # Q4 bounds + assert windows[3][1] == "2024-10-01" + assert windows[3][2] == "2024-12-31" + + +def test_generate_windows_annual(): + start = date(2022, 6, 1) + end = date(2024, 3, 31) + windows = _generate_windows(start, end, "annual") + + assert len(windows) == 3 + ids = [w[0] for w in windows] + assert ids == ["2022", "2023", "2024"] + + # 2024 should end at end_date, not Dec 31 + assert windows[2][2] == "2024-03-31" + + +def test_generate_windows_mid_quarter_start(): + """Starting in the middle of Q2 should still produce a full Q2 window.""" + start = date(2024, 5, 15) + end = date(2024, 9, 30) + windows = _generate_windows(start, end, "quarterly") + + ids = [w[0] for w in windows] + assert "2024-Q2" in ids + assert "2024-Q3" in ids + + +def test_build_parser_defaults(): + parser = build_parser() + args = parser.parse_args([]) + assert args.db_path == "data/motions.db" + assert args.window_size == "quarterly" + assert args.svd_k == 50 + assert args.dry_run is False + + +def test_run_dry_run(tmp_path, monkeypatch): + """Dry-run should log actions and return 0 without touching the DB.""" + db_path = str(tmp_path / "motions.db") + + # Create minimal DB so MotionDatabase initialises + from database import MotionDatabase + + MotionDatabase(db_path) + + args = argparse.Namespace( + db_path=db_path, + start_date="2024-01-01", + end_date="2024-03-31", + window_size="quarterly", + svd_k=10, + text_model=None, + skip_metadata=False, + skip_extract=False, + skip_svd=False, + skip_text=False, + skip_fusion=False, + dry_run=True, + ) + + exit_code = run(args) + assert exit_code == 0 + + +def test_run_skip_all(tmp_path): + """Skipping all phases should still return 0.""" + db_path = str(tmp_path / "motions.db") + from database import MotionDatabase + + MotionDatabase(db_path) + + args = argparse.Namespace( + db_path=db_path, + start_date="2024-01-01", + end_date="2024-03-31", + window_size="quarterly", + svd_k=10, + text_model=None, + skip_metadata=True, + skip_extract=True, + skip_svd=True, + skip_text=True, + skip_fusion=True, + dry_run=False, + ) + + exit_code = run(args) + assert exit_code == 0 diff --git a/uv.lock b/uv.lock index cbcc3ac..e887c14 100644 --- a/uv.lock +++ b/uv.lock @@ -1056,6 +1056,7 @@ dependencies = [ { name = "pytest" }, { name = "requests" }, { name = "schedule" }, + { name = "scikit-learn" }, { name = "scipy" }, { name = "streamlit" }, { name = "umap-learn" }, @@ -1070,6 +1071,7 @@ requires-dist = [ { name = "pytest", specifier = ">=9.0.2" }, { name = "requests", specifier = ">=2.32.4" }, { name = "schedule", specifier = ">=1.2.2" }, + { name = "scikit-learn", specifier = ">=1.8.0" }, { name = "scipy", specifier = ">=1.11" }, { name = "streamlit", specifier = ">=1.48.0" }, { name = "umap-learn", specifier = ">=0.5" },