You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
172 lines
5.8 KiB
172 lines
5.8 KiB
"""Recompute per-window SVD into a fresh DB copy and re-run 2D axes.
|
|
|
|
This script copies the current data/motions.db to a new file (data/motions_recompute.db),
|
|
clears any existing svd_vectors rows for the target windows in the new DB, runs
|
|
SVD on each window, then computes 2D axes and writes compass + trajectories into
|
|
outputs_recomputed/ for inspection.
|
|
|
|
Usage:
|
|
uv run python3 scripts/recompute_svd.py --db data/motions.db --out outputs_recomputed
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import calendar
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import sys
|
|
from datetime import date
|
|
from typing import List, Tuple
|
|
|
|
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if ROOT not in sys.path:
|
|
sys.path.insert(0, ROOT)
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logger = logging.getLogger("recompute_svd")
|
|
|
|
|
|
def year_bounds(window_id: str) -> Tuple[str, str]:
|
|
"""Return (start_date, end_date) for an annual window_id like '2024'.
|
|
|
|
Quarterly window IDs (containing '-Q') are not supported — this script
|
|
only processes annual windows.
|
|
"""
|
|
if "-Q" in window_id:
|
|
raise ValueError(
|
|
f"Quarterly window '{window_id}' is not supported. "
|
|
"Only annual windows should be recomputed."
|
|
)
|
|
y = int(window_id)
|
|
start = date(y, 1, 1).isoformat()
|
|
end = date(y, 12, 31).isoformat()
|
|
return start, end
|
|
|
|
|
|
def main(argv: List[str] | None = None) -> int:
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--db", default="data/motions.db")
|
|
p.add_argument("--out", default="outputs_recomputed")
|
|
p.add_argument("--k", type=int, default=50)
|
|
args = p.parse_args(argv)
|
|
|
|
os.makedirs(args.out, exist_ok=True)
|
|
|
|
# Copy DB to a new file so we don't clobber originals
|
|
src = args.db
|
|
dst = os.path.splitext(src)[0] + "_recompute.db"
|
|
logger.info("Copying %s -> %s", src, dst)
|
|
shutil.copyfile(src, dst)
|
|
|
|
# Lazy imports
|
|
try:
|
|
from database import MotionDatabase
|
|
from pipeline.svd_pipeline import run_svd_for_window
|
|
from analysis.political_axis import compute_2d_axes
|
|
from analysis.visualize import (
|
|
plot_political_compass,
|
|
plot_2d_trajectories,
|
|
_load_party_map,
|
|
)
|
|
from analysis import trajectory as traj
|
|
except Exception as e:
|
|
logger.exception("Import failed: %s", e)
|
|
return 2
|
|
|
|
# build MotionDatabase pointing to new file
|
|
db = MotionDatabase(dst)
|
|
|
|
# find windows from original DB via trajectory helper
|
|
all_window_ids = traj._load_window_ids(src)
|
|
# Only process annual windows — quarterly windows are excluded from all PCA/SVD computation
|
|
window_ids = [w for w in all_window_ids if "-Q" not in w]
|
|
if not window_ids:
|
|
logger.error("No annual windows found in source DB %s", src)
|
|
return 3
|
|
|
|
logger.info("Will recompute SVD for annual windows: %s", window_ids)
|
|
|
|
# clear existing svd_vectors rows for these windows in dst DB
|
|
import duckdb
|
|
|
|
conn = duckdb.connect(dst)
|
|
try:
|
|
conn.execute(
|
|
"DELETE FROM svd_vectors WHERE window_id IN ({})".format(
|
|
",".join([f"'{w}'" for w in window_ids])
|
|
)
|
|
)
|
|
conn.commit()
|
|
logger.info("Cleared existing svd_vectors rows for windows in %s", dst)
|
|
finally:
|
|
conn.close()
|
|
|
|
# Run SVD per window
|
|
for wid in window_ids:
|
|
start, end = year_bounds(wid)
|
|
logger.info("Running SVD for %s (%s -> %s) k=%d", wid, start, end, args.k)
|
|
res = run_svd_for_window(
|
|
db=db, window_id=wid, start_date=start, end_date=end, k=args.k
|
|
)
|
|
logger.info("SVD result for %s: %s", wid, res)
|
|
|
|
# Recompute 2D axes and plots from the recomputed DB
|
|
logger.info("Computing 2D axes (pca_residual=True) from recomputed DB")
|
|
positions_by_window, axes = compute_2d_axes(
|
|
dst, method="pca", pca_residual=True, normalize_vectors=True
|
|
)
|
|
if not positions_by_window:
|
|
logger.error("No positions returned from compute_2d_axes on recomputed DB")
|
|
return 5
|
|
|
|
latest = sorted(positions_by_window.keys())[-1]
|
|
party_map = _load_party_map(dst)
|
|
|
|
compass_out = os.path.join(args.out, f"political_compass_recomputed_{latest}.html")
|
|
traj_out = os.path.join(args.out, f"trajectories_recomputed_{latest}_top50.html")
|
|
|
|
plot_political_compass(
|
|
positions_by_window,
|
|
window_id=latest,
|
|
party_of=party_map,
|
|
axis_def=axes,
|
|
output_path=compass_out,
|
|
)
|
|
logger.info("Wrote recomputed compass to %s", compass_out)
|
|
|
|
# compute simple trajectories from positions_by_window
|
|
# build per-MP coords
|
|
mp_coords = {}
|
|
for wid in sorted(positions_by_window.keys()):
|
|
for mp, coord in positions_by_window[wid].items():
|
|
mp_coords.setdefault(mp, []).append((wid, coord))
|
|
|
|
names = [n for n, v in mp_coords.items() if len(v) >= 2]
|
|
plot_2d_trajectories(positions_by_window, mp_names=names[:50], output_path=traj_out)
|
|
logger.info("Wrote recomputed trajectories to %s", traj_out)
|
|
|
|
# write a short diagnostic JSON (convert numpy arrays to lists)
|
|
import json
|
|
import numpy as _np
|
|
|
|
def _to_serializable(o):
|
|
if isinstance(o, _np.ndarray):
|
|
return o.tolist()
|
|
if isinstance(o, (_np.floating, _np.integer)):
|
|
return float(o)
|
|
raise TypeError(f"Object of type {type(o)} is not JSON serializable")
|
|
|
|
diag = {"windows": window_ids, "axes": axes}
|
|
with open(
|
|
os.path.join(args.out, "recompute_diag.json"), "w", encoding="utf-8"
|
|
) as f:
|
|
json.dump(diag, f, indent=2, default=_to_serializable)
|
|
|
|
logger.info("Recompute complete; outputs in %s and DB copy at %s", args.out, dst)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|
|
|