You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/scripts/recompute_svd.py

172 lines
5.8 KiB

"""Recompute per-window SVD into a fresh DB copy and re-run 2D axes.
This script copies the current data/motions.db to a new file (data/motions_recompute.db),
clears any existing svd_vectors rows for the target windows in the new DB, runs
SVD on each window, then computes 2D axes and writes compass + trajectories into
outputs_recomputed/ for inspection.
Usage:
uv run python3 scripts/recompute_svd.py --db data/motions.db --out outputs_recomputed
"""
from __future__ import annotations
import argparse
import calendar
import logging
import os
import shutil
import sys
from datetime import date
from typing import List, Tuple
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("recompute_svd")
def year_bounds(window_id: str) -> Tuple[str, str]:
"""Return (start_date, end_date) for an annual window_id like '2024'.
Quarterly window IDs (containing '-Q') are not supported — this script
only processes annual windows.
"""
if "-Q" in window_id:
raise ValueError(
f"Quarterly window '{window_id}' is not supported. "
"Only annual windows should be recomputed."
)
y = int(window_id)
start = date(y, 1, 1).isoformat()
end = date(y, 12, 31).isoformat()
return start, end
def main(argv: List[str] | None = None) -> int:
p = argparse.ArgumentParser()
p.add_argument("--db", default="data/motions.db")
p.add_argument("--out", default="outputs_recomputed")
p.add_argument("--k", type=int, default=50)
args = p.parse_args(argv)
os.makedirs(args.out, exist_ok=True)
# Copy DB to a new file so we don't clobber originals
src = args.db
dst = os.path.splitext(src)[0] + "_recompute.db"
logger.info("Copying %s -> %s", src, dst)
shutil.copyfile(src, dst)
# Lazy imports
try:
from database import MotionDatabase
from pipeline.svd_pipeline import run_svd_for_window
from analysis.political_axis import compute_2d_axes
from analysis.visualize import (
plot_political_compass,
plot_2d_trajectories,
_load_party_map,
)
from analysis import trajectory as traj
except Exception as e:
logger.exception("Import failed: %s", e)
return 2
# build MotionDatabase pointing to new file
db = MotionDatabase(dst)
# find windows from original DB via trajectory helper
all_window_ids = traj._load_window_ids(src)
# Only process annual windows — quarterly windows are excluded from all PCA/SVD computation
window_ids = [w for w in all_window_ids if "-Q" not in w]
if not window_ids:
logger.error("No annual windows found in source DB %s", src)
return 3
logger.info("Will recompute SVD for annual windows: %s", window_ids)
# clear existing svd_vectors rows for these windows in dst DB
import duckdb
conn = duckdb.connect(dst)
try:
conn.execute(
"DELETE FROM svd_vectors WHERE window_id IN ({})".format(
",".join([f"'{w}'" for w in window_ids])
)
)
conn.commit()
logger.info("Cleared existing svd_vectors rows for windows in %s", dst)
finally:
conn.close()
# Run SVD per window
for wid in window_ids:
start, end = year_bounds(wid)
logger.info("Running SVD for %s (%s -> %s) k=%d", wid, start, end, args.k)
res = run_svd_for_window(
db=db, window_id=wid, start_date=start, end_date=end, k=args.k
)
logger.info("SVD result for %s: %s", wid, res)
# Recompute 2D axes and plots from the recomputed DB
logger.info("Computing 2D axes (pca_residual=True) from recomputed DB")
positions_by_window, axes = compute_2d_axes(
dst, method="pca", pca_residual=True, normalize_vectors=True
)
if not positions_by_window:
logger.error("No positions returned from compute_2d_axes on recomputed DB")
return 5
latest = sorted(positions_by_window.keys())[-1]
party_map = _load_party_map(dst)
compass_out = os.path.join(args.out, f"political_compass_recomputed_{latest}.html")
traj_out = os.path.join(args.out, f"trajectories_recomputed_{latest}_top50.html")
plot_political_compass(
positions_by_window,
window_id=latest,
party_of=party_map,
axis_def=axes,
output_path=compass_out,
)
logger.info("Wrote recomputed compass to %s", compass_out)
# compute simple trajectories from positions_by_window
# build per-MP coords
mp_coords = {}
for wid in sorted(positions_by_window.keys()):
for mp, coord in positions_by_window[wid].items():
mp_coords.setdefault(mp, []).append((wid, coord))
names = [n for n, v in mp_coords.items() if len(v) >= 2]
plot_2d_trajectories(positions_by_window, mp_names=names[:50], output_path=traj_out)
logger.info("Wrote recomputed trajectories to %s", traj_out)
# write a short diagnostic JSON (convert numpy arrays to lists)
import json
import numpy as _np
def _to_serializable(o):
if isinstance(o, _np.ndarray):
return o.tolist()
if isinstance(o, (_np.floating, _np.integer)):
return float(o)
raise TypeError(f"Object of type {type(o)} is not JSON serializable")
diag = {"windows": window_ids, "axes": axes}
with open(
os.path.join(args.out, "recompute_diag.json"), "w", encoding="utf-8"
) as f:
json.dump(diag, f, indent=2, default=_to_serializable)
logger.info("Recompute complete; outputs in %s and DB copy at %s", args.out, dst)
return 0
if __name__ == "__main__":
raise SystemExit(main())