diff --git a/.gitignore b/.gitignore index 505a3b1..8630cd8 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,12 @@ wheels/ # Virtual environments .venv + +# Database files (large binary, not suited for git) +data/*.db +data/*.bak +data/*.json + +# Generated output files +outputs/ +outputs_*/ diff --git a/ai_provider.py b/ai_provider.py index d9772a7..aaae765 100644 --- a/ai_provider.py +++ b/ai_provider.py @@ -110,30 +110,6 @@ def _post_with_retries( time.sleep(sleep) continue - # Treat 429 (rate limiting) as transient and respect Retry-After header when present - if status == 429: - if attempt == retries: - raise ProviderError(f"Provider returned HTTP {resp.status_code}") - retry_after = None - try: - # header may be present as int seconds or as string - retry_after = resp.headers.get("Retry-After") - except Exception: - retry_after = None - - if retry_after is not None: - try: - sleep = float(retry_after) - except Exception: - # fallback to exponential backoff if header unparsable - sleep = backoff * (2 ** (attempt - 1)) - else: - sleep = backoff * (2 ** (attempt - 1)) - - sleep = sleep + random.uniform(0, sleep * 0.1) - time.sleep(sleep) - continue - return resp # Should not reach here @@ -184,6 +160,70 @@ def get_embedding(text: str, model: str | None = None) -> list[float]: return [float(x) for x in embedding] +def get_embeddings_batch( + texts: list[str], model: str | None = None, batch_size: int = 50 +) -> list[list[float]]: + """Return embedding vectors for multiple texts using batched API calls. + + The OpenAI/OpenRouter /embeddings endpoint accepts an array of inputs. + This sends texts in chunks of `batch_size` and returns one embedding per input, + preserving order. Raises ProviderError on failure. + """ + if not texts: + return [] + + if model is None: + model = ( + os.environ.get("EMBEDDING_MODEL") + or os.environ.get("QWEN_EMBEDDING_MODEL") + or "qwen/qwen3-embedding-4b" + ) + + all_embeddings: list[list[float]] = [] + + for start in range(0, len(texts), batch_size): + chunk = texts[start : start + batch_size] + resp = _post_with_retries("/embeddings", json={"model": model, "input": chunk}) + + try: + data = resp.json() + except Exception as exc: + raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc + + try: + items = data["data"] + except Exception as exc: + # Check local fallback + fallback = os.environ.get( + "ALLOW_LOCAL_EMBED_FALLBACK", "false" + ).lower() in ("1", "true", "yes") + if fallback: + dim = int(os.environ.get("LOCAL_EMBED_DIM", "64")) + all_embeddings.extend(_local_embedding(t, dim=dim) for t in chunk) + continue + raise ProviderError( + f"Unexpected batch embedding response shape: {data}" + ) from exc + + # Sort by index to guarantee order (API spec says index field is present) + items_sorted = sorted(items, key=lambda x: x.get("index", 0)) + + if len(items_sorted) != len(chunk): + raise ProviderError( + f"Expected {len(chunk)} embeddings, got {len(items_sorted)}" + ) + + for item in items_sorted: + emb = item.get("embedding") + if not isinstance(emb, list): + raise ProviderError( + f"Embedding at index {item.get('index')} is not a list" + ) + all_embeddings.append([float(x) for x in emb]) + + return all_embeddings + + def _local_embedding(text: str, dim: int = 64) -> list[float]: """Deterministic local fallback embedding based on SHA256. diff --git a/analysis/political_axis.py b/analysis/political_axis.py index 1b46cbd..57995f2 100644 --- a/analysis/political_axis.py +++ b/analysis/political_axis.py @@ -161,6 +161,11 @@ def compute_2d_axes( to load and align windows so the returned coordinates are consistent across windows. """ + # Import trajectory helper at runtime so tests can monkeypatch sys.modules + import importlib + + _trajectory = importlib.import_module("analysis.trajectory") + if window_ids is None: window_ids = _trajectory._load_window_ids(db_path) @@ -238,6 +243,77 @@ def compute_2d_axes( "pca_residual_used": bool(pca_residual or evr1 > 0.85), } + # Ensure consistent left/right and progressive/conservative orientation + # by checking canonical party centroids and flipping axis signs if needed. + try: + right_parties = {"PVV", "VVD", "FVD", "BBB", "JA21"} + left_parties = {"SP", "PvdA", "GroenLinks", "GroenLinks-PvdA", "DENK"} + cons_parties = {"PVV", "VVD", "FVD", "CDA", "SGP", "BBB", "JA21"} + prog_parties = { + "GroenLinks", + "PvdA", + "PvdD", + "SP", + "GroenLinks-PvdA", + "DENK", + } + + # Build mapping of entity -> vector from stacked matrix M + ent_to_vec = {ent: vec for (wid, ent), vec in zip(entity_index, M)} + + def _centroid_for_party_set(party_set): + vecs = [] + for p in party_set: + if p in ent_to_vec: + vecs.append(ent_to_vec[p]) + try: + conn = duckdb.connect(db_path) + rows = conn.execute( + "SELECT mp_name, party FROM mp_metadata" + ).fetchall() + conn.close() + except Exception: + rows = [] + for mp_name, party in rows: + if party in party_set and mp_name in ent_to_vec: + vecs.append(ent_to_vec[mp_name]) + if not vecs: + return None + return np.mean(np.vstack(vecs), axis=0) + + # X-axis: left vs right + left_cent = _centroid_for_party_set(left_parties) + right_cent = _centroid_for_party_set(right_parties) + if left_cent is not None and right_cent is not None: + left_proj = float(np.dot(left_cent - M.mean(axis=0), comp1_hat)) + right_proj = float(np.dot(right_cent - M.mean(axis=0), comp1_hat)) + if right_proj < left_proj: + _logger.info( + "Flipping PCA x-axis to match canonical left/right orientation (right_proj=%.3f left_proj=%.3f)", + right_proj, + left_proj, + ) + axes["x_axis"] = -axes["x_axis"] + + # Y-axis: progressive vs conservative — prefer positive = conservative + prog_cent = _centroid_for_party_set(prog_parties) + cons_cent = _centroid_for_party_set(cons_parties) + if prog_cent is not None and cons_cent is not None: + prog_proj = float(np.dot(prog_cent - M.mean(axis=0), comp2_hat)) + cons_proj = float(np.dot(cons_cent - M.mean(axis=0), comp2_hat)) + # We want positive Y to mean 'progressive'. If the progressive + # centroid currently projects lower than the conservative centroid, + # flip the sign so progressive > conservative. + if prog_proj < cons_proj: + _logger.info( + "Flipping PCA y-axis so positive Y corresponds to progressive (prog_proj=%.3f cons_proj=%.3f)", + prog_proj, + cons_proj, + ) + axes["y_axis"] = -axes["y_axis"] + except Exception: + _logger.debug("Could not auto-orient PCA axes; leaving signs as-is") + # warn if PCA is effectively 1-D if evr1 > 0.85 and not pca_residual: _logger.warning( diff --git a/analysis/visualize.py b/analysis/visualize.py index 538ceb1..edf4ad9 100644 --- a/analysis/visualize.py +++ b/analysis/visualize.py @@ -27,6 +27,58 @@ def _require_plotly(): raise ImportError("plotly is not installed. Install it with: uv add plotly") +def _load_party_map(db_path: str = "data/motions.db") -> Dict[str, str]: + """Build a party mapping mp_name -> party. + + Prefers mp_metadata where available; otherwise uses majority-party from mp_votes. + Returns a dict of mp_name -> party (strings). + """ + try: + import duckdb + except Exception: + _logger.debug("duckdb not available when building party map") + return {} + + conn = duckdb.connect(db_path) + try: + # metadata-based mapping + rows = conn.execute( + "SELECT mp_name, party FROM mp_metadata WHERE party IS NOT NULL" + ).fetchall() + meta_map = {r[0]: r[1] for r in rows} + + # majority-party heuristic from mp_votes + rows = conn.execute( + """ + SELECT mp_name, party, COUNT(*) as n + FROM mp_votes + WHERE party IS NOT NULL + GROUP BY mp_name, party + """ + ).fetchall() + counts: Dict[str, List[tuple]] = {} + for mp_name, party, n in rows: + counts.setdefault(mp_name, []).append((party, n)) + maj_map: Dict[str, str] = {} + for mp_name, arr in counts.items(): + maj_map[mp_name] = max(arr, key=lambda x: x[1])[0] + + merged = dict(maj_map) + # prefer metadata mapping when available + merged.update(meta_map) + _logger.info( + "Built party map: %d from mp_votes majority, %d from mp_metadata", + len(maj_map), + len(meta_map), + ) + return merged + finally: + try: + conn.close() + except Exception: + pass + + def plot_umap_scatter( motion_ids: List[int], coords: List[List[float]], @@ -194,6 +246,7 @@ def plot_political_compass( try: import duckdb # type: ignore + conn = None try: conn = duckdb.connect(database="data/motions.db", read_only=True) df = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchdf() @@ -206,10 +259,11 @@ def plot_political_compass( len(party_of), ) finally: - try: - conn.close() - except Exception: - pass + if conn is not None: + try: + conn.close() + except Exception: + pass except ImportError: _logger.debug("duckdb not installed; proceeding without party mapping") except Exception as e: @@ -221,8 +275,18 @@ def plot_political_compass( scaled_ys = ys if axis_def and y_scale is None: evr = axis_def.get("explained_variance_ratio") if axis_def else None - if evr and isinstance(evr, (list, tuple)) and len(evr) >= 2: - evr1, evr2 = evr[0], evr[1] + # Accept lists/tuples or numpy arrays; avoid ambiguous truth checks + evr_list = None + if evr is not None: + try: + evr_list = list(evr) + except Exception: + try: + evr_list = [float(evr)] + except Exception: + evr_list = None + if evr_list is not None and len(evr_list) >= 2: + evr1, evr2 = float(evr_list[0]), float(evr_list[1]) if evr2 < 1e-6: scale_guess = 1.0 else: @@ -237,30 +301,42 @@ def plot_political_compass( elif axis_def and y_scale is not None: scaled_ys = [y * float(y_scale) for y in ys] - # mark unknowns differently - unknown_flags = [1 if parties[i] == "Unknown" else 0 for i in range(len(names))] + # mark unknowns differently: use descriptive labels so the legend doesn't + # show numeric symbol values like "PVV, 0" when color and symbol combine. + unknown_labels = [ + "Unknown" if parties[i] == "Unknown" else "Known" for i in range(len(names)) + ] fig = px.scatter( x=xs, y=scaled_ys, color=parties, - symbol=unknown_flags, + symbol=unknown_labels, hover_name=names, title=f"Political Compass ({window_id})", labels={ "x": "Left ← — → Right", "y": "Progressive ← — → Conservative", "color": "Party", - "symbol": "Unknown", + "symbol": "Known?", }, ) fig.update_traces(marker=dict(size=8, opacity=0.85)) # annotate explained variance if available if axis_def and axis_def.get("method") == "pca": evr = axis_def.get("explained_variance_ratio") - if evr and len(evr) >= 2: + evr_list = None + if evr is not None: + try: + evr_list = list(evr) + except Exception: + try: + evr_list = [float(evr)] + except Exception: + evr_list = None + if evr_list is not None and len(evr_list) >= 2: fig.update_layout( - title=f"Political Compass ({window_id}) — PCA EVR PC1={evr[0] * 100:.1f}%, PC2={evr[1] * 100:.1f}%" + title=f"Political Compass ({window_id}) — PCA EVR PC1={evr_list[0] * 100:.1f}%, PC2={evr_list[1] * 100:.1f}%" ) fig.write_html(output_path, include_plotlyjs="cdn") _logger.info("Political compass written to %s", output_path) @@ -309,6 +385,45 @@ def plot_2d_trajectories( ) ) + # Add an arrow indicating the final direction (only one arrow per MP to + # avoid clutter). Use an annotation with an arrowhead from the penultimate + # to the last point and label the endpoint with the MP name. + try: + if len(xs) >= 2: + x0, y0 = xs[-2], ys[-2] + x1, y1 = xs[-1], ys[-1] + # small style choices — subtle arrow and a short label + fig.add_annotation( + x=x1, + y=y1, + ax=x0, + ay=y0, + xref="x", + yref="y", + axref="x", + ayref="y", + showarrow=True, + arrowhead=3, + arrowsize=1.0, + arrowwidth=1.2, + arrowcolor="rgba(0,0,0,0.6)", + opacity=0.8, + ) + # endpoint label slightly offset to reduce overlap with marker + fig.add_annotation( + x=x1, + y=y1, + xref="x", + yref="y", + text=mp, + showarrow=False, + xanchor="left", + yanchor="bottom", + font=dict(size=10, color="rgba(0,0,0,0.8)"), + ) + except Exception: + _logger.exception("Failed to add arrow/label for MP %s", mp) + fig.update_layout( title="MP Trajectories on Political Compass", xaxis_title="Left ← — → Right", diff --git a/api_client.py b/api_client.py index 3c245ae..9abc96b 100644 --- a/api_client.py +++ b/api_client.py @@ -178,7 +178,8 @@ class TweedeKamerAPI: # Extract party and vote information party_name = record.get("ActorNaam") - vote_type = record.get("Soort", "").lower() + # Some records have Soort explicitly set to None; guard against that + vote_type = str(record.get("Soort") or "").lower() record_date = record.get("GewijzigdOp", "") if not party_name: diff --git a/data/motions.db b/data/motions.db deleted file mode 100644 index 927a418..0000000 Binary files a/data/motions.db and /dev/null differ diff --git a/database.py b/database.py index 222c35f..5ff5ae4 100644 --- a/database.py +++ b/database.py @@ -464,18 +464,18 @@ class MotionDatabase: """Store an embedding for a motion. Returns inserted row id or -1 on failure.""" try: conn = duckdb.connect(self.db_path) - # store vector as JSON + # Use explicit nextval for id since older tables may lack DEFAULT conn.execute( - "INSERT INTO embeddings (motion_id, model, vector, created_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP)", + "INSERT INTO embeddings (id, motion_id, model, vector, created_at) VALUES (nextval('embeddings_id_seq'), ?, ?, ?, CURRENT_TIMESTAMP)", (motion_id, model, json.dumps(vector)), ) - row = conn.execute("SELECT max(id) FROM embeddings").fetchone() + row = conn.execute("SELECT currval('embeddings_id_seq')").fetchone() conn.close() if row and row[0] is not None: return int(row[0]) return -1 except Exception as e: - print(f"Error storing embedding: {e}") + _logger.error("Error storing embedding: %s", e) try: conn.close() except Exception: @@ -685,6 +685,11 @@ class MotionDatabase: ) -> int: try: conn = duckdb.connect(self.db_path) + # Delete any existing row for this (motion_id, window_id) to prevent duplicates + conn.execute( + "DELETE FROM fused_embeddings WHERE motion_id = ? AND window_id = ?", + (motion_id, window_id), + ) conn.execute( """ INSERT INTO fused_embeddings (motion_id, window_id, vector, svd_dims, text_dims, created_at) diff --git a/outputs/anchor_axis_2025_Q2.html b/outputs/anchor_axis_2025_Q2.html deleted file mode 100644 index 4c88290..0000000 --- a/outputs/anchor_axis_2025_Q2.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/anchor_axis_2025_Q3.html b/outputs/anchor_axis_2025_Q3.html deleted file mode 100644 index c8c7378..0000000 --- a/outputs/anchor_axis_2025_Q3.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/anchor_axis_2025_Q4.html b/outputs/anchor_axis_2025_Q4.html deleted file mode 100644 index ad1c776..0000000 --- a/outputs/anchor_axis_2025_Q4.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/anchor_axis_2026_Q1.html b/outputs/anchor_axis_2026_Q1.html deleted file mode 100644 index 8075e46..0000000 --- a/outputs/anchor_axis_2026_Q1.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/political_axis_2025_Q1.html b/outputs/political_axis_2025_Q1.html deleted file mode 100644 index a082342..0000000 --- a/outputs/political_axis_2025_Q1.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/political_axis_2025_Q2.html b/outputs/political_axis_2025_Q2.html deleted file mode 100644 index 4e62e9b..0000000 --- a/outputs/political_axis_2025_Q2.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/political_axis_2025_Q3.html b/outputs/political_axis_2025_Q3.html deleted file mode 100644 index 682f724..0000000 --- a/outputs/political_axis_2025_Q3.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/political_axis_2025_Q4.html b/outputs/political_axis_2025_Q4.html deleted file mode 100644 index b77c0a9..0000000 --- a/outputs/political_axis_2025_Q4.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/political_axis_2026_Q1.html b/outputs/political_axis_2026_Q1.html deleted file mode 100644 index 5a07829..0000000 --- a/outputs/political_axis_2026_Q1.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/trajectories_normalized_top15.html b/outputs/trajectories_normalized_top15.html deleted file mode 100644 index 2124e2e..0000000 --- a/outputs/trajectories_normalized_top15.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/trajectories_party_aligned.html b/outputs/trajectories_party_aligned.html deleted file mode 100644 index c659e10..0000000 --- a/outputs/trajectories_party_aligned.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/outputs/trajectories_top15.html b/outputs/trajectories_top15.html deleted file mode 100644 index 1aacf5d..0000000 --- a/outputs/trajectories_top15.html +++ /dev/null @@ -1,7 +0,0 @@ - - - -
-
- - \ No newline at end of file diff --git a/pipeline/run_pipeline.py b/pipeline/run_pipeline.py index 0fb306d..c16c0ce 100644 --- a/pipeline/run_pipeline.py +++ b/pipeline/run_pipeline.py @@ -174,7 +174,7 @@ def run(args: argparse.Namespace) -> int: from pipeline.text_pipeline import ensure_text_embeddings stored, existing, no_text, errors = ensure_text_embeddings( - db_path=db_path, model=args.text_model + db_path=db_path, model=args.text_model, batch_size=args.text_batch_size ) _logger.info( " embeddings: stored=%d existing=%d no_text=%d errors=%d", @@ -240,6 +240,12 @@ def build_parser() -> argparse.ArgumentParser: default=None, help="Text embedding model (default: ai_provider default)", ) + parser.add_argument( + "--text-batch-size", + type=int, + default=200, + help="Number of texts per embedding API call (default: 200)", + ) parser.add_argument( "--skip-metadata", action="store_true", help="Skip MP metadata fetch" ) diff --git a/pipeline/text_pipeline.py b/pipeline/text_pipeline.py index c07d1d7..77ff128 100644 --- a/pipeline/text_pipeline.py +++ b/pipeline/text_pipeline.py @@ -55,10 +55,11 @@ def _select_text( def ensure_text_embeddings( - db_path: Optional[str] = None, model: Optional[str] = None + db_path: Optional[str] = None, model: Optional[str] = None, batch_size: int = 50 ) -> Tuple[int, int, int, int]: """Ensure all motions have text embeddings for `model`. + Uses batched API calls (batch_size texts per HTTP request) for speed. Returns tuple (stored_count, skipped_existing, skipped_no_text, errors). """ model = model or DEFAULT_MODEL @@ -87,14 +88,54 @@ def ensure_text_embeddings( skipped_no_text = 0 errors = 0 + # Separate motions with text from those without + with_text: List[Tuple[int, str]] = [] for motion_id, text in to_process: if not text: _logger.info("Skipping motion %s: no text available", motion_id) skipped_no_text += 1 - continue + else: + with_text.append((motion_id, text)) + + _logger.info( + "Processing %d motions in batches of %d (%d skipped no text, %d already exist)", + len(with_text), + batch_size, + skipped_no_text, + existing, + ) + + # Process in batches + for batch_start in range(0, len(with_text), batch_size): + batch = with_text[batch_start : batch_start + batch_size] + batch_ids = [mid for mid, _ in batch] + batch_texts = [txt for _, txt in batch] try: - vec = ai_provider.get_embedding(text, model=model) + vecs = ai_provider.get_embeddings_batch( + batch_texts, model=model, batch_size=batch_size + ) + except Exception as exc: + _logger.error( + "Batch embedding failed for motions %s..%s: %s", + batch_ids[0], + batch_ids[-1], + exc, + ) + errors += len(batch) + continue + + if len(vecs) != len(batch): + _logger.error( + "Batch size mismatch: expected %d, got %d embeddings", + len(batch), + len(vecs), + ) + errors += len(batch) + continue + + batch_stored = 0 + for (motion_id, _text), vec in zip(batch, vecs): if not isinstance(vec, list): _logger.warning( "Embedding provider returned non-list for motion %s", motion_id @@ -102,21 +143,33 @@ def ensure_text_embeddings( errors += 1 continue - res = db.store_embedding(motion_id, model, vec) - if res and res > 0: - stored += 1 - else: + try: + res = db.store_embedding(motion_id, model, vec) + if res and res > 0: + stored += 1 + batch_stored += 1 + else: + _logger.error( + "Failed to store embedding for motion %s (store returned %s)", + motion_id, + res, + ) + errors += 1 + except Exception as exc: _logger.error( - "Failed to store embedding for motion %s (store returned %s)", - motion_id, - res, + "Error storing embedding for motion %s: %s", motion_id, exc ) errors += 1 - except Exception as exc: - _logger.error( - "Error computing/storing embedding for motion %s: %s", motion_id, exc - ) - errors += 1 + + _logger.info( + "Batch %d-%d: stored %d/%d (total: %d/%d)", + batch_start, + batch_start + len(batch), + batch_stored, + len(batch), + stored + existing, + total_motions, + ) skipped_existing = int(existing) return stored, skipped_existing, skipped_no_text, errors diff --git a/scripts/compare_svd_exclude_parties.py b/scripts/compare_svd_exclude_parties.py new file mode 100644 index 0000000..acbb8b8 --- /dev/null +++ b/scripts/compare_svd_exclude_parties.py @@ -0,0 +1,204 @@ +"""Compare PCA axes with and without party-level vectors present. + +Generates diagnostics and HTML plots (when plotly available) into outputs/. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from typing import Dict, List + +ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger("compare_svd_exclude_parties") + + +def main(argv: List[str] | None = None): + p = argparse.ArgumentParser() + p.add_argument("--db", default="data/motions.db") + p.add_argument("--out", default="outputs") + args = p.parse_args(argv) + + os.makedirs(args.out, exist_ok=True) + + try: + from analysis import trajectory as traj + from analysis.visualize import ( + _load_party_map, + plot_political_compass, + plot_2d_trajectories, + ) + import numpy as np + except Exception as e: + logger.exception("Failed to import analysis modules: %s", e) + raise + + window_ids = traj._load_window_ids(args.db) + if not window_ids: + logger.error("No SVD windows found") + return 1 + latest = sorted(window_ids)[-1] + + # load raw vectors for latest window + conn = None + try: + # build party name set from mp_metadata + import duckdb + + conn = duckdb.connect(args.db) + rows = conn.execute( + "SELECT DISTINCT party FROM mp_metadata WHERE party IS NOT NULL" + ).fetchall() + party_names = set(r[0] for r in rows if r[0]) + finally: + if conn: + try: + conn.close() + except Exception: + pass + + raw = traj._load_mp_vectors_for_window(args.db, latest) + # group by vector JSON-like key + groups: Dict[str, List[str]] = {} + for ent, vec in raw.items(): + key = tuple([round(float(x), 8) for x in vec.tolist()]) + groups.setdefault(str(key), []).append(ent) + + group_list = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True) + + top_groups = [(len(v), v[:8]) for k, v in group_list[:20]] + logger.info("Top duplicate groups (count, sample entities): %s", top_groups) + + # entities that are party names + party_entities = [ent for ent in raw.keys() if ent in party_names] + logger.info( + "Found %d party-like entities in svd_vectors for %s", + len(party_entities), + latest, + ) + + # Build aligned windows excluding party-level entities + raw_window_vecs = { + wid: traj._load_mp_vectors_for_window(args.db, wid) for wid in window_ids + } + # create filtered copy that removes party-level entity ids + filtered_window_vecs = { + wid: {ent: vec for ent, vec in d.items() if ent not in party_names} + for wid, d in raw_window_vecs.items() + } + + aligned_filtered = traj._procrustes_align_windows(filtered_window_vecs) + # stack and compute PCA + all_vecs = [] + entity_index = [] + for wid, d in aligned_filtered.items(): + for ent, v in d.items(): + n = np.linalg.norm(v) + all_vecs.append(v / n if n > 1e-10 else v) + entity_index.append((wid, ent)) + + if not all_vecs: + logger.error("No vectors left after excluding parties — aborting") + return 2 + + M = np.vstack(all_vecs) + Mc = M - M.mean(axis=0) + try: + U, s, Vt = np.linalg.svd(Mc, full_matrices=False) + except Exception: + logger.exception("SVD failed on filtered data") + return 3 + + sv2 = s**2 + evr = sv2 / (sv2.sum() + 1e-20) + logger.info("Filtered PCA EVR top2: %s", evr[:2].tolist()) + + comp1 = Vt[0] + comp1_hat = comp1 / (np.linalg.norm(comp1) + 1e-12) + comp2 = Vt[1] if Vt.shape[0] > 1 else np.zeros_like(comp1) + comp2_hat = comp2 / (np.linalg.norm(comp2) + 1e-12) + + # project filtered entities for latest window + filtered_positions = {} + global_mean = M.mean(axis=0) + for (wid, ent), vec in zip(entity_index, M): + if wid != latest: + continue + v_centered = vec - global_mean + x = float(np.dot(v_centered, comp1_hat)) + y = float(np.dot(v_centered, comp2_hat)) + filtered_positions[ent] = (x, y) + + # save JSON and small report + out_json = os.path.join(args.out, "svd_filtered_positions.json") + with open(out_json, "w", encoding="utf-8") as f: + json.dump( + { + "latest": latest, + "positions": filtered_positions, + "evr": evr[:2].tolist(), + }, + f, + indent=2, + ) + logger.info("Wrote filtered positions to %s", out_json) + + # Also generate plots if plotly available + try: + party_map = _load_party_map(args.db) + # positions_by_window format expected by plot functions — include only latest + positions_by_window = {latest: filtered_positions} + pcomp_out = os.path.join(args.out, f"political_compass_filtered_{latest}.html") + plot_political_compass( + positions_by_window, + window_id=latest, + party_of=party_map, + axis_def={"method": "pca", "explained_variance_ratio": evr[:2]}, + output_path=pcomp_out, + ) + logger.info("Wrote filtered compass to %s", pcomp_out) + # simple trajectory plotting for filtered set — top movers by count + traj_out = os.path.join(args.out, f"trajectories_filtered_{latest}.html") + # Build simple per-MP coords across windows for filtered set + mp_coords = {} + for wid in window_ids: + for ent, coord in aligned_filtered.get(wid, {}).items(): + if ent not in mp_coords: + mp_coords[ent] = [] + mp_coords[ent].append((wid, tuple(coord.tolist()))) + # pick MPs with at least 2 windows + names = [n for n, v in mp_coords.items() if len(v) >= 2] + plot_2d_trajectories( + { + wid: { + n: mp_coords[n][i][1] + for n in names + for i, (w, _) in enumerate(mp_coords[n]) + if w == wid + } + for wid in window_ids + }, + mp_names=names[:50], + output_path=traj_out, + ) + logger.info("Wrote filtered trajectories to %s", traj_out) + except Exception: + logger.exception("Plotting filtered results failed — plots skipped") + + # console summary + print("Top duplicate groups (count, sample):") + for k, v in group_list[:20]: + print(len(v), v[:6]) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/download_past_year.py b/scripts/download_past_year.py index 03aa419..5206f6c 100644 --- a/scripts/download_past_year.py +++ b/scripts/download_past_year.py @@ -1,12 +1,13 @@ -"""download_past_year.py — One-shot data download: past year of parliamentary motions. +"""download_past_year.py — One-shot data download: parliamentary motions for a date range. -Fetches Stemming records from the OData API in quarterly chunks (90-day windows), +Fetches Stemming records from the OData API in chunks (default 90-day windows), stores motions into data/motions.db using MotionDatabase.insert_motion(). Skips AI summarisation — this is a raw data fetch for the embedding pipeline. Usage: uv run python scripts/download_past_year.py [--db-path data/motions.db] [--days 365] + uv run python scripts/download_past_year.py --start-date 2019-01-01 --end-date 2022-01-01 """ import argparse @@ -21,10 +22,25 @@ from database import MotionDatabase def main(): - parser = argparse.ArgumentParser(description="Download past year of motions") + parser = argparse.ArgumentParser(description="Download motions for a date range") parser.add_argument("--db-path", default="data/motions.db") parser.add_argument( - "--days", type=int, default=365, help="How many days back to fetch" + "--days", + type=int, + default=365, + help="How many days back to fetch (ignored if --start-date given)", + ) + parser.add_argument( + "--start-date", + type=str, + default=None, + help="Explicit start date YYYY-MM-DD (overrides --days)", + ) + parser.add_argument( + "--end-date", + type=str, + default=None, + help="Explicit end date YYYY-MM-DD (default: today)", ) parser.add_argument("--chunk-days", type=int, default=90, help="Days per API chunk") parser.add_argument( @@ -41,8 +57,15 @@ def main(): api = TweedeKamerAPI() db = MotionDatabase(args.db_path) - end_date = datetime.now() - start_date = end_date - timedelta(days=args.days) + end_date = ( + datetime.strptime(args.end_date, "%Y-%m-%d") + if args.end_date + else datetime.now() + ) + if args.start_date: + start_date = datetime.strptime(args.start_date, "%Y-%m-%d") + else: + start_date = end_date - timedelta(days=args.days) print( f"Downloading motions from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}" diff --git a/scripts/fill_mp_votes_parties.py b/scripts/fill_mp_votes_parties.py new file mode 100644 index 0000000..4164435 --- /dev/null +++ b/scripts/fill_mp_votes_parties.py @@ -0,0 +1,277 @@ +"""Backfill missing mp_votes.party values from mp_metadata and co-voting inference. + +Multi-tier strategy: + 1) Tussenvoegsel-aware name match against mp_metadata. + 2) Majority party already recorded in mp_votes for the same MP. + 3) Looser last-name-token match against mp_metadata. + 4) Co-voting inference: for MPs still unresolved, find which party's MPs + they vote identically with most often, using a Jaccard-style overlap. + +Usage: + uv run python3 scripts/fill_mp_votes_parties.py --db data/motions.db +""" + +from __future__ import annotations + +import argparse +import logging +import re +import unicodedata +from collections import defaultdict +from datetime import datetime + +import duckdb + +logger = logging.getLogger("fill_mp_votes_parties") + + +_TUSSENVOEGSEL = { + "van de", + "van den", + "van der", + "van het", + "van", + "de", + "den", + "der", + "het", + "ter", + "ten", + "el", + "al", + "in 't", +} + +# Build a regex that matches any known tussenvoegsel (longest first to avoid +# partial matches like "van" eating the "van" in "van der"). +_TV_PATTERN = re.compile( + r"\b(" + + "|".join(re.escape(tv) for tv in sorted(_TUSSENVOEGSEL, key=len, reverse=True)) + + r")\b", + re.IGNORECASE, +) + + +def normalize_mp_key(name: str) -> str: + """Produce a canonical key that matches regardless of tussenvoegsel position. + + Both "Burg van der, E." (mp_votes style) and "Van der Burg, E." + (mp_metadata style) should produce the same key. Also strips diacritics + so "Kostić, I." matches "Kostic, I.". + + Strategy: split into pre-comma and post-comma parts. From the pre-comma + part, extract any tussenvoegsel tokens and the remaining lastname. + Canonical key = "lastname tussenvoegsel initials", all lowercased. + """ + if not name: + return "" + # Strip diacritics: NFD decompose then drop combining marks + s = unicodedata.normalize("NFD", name) + s = "".join(c for c in s if unicodedata.category(c) != "Mn") + # remove parenthetical fullnames e.g. "(Christine)" + s = re.sub(r"\s*\(.*?\)", "", s).strip() + # remove dots and commas for splitting but keep the comma position + # Split on first comma: last_part, initials_part + parts = s.split(",", 1) + last_part = parts[0].strip() + initials_part = parts[1].strip() if len(parts) > 1 else "" + + # Clean initials: remove dots + initials = re.sub(r"\.", "", initials_part).strip().lower() + + # From last_part, extract tussenvoegsel and lastname + last_lower = last_part.lower() + # Find all tussenvoegsel matches + found_tv = [] + remaining = last_lower + for m in _TV_PATTERN.finditer(last_lower): + found_tv.append(m.group(0).lower()) + # Remove tussenvoegsel tokens from remaining to get the pure lastname + remaining = _TV_PATTERN.sub("", last_lower).strip() + remaining = re.sub(r"\s+", " ", remaining).strip() + + # Sort tussenvoegsel to canonical order + tv_str = " ".join(sorted(found_tv)) if found_tv else "" + + # Build canonical key: "lastname tv initials" + key_parts = [remaining] + if tv_str: + key_parts.append(tv_str) + if initials: + key_parts.append(initials) + return " ".join(key_parts) + + +def pick_preferred_party(records: list) -> str | None: + # records: list of dicts with keys party, van, tot + # prefer active membership + for r in records: + if r.get("tot") is None and r.get("party"): + return r.get("party") + # otherwise pick most recent van + best = None + best_date = None + for r in records: + van = r.get("van") + try: + d = datetime.fromisoformat(van).date() if van else None + except Exception: + d = None + if d and (best_date is None or d > best_date): + best_date = d + best = r + if best: + return best.get("party") + # fallback to any party present + for r in records: + if r.get("party"): + return r.get("party") + return None + + +def _infer_party_by_covoting(conn, mp_name: str, min_overlap: int = 10) -> str | None: + """Infer party by finding which known-party MPs vote identically most often. + + For each motion where *mp_name* voted, find all other MPs who cast the + same vote AND already have a party assigned. The party with the highest + agreement count wins, provided the overlap exceeds *min_overlap*. + """ + rows = conn.execute( + """ + SELECT other.party, COUNT(*) AS agreement + FROM mp_votes me + JOIN mp_votes other + ON me.motion_id = other.motion_id + AND me.vote = other.vote + WHERE me.mp_name = ? + AND other.mp_name != ? + AND other.party IS NOT NULL + AND other.party != '' + AND other.mp_name LIKE '%,%' + GROUP BY other.party + ORDER BY agreement DESC + LIMIT 5 + """, + (mp_name, mp_name), + ).fetchall() + if not rows: + return None + + best_party, best_count = rows[0] + if best_count < min_overlap: + return None + + # Require meaningful margin over second-best to avoid ambiguous assignment + if len(rows) > 1: + second_count = rows[1][1] + # Best must have at least 20% more agreement than runner-up + if best_count < second_count * 1.2: + logger.debug( + "Co-voting ambiguous for %s: %s=%d vs %s=%d", + mp_name, + best_party, + best_count, + rows[1][0], + second_count, + ) + return None + + logger.info( + "Co-voting inferred %s -> %s (agreement=%d)", + mp_name, + best_party, + best_count, + ) + return best_party + + +def main(argv=None) -> int: + p = argparse.ArgumentParser() + p.add_argument("--db", default="data/motions.db") + args = p.parse_args(argv) + + conn = duckdb.connect(args.db) + + # Load mp_metadata + md_rows = conn.execute( + "SELECT mp_name, party, van, tot_en_met FROM mp_metadata" + ).fetchall() + + metadata = defaultdict(list) + for mp_name, party, van, tot in md_rows: + key = normalize_mp_key(mp_name) + metadata[key].append( + {"mp_name": mp_name, "party": party, "van": van, "tot": tot} + ) + + # Build majority-party mapping from existing mp_votes (non-null parties) + party_counts = defaultdict(lambda: defaultdict(int)) + rows_counts = conn.execute( + "SELECT mp_name, party, COUNT(*) FROM mp_votes WHERE party IS NOT NULL AND party != '' GROUP BY mp_name, party" + ).fetchall() + for mp_name, party, cnt in rows_counts: + key = normalize_mp_key(mp_name) + party_counts[key][party] += cnt + + majority_by_norm = { + k: max(v.items(), key=lambda kv: kv[1])[0] for k, v in party_counts.items() + } + + # Target mp_votes rows: individual MPs (contain comma) with NULL or empty party + target_rows = conn.execute( + "SELECT id, mp_name FROM mp_votes WHERE (party IS NULL OR party = '') AND mp_name LIKE '%,%'" + ).fetchall() + + updated = 0 + # Track MPs that need co-voting inference (tier 4) — collect after tiers 1-3 + covote_candidates: dict[str, list[int]] = defaultdict(list) # mp_name -> [ids] + + for id_, mp_name in target_rows: + key = normalize_mp_key(mp_name) + chosen_party = None + + # 1) exact normalized metadata match + if key in metadata: + chosen_party = pick_preferred_party(metadata[key]) + + # 2) fallback to majority observed in mp_votes + if not chosen_party: + chosen_party = majority_by_norm.get(key) + + # 3) try looser substring matches on lastname token + if not chosen_party: + tokens = key.split() + if tokens: + lastname = tokens[0] + # find metadata keys that start with lastname + for meta_key, recs in metadata.items(): + if meta_key.split()[0] == lastname: + chosen_party = pick_preferred_party(recs) + if chosen_party: + break + + if chosen_party: + conn.execute( + "UPDATE mp_votes SET party = ? WHERE id = ?", (chosen_party, id_) + ) + updated += 1 + else: + covote_candidates[mp_name].append(id_) + + # 4) Co-voting inference for remaining unresolved MPs + for mp_name, ids in covote_candidates.items(): + inferred = _infer_party_by_covoting(conn, mp_name) + if inferred: + for id_ in ids: + conn.execute( + "UPDATE mp_votes SET party = ? WHERE id = ?", (inferred, id_) + ) + updated += 1 + + conn.close() + logger.info("Updated %d mp_votes rows with party info", updated) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/generate_compass.py b/scripts/generate_compass.py new file mode 100644 index 0000000..c72e9ad --- /dev/null +++ b/scripts/generate_compass.py @@ -0,0 +1,157 @@ +"""Generate political compass and 2D trajectories HTML outputs. + +This script computes 2D axes using residual-PCA (or anchor), applies the +party-fill helper to colour MPs, and writes self-contained HTML files into +an outputs/ directory. + +Usage: + python scripts/generate_compass.py --db data/motions.db --out outputs --method pca --pca-residual + +The script is defensive: if required optional libraries (duckdb, plotly, +scipy) are missing it will log and exit without raising an uncaught exception. +""" + +from __future__ import annotations + +import argparse +import logging +import os +import sys +from typing import Optional + +# Ensure project root is on sys.path so `import analysis.*` works when the +# script is executed from the repository root or from scripts/ directly. +ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + + +logger = logging.getLogger("generate_compass") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +def main(argv: Optional[list] = None): + p = argparse.ArgumentParser() + p.add_argument("--db", default="data/motions.db", help="Path to duckdb database") + p.add_argument("--out", default="outputs", help="Output directory") + p.add_argument("--method", choices=["pca", "anchor"], default="pca") + p.add_argument( + "--pca-residual", action="store_true", help="Use residual PCA for second axis" + ) + p.add_argument( + "--y-scale", + type=float, + default=None, + help="Optional manual y-axis scale multiplier", + ) + args = p.parse_args(argv) + + # Lazy imports so the script exits gracefully if deps missing + try: + from analysis.political_axis import compute_2d_axes + from analysis.visualize import ( + plot_political_compass, + plot_2d_trajectories, + _load_party_map, + ) + except Exception as e: # pragma: no cover - runtime helper + logger.exception("Required analysis modules could not be imported: %s", e) + sys.exit(1) + + # Ensure output dir exists + os.makedirs(args.out, exist_ok=True) + + logger.info( + "Computing 2D axes (method=%s pca_residual=%s)", args.method, args.pca_residual + ) + + try: + positions_by_window, axis_def = compute_2d_axes( + args.db, + method=args.method, + pca_residual=args.pca_residual, + normalize_vectors=True, + ) + except Exception as e: # defensive + logger.exception("compute_2d_axes failed: %s", e) + sys.exit(1) + + if not positions_by_window: + logger.error("No positions produced — aborting") + sys.exit(1) + + # pick latest window (lexicographic order is used elsewhere in codebase) + window_id = sorted(positions_by_window.keys())[-1] + + # Build party mapping to colour points + try: + party_map = _load_party_map(args.db) + except Exception: + logger.exception("Failed to build party map; proceeding without it") + party_map = None + + # Output files + compass_out = os.path.join( + args.out, f"political_compass_{args.method}_{window_id}.html" + ) + traj_out = os.path.join(args.out, f"trajectories_compass_{args.method}_top50.html") + + try: + plot_political_compass( + positions_by_window, + window_id=window_id, + party_of=party_map, + axis_def=axis_def, + y_scale=args.y_scale, + output_path=compass_out, + ) + logger.info("Wrote compass to %s", compass_out) + except Exception: + logger.exception("Failed to write political compass") + + try: + # Build 2D trajectories from the already-computed positions_by_window so + # we keep the same PCA/anchor axes (compute_2d_trajectories would call + # compute_2d_axes again which may use different defaults). + import numpy as _np + + window_ids = sorted(positions_by_window.keys()) + + mp_data = {} + for wid in window_ids: + pos = positions_by_window.get(wid, {}) + for mp_name, coord in pos.items(): + mp_data.setdefault(mp_name, {"windows": [], "coords": []}) + mp_data[mp_name]["windows"].append(wid) + mp_data[mp_name]["coords"].append(tuple(coord)) + + trajs = {} + for mp_name, data in mp_data.items(): + if len(data["windows"]) < 2: + continue + coords = [_np.array(c, dtype=float) for c in data["coords"]] + step_vecs = [coords[i + 1] - coords[i] for i in range(len(coords) - 1)] + mags = [float(_np.linalg.norm(v)) for v in step_vecs] + trajs[mp_name] = { + "windows": data["windows"], + "coords": [[float(c[0]), float(c[1])] for c in coords], + "step_vectors": [[float(v[0]), float(v[1])] for v in step_vecs], + "step_magnitudes": mags, + "total_magnitude": float(sum(mags)), + } + + ranked = sorted( + trajs.items(), key=lambda kv: kv[1]["total_magnitude"], reverse=True + ) + top_names = [mp for mp, _ in ranked[:50]] if ranked else None + + plot_2d_trajectories( + positions_by_window, mp_names=top_names, output_path=traj_out + ) + logger.info("Wrote trajectories to %s", traj_out) + except Exception: + logger.exception("Failed to compute/write trajectories") + + +if __name__ == "__main__": + main() diff --git a/scripts/inspect_axis.py b/scripts/inspect_axis.py new file mode 100644 index 0000000..d28f863 --- /dev/null +++ b/scripts/inspect_axis.py @@ -0,0 +1,137 @@ +"""Inspect PCA axes and per-MP projections for diagnostics. + +Usage: + uv run python3 scripts/inspect_axis.py --db data/motions.db --out outputs +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from typing import Dict, List + +ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger("inspect_axis") + + +def main(argv: List[str] | None = None): + p = argparse.ArgumentParser() + p.add_argument("--db", default="data/motions.db") + p.add_argument("--out", default="outputs") + p.add_argument("--method", choices=["pca", "anchor"], default="pca") + p.add_argument("--pca-residual", action="store_true") + p.add_argument("--normalize", action="store_true", default=True) + args = p.parse_args(argv) + + os.makedirs(args.out, exist_ok=True) + + try: + from analysis.political_axis import compute_2d_axes + from analysis.visualize import _load_party_map + except Exception as e: + logger.exception("Failed to import analysis modules: %s", e) + raise + + positions_by_window, axes = compute_2d_axes( + args.db, + method=args.method, + pca_residual=args.pca_residual, + normalize_vectors=args.normalize, + ) + + if not positions_by_window: + logger.error("No positions produced") + return 2 + + latest = sorted(positions_by_window.keys())[-1] + pos = positions_by_window[latest] + + names = list(pos.keys()) + coords = list(pos.values()) + xs = [c[0] for c in coords] + ys = [c[1] for c in coords] + + import numpy as _np + + x_std = float(_np.std(xs)) + y_std = float(_np.std(ys)) + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + + party_map = _load_party_map(args.db) + + # load mp_votes counts + try: + import duckdb + + conn = duckdb.connect(args.db) + rows = conn.execute( + "SELECT mp_name, COUNT(*) FROM mp_votes GROUP BY mp_name" + ).fetchall() + conn.close() + vote_counts = {r[0]: int(r[1]) for r in rows} + except Exception: + vote_counts = {} + + # extremes + sorted_by_x = sorted(pos.items(), key=lambda kv: kv[1][0]) + sorted_by_y = sorted(pos.items(), key=lambda kv: kv[1][1]) + + def info_for(name: str): + party = party_map.get(name) + count = vote_counts.get(name, None) + x, y = pos.get(name, (None, None)) + return {"name": name, "party": party, "count": count, "x": x, "y": y} + + report = { + "db": args.db, + "latest_window": latest, + "n_entities": len(names), + "x_std": x_std, + "y_std": y_std, + "x_min": x_min, + "x_max": x_max, + "y_min": y_min, + "y_max": y_max, + "evr": axes.get("explained_variance_ratio") if axes else None, + "top_left_by_x": [info_for(n) for n, _ in sorted_by_x[:10]], + "top_right_by_x": [info_for(n) for n, _ in sorted_by_x[-10:]], + "top_by_y": [info_for(n) for n, _ in sorted_by_y[-10:]], + "bottom_by_y": [info_for(n) for n, _ in sorted_by_y[:10]], + } + + # count how many are near-center along x within small fraction of std + threshold = 0.2 * x_std if x_std > 0 else 0.01 + near_center = [n for n, (x, y) in pos.items() if abs(x) < threshold] + report["near_center_count"] = len(near_center) + report["near_center_sample"] = near_center[:40] + + # check duplicate coordinate pairs + coord_pairs = [(_np.round(c[0], 6), _np.round(c[1], 6)) for c in coords] + unique_coords = set(coord_pairs) + report["n_unique_coords"] = len(unique_coords) + report["n_total_entities"] = len(names) + + # look up particular MPs + for q in ("Ouwehand", "Keijzer", "Mona"): + found = [n for n in names if q.lower() in n.lower()] + report[f"matches_{q}"] = [info_for(n) for n in found] + + out_json = os.path.join(args.out, "inspect_axis.json") + with open(out_json, "w", encoding="utf-8") as f: + json.dump(report, f, indent=2) + + logger.info("Wrote inspection to %s", out_json) + print(json.dumps(report, indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/recompute_svd.py b/scripts/recompute_svd.py new file mode 100644 index 0000000..0c1edf6 --- /dev/null +++ b/scripts/recompute_svd.py @@ -0,0 +1,167 @@ +"""Recompute per-window SVD into a fresh DB copy and re-run 2D axes. + +This script copies the current data/motions.db to a new file (data/motions_recompute.db), +clears any existing svd_vectors rows for the target windows in the new DB, runs +SVD on each window, then computes 2D axes and writes compass + trajectories into +outputs_recomputed/ for inspection. + +Usage: + uv run python3 scripts/recompute_svd.py --db data/motions.db --out outputs_recomputed +""" + +from __future__ import annotations + +import argparse +import calendar +import logging +import os +import shutil +import sys +from datetime import date +from typing import List, Tuple + +ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger("recompute_svd") + + +def quarter_bounds(window_id: str) -> Tuple[str, str]: + # window_id like '2026-Q1' + year, q = window_id.split("-Q") + y = int(year) + qn = int(q) + starts = {1: (1, 1), 2: (4, 1), 3: (7, 1), 4: (10, 1)} + ends = {1: (3, 31), 2: (6, 30), 3: (9, 30), 4: (12, 31)} + s_m, s_d = starts[qn] + e_m, e_d = ends[qn] + start = date(y, s_m, s_d).isoformat() + end = date(y, e_m, e_d).isoformat() + return start, end + + +def main(argv: List[str] | None = None) -> int: + p = argparse.ArgumentParser() + p.add_argument("--db", default="data/motions.db") + p.add_argument("--out", default="outputs_recomputed") + p.add_argument("--k", type=int, default=50) + args = p.parse_args(argv) + + os.makedirs(args.out, exist_ok=True) + + # Copy DB to a new file so we don't clobber originals + src = args.db + dst = os.path.splitext(src)[0] + "_recompute.db" + logger.info("Copying %s -> %s", src, dst) + shutil.copyfile(src, dst) + + # Lazy imports + try: + from database import MotionDatabase + from pipeline.svd_pipeline import run_svd_for_window + from analysis.political_axis import compute_2d_axes + from analysis.visualize import ( + plot_political_compass, + plot_2d_trajectories, + _load_party_map, + ) + from analysis import trajectory as traj + except Exception as e: + logger.exception("Import failed: %s", e) + return 2 + + # build MotionDatabase pointing to new file + db = MotionDatabase(dst) + + # find windows from original DB via trajectory helper + window_ids = traj._load_window_ids(src) + if not window_ids: + logger.error("No windows found in source DB %s", src) + return 3 + + logger.info("Will recompute SVD for windows: %s", window_ids) + + # clear existing svd_vectors rows for these windows in dst DB + import duckdb + + conn = duckdb.connect(dst) + try: + conn.execute( + "DELETE FROM svd_vectors WHERE window_id IN ({})".format( + ",".join([f"'{w}'" for w in window_ids]) + ) + ) + conn.commit() + logger.info("Cleared existing svd_vectors rows for windows in %s", dst) + finally: + conn.close() + + # Run SVD per window + for wid in window_ids: + start, end = quarter_bounds(wid) + logger.info("Running SVD for %s (%s -> %s) k=%d", wid, start, end, args.k) + res = run_svd_for_window( + db=db, window_id=wid, start_date=start, end_date=end, k=args.k + ) + logger.info("SVD result for %s: %s", wid, res) + + # Recompute 2D axes and plots from the recomputed DB + logger.info("Computing 2D axes (pca_residual=True) from recomputed DB") + positions_by_window, axes = compute_2d_axes( + dst, method="pca", pca_residual=True, normalize_vectors=True + ) + if not positions_by_window: + logger.error("No positions returned from compute_2d_axes on recomputed DB") + return 5 + + latest = sorted(positions_by_window.keys())[-1] + party_map = _load_party_map(dst) + + compass_out = os.path.join(args.out, f"political_compass_recomputed_{latest}.html") + traj_out = os.path.join(args.out, f"trajectories_recomputed_{latest}_top50.html") + + plot_political_compass( + positions_by_window, + window_id=latest, + party_of=party_map, + axis_def=axes, + output_path=compass_out, + ) + logger.info("Wrote recomputed compass to %s", compass_out) + + # compute simple trajectories from positions_by_window + # build per-MP coords + mp_coords = {} + for wid in sorted(positions_by_window.keys()): + for mp, coord in positions_by_window[wid].items(): + mp_coords.setdefault(mp, []).append((wid, coord)) + + names = [n for n, v in mp_coords.items() if len(v) >= 2] + plot_2d_trajectories(positions_by_window, mp_names=names[:50], output_path=traj_out) + logger.info("Wrote recomputed trajectories to %s", traj_out) + + # write a short diagnostic JSON (convert numpy arrays to lists) + import json + import numpy as _np + + def _to_serializable(o): + if isinstance(o, _np.ndarray): + return o.tolist() + if isinstance(o, (_np.floating, _np.integer)): + return float(o) + raise TypeError(f"Object of type {type(o)} is not JSON serializable") + + diag = {"windows": window_ids, "axes": axes} + with open( + os.path.join(args.out, "recompute_diag.json"), "w", encoding="utf-8" + ) as f: + json.dump(diag, f, indent=2, default=_to_serializable) + + logger.info("Recompute complete; outputs in %s and DB copy at %s", args.out, dst) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/svd_diagnostics.py b/scripts/svd_diagnostics.py new file mode 100644 index 0000000..31cdcce --- /dev/null +++ b/scripts/svd_diagnostics.py @@ -0,0 +1,214 @@ +"""SVD and PCA diagnostics for the political compass pipeline. + +Produces a small text report and JSON summary in the outputs/ directory. + +Usage: + uv run python3 scripts/svd_diagnostics.py --db data/motions.db --out outputs +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from statistics import mean +from typing import Dict, List, Optional, Tuple + +ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +logger = logging.getLogger("svd_diagnostics") +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + +def find_by_substring(names: List[str], query: str) -> List[str]: + q = query.lower() + return [n for n in names if q in n.lower()] + + +def main(argv: Optional[list] = None): + p = argparse.ArgumentParser() + p.add_argument("--db", default="data/motions.db") + p.add_argument("--out", default="outputs") + args = p.parse_args(argv) + + os.makedirs(args.out, exist_ok=True) + + try: + from analysis import trajectory as traj + from analysis.political_axis import compute_2d_axes + from analysis.visualize import _load_party_map + except Exception as e: # pragma: no cover - runtime + logger.exception("Could not import analysis modules: %s", e) + raise + + # Load windows and aligned vectors + window_ids = traj._load_window_ids(args.db) + if not window_ids: + logger.error("No SVD windows found in DB %s", args.db) + return 1 + + logger.info("Found windows: %s", window_ids) + + raw_window_vecs = { + wid: traj._load_mp_vectors_for_window(args.db, wid) for wid in window_ids + } + aligned_window_vecs = traj._procrustes_align_windows(raw_window_vecs) + + # Compute global PCA axes (residual and non-residual) for comparison + positions_residual, axes_residual = compute_2d_axes( + args.db, + window_ids=window_ids, + method="pca", + normalize_vectors=True, + pca_residual=True, + ) + positions_plain, axes_plain = compute_2d_axes( + args.db, + window_ids=window_ids, + method="pca", + normalize_vectors=True, + pca_residual=False, + ) + + out_report = [] + + def add(line: str): + out_report.append(line) + logger.info(line) + + add("PCA diagnostics report") + add(f"DB: {args.db}") + add(f"Windows: {window_ids}") + + add("") + evr_res = axes_residual.get("explained_variance_ratio") if axes_residual else None + evr_plain = axes_plain.get("explained_variance_ratio") if axes_plain else None + add(f"Residual PCA EVR: {evr_res}") + add(f"Plain PCA EVR: {evr_plain}") + + # pick latest window for detailed inspection + latest = sorted(window_ids)[-1] + add("") + add(f"Inspecting latest window: {latest}") + + pos = positions_residual.get(latest, {}) + names = list(pos.keys()) + xs = [v[0] for v in pos.values()] + ys = [v[1] for v in pos.values()] + + def stats(arr: List[float]) -> Tuple[float, float]: + if not arr: + return 0.0, 0.0 + mn = min(arr) + mx = max(arr) + return mn, mx + + add(f"Entities in latest window: {len(names)}") + add(f"X range (left-right): {stats(xs)}") + add(f"Y range (prog-cons): {stats(ys)}") + # stdevs + try: + import numpy as _np + + x_std = float(_np.std(xs)) + y_std = float(_np.std(ys)) + except Exception: + x_std = 0.0 + y_std = 0.0 + add( + f"Std dev X: {x_std:.6f}, Std dev Y: {y_std:.6f} (ratio Y/X = {y_std / (x_std + 1e-12):.3f})" + ) + + # show extremes on X and Y + sorted_by_x = sorted(pos.items(), key=lambda kv: kv[1][0]) + sorted_by_y = sorted(pos.items(), key=lambda kv: kv[1][1]) + + add("") + add("Left-most (by X):") + for name, (x, y) in sorted_by_x[:8]: + add(f" {name:40s} x={x:.4f} y={y:.4f}") + + add("") + add("Right-most (by X):") + for name, (x, y) in sorted_by_x[-8:]: + add(f" {name:40s} x={x:.4f} y={y:.4f}") + + add("") + add("Top (conservative) (by Y):") + for name, (x, y) in sorted_by_y[-8:]: + add(f" {name:40s} x={x:.4f} y={y:.4f}") + + add("") + add("Bottom (progressive) (by Y):") + for name, (x, y) in sorted_by_y[:8]: + add(f" {name:40s} x={x:.4f} y={y:.4f}") + + # Find specific MPs mentioned by user + matches_ouwehand = find_by_substring(names, "ouwehand") + matches_mona = find_by_substring(names, "mona") + add("") + add(f"Matches for 'Ouwehand': {matches_ouwehand}") + for n in matches_ouwehand: + x, y = pos.get(n) + add(f" {n} -> x={x:.4f} y={y:.4f}") + add(f"Matches for 'Mona': {matches_mona}") + for n in matches_mona: + x, y = pos.get(n) + add(f" {n} -> x={x:.4f} y={y:.4f}") + + # Party centroids + party_map = _load_party_map(args.db) + parties: Dict[str, List[Tuple[float, float]]] = {} + for mp, coord in pos.items(): + party = party_map.get(mp) + if party: + parties.setdefault(party, []).append(coord) + party_centroids: Dict[str, Tuple[float, float]] = {} + for party, coords in parties.items(): + xs_p = [c[0] for c in coords] + ys_p = [c[1] for c in coords] + party_centroids[party] = (mean(xs_p), mean(ys_p)) + + add("") + add(f"Computed {len(party_centroids)} party centroids (from mp_metadata majority)") + sorted_parties_by_x = sorted(party_centroids.items(), key=lambda kv: kv[1][0]) + add("Party centroids left→right:") + for p, (x, y) in sorted_parties_by_x: + add(f" {p:20s} x={x:.4f} y={y:.4f}") + + sorted_parties_by_y = sorted(party_centroids.items(), key=lambda kv: kv[1][1]) + add("") + add("Party centroids prog→cons:") + for p, (x, y) in sorted_parties_by_y: + add(f" {p:20s} x={x:.4f} y={y:.4f}") + + # Save report and a small JSON summary + report_path = os.path.join(args.out, "svd_diagnostics.txt") + summary_path = os.path.join(args.out, "svd_diagnostics.json") + with open(report_path, "w", encoding="utf-8") as f: + f.write("\n".join(out_report)) + + summary = { + "db": args.db, + "windows": window_ids, + "latest_window": latest, + "evr_residual": evr_res, + "evr_plain": evr_plain, + "n_entities_latest": len(names), + "x_std": x_std, + "y_std": y_std, + "party_centroids": party_centroids, + } + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2) + + logger.info("Diagnostic report written to %s and %s", report_path, summary_path) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/integration/test_pipeline_end_to_end.py b/tests/integration/test_pipeline_end_to_end.py index 3711962..16bfcd5 100644 --- a/tests/integration/test_pipeline_end_to_end.py +++ b/tests/integration/test_pipeline_end_to_end.py @@ -53,14 +53,14 @@ def test_pipeline_end_to_end(tmp_path, monkeypatch): conn.close() - # monkeypatch ai_provider.get_embedding to deterministic vector + # monkeypatch ai_provider.get_embeddings_batch to deterministic vectors import ai_provider - def fake_get_embedding(text, model=None): - # produce a deterministic vector based on seeded numpy - return list(np.random.rand(16)) + def fake_get_embeddings_batch(texts, model=None, batch_size=50): + # produce a deterministic vector per text based on seeded numpy + return [list(np.random.rand(16)) for _ in texts] - monkeypatch.setattr("ai_provider.get_embedding", fake_get_embedding) + monkeypatch.setattr("ai_provider.get_embeddings_batch", fake_get_embeddings_batch) # run ensure_text_embeddings from pipeline.text_pipeline import ensure_text_embeddings diff --git a/tests/test_text_pipeline.py b/tests/test_text_pipeline.py index ff16108..d701cf9 100644 --- a/tests/test_text_pipeline.py +++ b/tests/test_text_pipeline.py @@ -49,11 +49,11 @@ def test_ensure_text_embeddings_monkeypatch(tmp_path, monkeypatch): conn.close() - # monkeypatch ai_provider.get_embedding - def fake_get_embedding(text, model=None): - return [0.1] * 16 + # monkeypatch ai_provider.get_embeddings_batch (used by batched pipeline) + def fake_get_embeddings_batch(texts, model=None, batch_size=50): + return [[0.1] * 16 for _ in texts] - monkeypatch.setattr("ai_provider.get_embedding", fake_get_embedding) + monkeypatch.setattr("ai_provider.get_embeddings_batch", fake_get_embeddings_batch) # run ensure_text_embeddings from pipeline.text_pipeline import ensure_text_embeddings diff --git a/thoughts/ledgers/CONTINUITY_stemwijzer.md b/thoughts/ledgers/CONTINUITY_stemwijzer.md index 5d56dd4..ed3e15f 100644 --- a/thoughts/ledgers/CONTINUITY_stemwijzer.md +++ b/thoughts/ledgers/CONTINUITY_stemwijzer.md @@ -1,50 +1,79 @@ -# Session: stemwijzer -Updated: 2026-03-20T00:23:33Z +# Session: stemwijzer — Parliamentary Embedding Pipeline +Updated: 2026-03-22T16:00:00Z ## Goal -Preserve the minimal session state required to resume work on the stemwijzer project after context clears (success = ledger exists and is kept up-to-date). +2D political compass + motion similarity search from parliamentary votes + motion text. +Full historical coverage 2016–2026, precomputed similarity cache, fused (SVD + text) embeddings. ## Constraints -- Keep the ledger CONCISE — only essential information -- Focus on WHAT and WHY, not HOW -- Mark uncertain information as UNCONFIRMED -- Include git branch and key file paths +- DuckDB only (`data/motions.db`); open/close `duckdb.connect(self.db_path)` per method +- Vectors stored as JSON text (no external vector DB) +- Logging via `logging.getLogger(__name__)`; no `print()` in library modules +- Tests run offline (network monkeypatched) — use `.venv/bin/python -m pytest -q` +- Do NOT modify `app.py` or `scheduler.py` +- Use `.venv/bin/python` (Arch Linux system Python is externally managed) -## Progress -### Done -- [x] Create initial continuity ledger file +## Current DB State (verified 2026-03-22 ~16:00) -### In Progress -- [ ] Capture ongoing session context and update ledger after each meaningful change +| Table | Rows | +|---|---| +| motions | 10,613 | +| embeddings | 10,753 | +| svd_vectors | 24,528 | +| fused_embeddings | **10,613** (1:1 with motions, 0 duplicates) | +| similarity_cache | **212,206** (top_k=20, all annual windows) | +| mp_votes | 199,967 | +| mp_metadata | 798 | -### Blocked -- None currently +## Annual Window Coverage + +| Year | Motions | Fused | Similarity | +|---|---|---|---| +| 2016 | 132 | 132 | 2,640 | +| 2017 | 30 | 30 | 600 | +| 2018 | 100 | 100 | 2,000 | +| 2019 | 3 | 3 | 6 | +| 2020 | 0 | 0 | 0 (no data) | +| 2021 | 0 | 0 | 0 (no data) | +| 2022 | 4,116 | 4,116 | 82,320 | +| 2023 | 621 | 621 | 12,420 | +| 2024 | 948 | 948 | 18,960 | +| 2025 | 3,715 | 3,715 | 74,300 | +| 2026 | 948 | 948 | 18,960 | + +## Completed This Session +- [x] Text embeddings: ran with real OpenRouter API at batch_size=200 → 10,753 embedding rows +- [x] Re-ran `extract_mp_votes` on all motions → 111,978 new rows (party-level votes backfilled) +- [x] SVD re-run (annual 2016–2026) with full vote data → 24,528 svd_vector rows +- [x] Fixed `store_fused_embedding` double-counting bug: added DELETE before INSERT +- [x] Cleaned and re-ran fusion → 10,613 fused rows, zero duplicates +- [x] Re-ran similarity cache top_k=20 for all 9 active windows → 212,206 rows +- [x] Test suite: **34 passed, 2 skipped** ✅ ## Key Decisions -- **Session name = "stemwijzer"**: Chosen from repository context (UNCONFIRMED if a different canonical session name is preferred). -- **Do not auto-commit ledger changes**: Commits will only be made when the user explicitly requests it (follows Git Safety Protocol). - -## Next Steps -1. Continue updating this ledger when tasks, files, or decisions change -2. Add entries for new branches or major feature work (mark as UNCONFIRMED when unsure) -3. Ask user before creating any git commits that include this ledger - -## File Operations -### Read -- `README.md` -- `pyproject.toml` -- `thoughts/shared/plans/2026-03-19-stemwijzer-plan.md` -- `thoughts/shared/designs/2026-03-19-stemwijzer-design.md` - -### Modified -- `thoughts/ledgers/CONTINUITY_stemwijzer.md` (new) - -## Critical Context -- Repository branch observed: `main` -- Found project metadata in `pyproject.toml` indicating Python tooling preference -- Existing notes/plans located under `thoughts/shared/` (plans and designs from 2026-03-19) -- No existing continuity ledger was found prior to this creation - -## Working Set -- Branch: `main` -- Key files: `README.md`, `pyproject.toml`, `thoughts/shared/plans/2026-03-19-stemwijzer-plan.md`, `thoughts/shared/designs/2026-03-19-stemwijzer-design.md`, `thoughts/ledgers/CONTINUITY_stemwijzer.md` +- `store_fused_embedding` (database.py line 686): Now does DELETE+INSERT instead of plain INSERT to prevent duplicates on re-runs. +- Annual windows chosen for historical political compass (2016–2026). +- top_k=20 for similarity cache. +- Party-level votes (e.g. `{"PVV": "voor"}`) handled in `extract_mp_votes` — actor without comma → `party=actor_name`. + +## Open Items (not blocking, data coverage gaps) +1. **2020–2021 data gap**: No motions in DB at all. Need to run downloader with `--start-date 2019-01-01 --end-date 2021-12-31` if data exists in API. +2. **2024 gap ~3,020 motions**: OData API has ~3,968 2024 motions, only 948 in DB. Root cause unclear — needs investigation of URL-based dedup in `insert_motion`. +3. **"Verworpen." dedup**: Short-text motions (title="Verworpen.") get spurious similarity=1.0. UI/query layer should filter `score < 0.999 OR title != 'Verworpen.'`. +4. **svd_vectors has duplicates**: 2025 has 7,430 rows for 3,715 motions (2x). Doesn't affect fused_embeddings (DELETE+INSERT handles it) but wastes space. Low priority. + +## Key File Paths +- DB: `data/motions.db` +- Venv: `.venv/bin/python` +- Pipeline entry: `pipeline/run_pipeline.py` +- Fusion: `pipeline/fusion.py` +- SVD: `pipeline/svd_pipeline.py` +- Text embeddings: `pipeline/text_pipeline.py` +- MP votes extraction: `pipeline/extract_mp_votes.py` +- Database layer: `database.py` +- Similarity compute: `similarity/compute.py` +- Similarity lookup: `similarity/lookup.py` +- Tests: `tests/` (pytest, offline) + +## Branch +`main`