You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/scripts/svd_diagnostics.py

214 lines
6.5 KiB

"""SVD and PCA diagnostics for the political compass pipeline.
Produces a small text report and JSON summary in the outputs/ directory.
Usage:
uv run python3 scripts/svd_diagnostics.py --db data/motions.db --out outputs
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import sys
from statistics import mean
from typing import Dict, List, Optional, Tuple
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
logger = logging.getLogger("svd_diagnostics")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def find_by_substring(names: List[str], query: str) -> List[str]:
q = query.lower()
return [n for n in names if q in n.lower()]
def main(argv: Optional[list] = None):
p = argparse.ArgumentParser()
p.add_argument("--db", default="data/motions.db")
p.add_argument("--out", default="outputs")
args = p.parse_args(argv)
os.makedirs(args.out, exist_ok=True)
try:
from analysis import trajectory as traj
from analysis.political_axis import compute_2d_axes
from analysis.visualize import _load_party_map
except Exception as e: # pragma: no cover - runtime
logger.exception("Could not import analysis modules: %s", e)
raise
# Load windows and aligned vectors
window_ids = traj._load_window_ids(args.db)
if not window_ids:
logger.error("No SVD windows found in DB %s", args.db)
return 1
logger.info("Found windows: %s", window_ids)
raw_window_vecs = {
wid: traj._load_mp_vectors_for_window(args.db, wid) for wid in window_ids
}
aligned_window_vecs = traj._procrustes_align_windows(raw_window_vecs)
# Compute global PCA axes (residual and non-residual) for comparison
positions_residual, axes_residual = compute_2d_axes(
args.db,
window_ids=window_ids,
method="pca",
normalize_vectors=True,
pca_residual=True,
)
positions_plain, axes_plain = compute_2d_axes(
args.db,
window_ids=window_ids,
method="pca",
normalize_vectors=True,
pca_residual=False,
)
out_report = []
def add(line: str):
out_report.append(line)
logger.info(line)
add("PCA diagnostics report")
add(f"DB: {args.db}")
add(f"Windows: {window_ids}")
add("")
evr_res = axes_residual.get("explained_variance_ratio") if axes_residual else None
evr_plain = axes_plain.get("explained_variance_ratio") if axes_plain else None
add(f"Residual PCA EVR: {evr_res}")
add(f"Plain PCA EVR: {evr_plain}")
# pick latest window for detailed inspection
latest = sorted(window_ids)[-1]
add("")
add(f"Inspecting latest window: {latest}")
pos = positions_residual.get(latest, {})
names = list(pos.keys())
xs = [v[0] for v in pos.values()]
ys = [v[1] for v in pos.values()]
def stats(arr: List[float]) -> Tuple[float, float]:
if not arr:
return 0.0, 0.0
mn = min(arr)
mx = max(arr)
return mn, mx
add(f"Entities in latest window: {len(names)}")
add(f"X range (left-right): {stats(xs)}")
add(f"Y range (prog-cons): {stats(ys)}")
# stdevs
try:
import numpy as _np
x_std = float(_np.std(xs))
y_std = float(_np.std(ys))
except Exception:
x_std = 0.0
y_std = 0.0
add(
f"Std dev X: {x_std:.6f}, Std dev Y: {y_std:.6f} (ratio Y/X = {y_std / (x_std + 1e-12):.3f})"
)
# show extremes on X and Y
sorted_by_x = sorted(pos.items(), key=lambda kv: kv[1][0])
sorted_by_y = sorted(pos.items(), key=lambda kv: kv[1][1])
add("")
add("Left-most (by X):")
for name, (x, y) in sorted_by_x[:8]:
add(f" {name:40s} x={x:.4f} y={y:.4f}")
add("")
add("Right-most (by X):")
for name, (x, y) in sorted_by_x[-8:]:
add(f" {name:40s} x={x:.4f} y={y:.4f}")
add("")
add("Top (conservative) (by Y):")
for name, (x, y) in sorted_by_y[-8:]:
add(f" {name:40s} x={x:.4f} y={y:.4f}")
add("")
add("Bottom (progressive) (by Y):")
for name, (x, y) in sorted_by_y[:8]:
add(f" {name:40s} x={x:.4f} y={y:.4f}")
# Find specific MPs mentioned by user
matches_ouwehand = find_by_substring(names, "ouwehand")
matches_mona = find_by_substring(names, "mona")
add("")
add(f"Matches for 'Ouwehand': {matches_ouwehand}")
for n in matches_ouwehand:
x, y = pos.get(n)
add(f" {n} -> x={x:.4f} y={y:.4f}")
add(f"Matches for 'Mona': {matches_mona}")
for n in matches_mona:
x, y = pos.get(n)
add(f" {n} -> x={x:.4f} y={y:.4f}")
# Party centroids
party_map = _load_party_map(args.db)
parties: Dict[str, List[Tuple[float, float]]] = {}
for mp, coord in pos.items():
party = party_map.get(mp)
if party:
parties.setdefault(party, []).append(coord)
party_centroids: Dict[str, Tuple[float, float]] = {}
for party, coords in parties.items():
xs_p = [c[0] for c in coords]
ys_p = [c[1] for c in coords]
party_centroids[party] = (mean(xs_p), mean(ys_p))
add("")
add(f"Computed {len(party_centroids)} party centroids (from mp_metadata majority)")
sorted_parties_by_x = sorted(party_centroids.items(), key=lambda kv: kv[1][0])
add("Party centroids left→right:")
for p, (x, y) in sorted_parties_by_x:
add(f" {p:20s} x={x:.4f} y={y:.4f}")
sorted_parties_by_y = sorted(party_centroids.items(), key=lambda kv: kv[1][1])
add("")
add("Party centroids prog→cons:")
for p, (x, y) in sorted_parties_by_y:
add(f" {p:20s} x={x:.4f} y={y:.4f}")
# Save report and a small JSON summary
report_path = os.path.join(args.out, "svd_diagnostics.txt")
summary_path = os.path.join(args.out, "svd_diagnostics.json")
with open(report_path, "w", encoding="utf-8") as f:
f.write("\n".join(out_report))
summary = {
"db": args.db,
"windows": window_ids,
"latest_window": latest,
"evr_residual": evr_res,
"evr_plain": evr_plain,
"n_entities_latest": len(names),
"x_std": x_std,
"y_std": y_std,
"party_centroids": party_centroids,
}
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, indent=2)
logger.info("Diagnostic report written to %s and %s", report_path, summary_path)
return 0
if __name__ == "__main__":
raise SystemExit(main())