"""Recompute per-window SVD into a fresh DB copy and re-run 2D axes. This script copies the current data/motions.db to a new file (data/motions_recompute.db), clears any existing svd_vectors rows for the target windows in the new DB, runs SVD on each window, then computes 2D axes and writes compass + trajectories into outputs_recomputed/ for inspection. Usage: uv run python3 scripts/recompute_svd.py --db data/motions.db --out outputs_recomputed """ from __future__ import annotations import argparse import calendar import logging import os import shutil import sys from datetime import date from typing import List, Tuple ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if ROOT not in sys.path: sys.path.insert(0, ROOT) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("recompute_svd") def year_bounds(window_id: str) -> Tuple[str, str]: """Return (start_date, end_date) for an annual window_id like '2024'. Quarterly window IDs (containing '-Q') are not supported — this script only processes annual windows. """ if "-Q" in window_id: raise ValueError( f"Quarterly window '{window_id}' is not supported. " "Only annual windows should be recomputed." ) y = int(window_id) start = date(y, 1, 1).isoformat() end = date(y, 12, 31).isoformat() return start, end def main(argv: List[str] | None = None) -> int: p = argparse.ArgumentParser() p.add_argument("--db", default="data/motions.db") p.add_argument("--out", default="outputs_recomputed") p.add_argument("--k", type=int, default=50) args = p.parse_args(argv) os.makedirs(args.out, exist_ok=True) # Copy DB to a new file so we don't clobber originals src = args.db dst = os.path.splitext(src)[0] + "_recompute.db" logger.info("Copying %s -> %s", src, dst) shutil.copyfile(src, dst) # Lazy imports try: from database import MotionDatabase from pipeline.svd_pipeline import run_svd_for_window from analysis.political_axis import compute_2d_axes from analysis.visualize import ( plot_political_compass, plot_2d_trajectories, _load_party_map, ) from analysis import trajectory as traj except Exception as e: logger.exception("Import failed: %s", e) return 2 # build MotionDatabase pointing to new file db = MotionDatabase(dst) # find windows from original DB via trajectory helper all_window_ids = traj._load_window_ids(src) # Only process annual windows — quarterly windows are excluded from all PCA/SVD computation window_ids = [w for w in all_window_ids if "-Q" not in w] if not window_ids: logger.error("No annual windows found in source DB %s", src) return 3 logger.info("Will recompute SVD for annual windows: %s", window_ids) # clear existing svd_vectors rows for these windows in dst DB import duckdb conn = duckdb.connect(dst) try: conn.execute( "DELETE FROM svd_vectors WHERE window_id IN ({})".format( ",".join([f"'{w}'" for w in window_ids]) ) ) conn.commit() logger.info("Cleared existing svd_vectors rows for windows in %s", dst) finally: conn.close() # Run SVD per window for wid in window_ids: start, end = year_bounds(wid) logger.info("Running SVD for %s (%s -> %s) k=%d", wid, start, end, args.k) res = run_svd_for_window( db=db, window_id=wid, start_date=start, end_date=end, k=args.k ) logger.info("SVD result for %s: %s", wid, res) # Recompute 2D axes and plots from the recomputed DB logger.info("Computing 2D axes (pca_residual=True) from recomputed DB") positions_by_window, axes = compute_2d_axes( dst, method="pca", pca_residual=True, normalize_vectors=True ) if not positions_by_window: logger.error("No positions returned from compute_2d_axes on recomputed DB") return 5 latest = sorted(positions_by_window.keys())[-1] party_map = _load_party_map(dst) compass_out = os.path.join(args.out, f"political_compass_recomputed_{latest}.html") traj_out = os.path.join(args.out, f"trajectories_recomputed_{latest}_top50.html") plot_political_compass( positions_by_window, window_id=latest, party_of=party_map, axis_def=axes, output_path=compass_out, ) logger.info("Wrote recomputed compass to %s", compass_out) # compute simple trajectories from positions_by_window # build per-MP coords mp_coords = {} for wid in sorted(positions_by_window.keys()): for mp, coord in positions_by_window[wid].items(): mp_coords.setdefault(mp, []).append((wid, coord)) names = [n for n, v in mp_coords.items() if len(v) >= 2] plot_2d_trajectories(positions_by_window, mp_names=names[:50], output_path=traj_out) logger.info("Wrote recomputed trajectories to %s", traj_out) # write a short diagnostic JSON (convert numpy arrays to lists) import json import numpy as _np def _to_serializable(o): if isinstance(o, _np.ndarray): return o.tolist() if isinstance(o, (_np.floating, _np.integer)): return float(o) raise TypeError(f"Object of type {type(o)} is not JSON serializable") diag = {"windows": window_ids, "axes": axes} with open( os.path.join(args.out, "recompute_diag.json"), "w", encoding="utf-8" ) as f: json.dump(diag, f, indent=2, default=_to_serializable) logger.info("Recompute complete; outputs in %s and DB copy at %s", args.out, dst) return 0 if __name__ == "__main__": raise SystemExit(main())