You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
204 lines
6.6 KiB
204 lines
6.6 KiB
"""Compare PCA axes with and without party-level vectors present.
|
|
|
|
Generates diagnostics and HTML plots (when plotly available) into outputs/.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from typing import Dict, List
|
|
|
|
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if ROOT not in sys.path:
|
|
sys.path.insert(0, ROOT)
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logger = logging.getLogger("compare_svd_exclude_parties")
|
|
|
|
|
|
def main(argv: List[str] | None = None):
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--db", default="data/motions.db")
|
|
p.add_argument("--out", default="outputs")
|
|
args = p.parse_args(argv)
|
|
|
|
os.makedirs(args.out, exist_ok=True)
|
|
|
|
try:
|
|
from analysis import trajectory as traj
|
|
from analysis.visualize import (
|
|
_load_party_map,
|
|
plot_political_compass,
|
|
plot_2d_trajectories,
|
|
)
|
|
import numpy as np
|
|
except Exception as e:
|
|
logger.exception("Failed to import analysis modules: %s", e)
|
|
raise
|
|
|
|
window_ids = traj._load_window_ids(args.db)
|
|
if not window_ids:
|
|
logger.error("No SVD windows found")
|
|
return 1
|
|
latest = sorted(window_ids)[-1]
|
|
|
|
# load raw vectors for latest window
|
|
conn = None
|
|
try:
|
|
# build party name set from mp_metadata
|
|
import duckdb
|
|
|
|
conn = duckdb.connect(args.db)
|
|
rows = conn.execute(
|
|
"SELECT DISTINCT party FROM mp_metadata WHERE party IS NOT NULL"
|
|
).fetchall()
|
|
party_names = set(r[0] for r in rows if r[0])
|
|
finally:
|
|
if conn:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
raw = traj._load_mp_vectors_for_window(args.db, latest)
|
|
# group by vector JSON-like key
|
|
groups: Dict[str, List[str]] = {}
|
|
for ent, vec in raw.items():
|
|
key = tuple([round(float(x), 8) for x in vec.tolist()])
|
|
groups.setdefault(str(key), []).append(ent)
|
|
|
|
group_list = sorted(groups.items(), key=lambda kv: len(kv[1]), reverse=True)
|
|
|
|
top_groups = [(len(v), v[:8]) for k, v in group_list[:20]]
|
|
logger.info("Top duplicate groups (count, sample entities): %s", top_groups)
|
|
|
|
# entities that are party names
|
|
party_entities = [ent for ent in raw.keys() if ent in party_names]
|
|
logger.info(
|
|
"Found %d party-like entities in svd_vectors for %s",
|
|
len(party_entities),
|
|
latest,
|
|
)
|
|
|
|
# Build aligned windows excluding party-level entities
|
|
raw_window_vecs = {
|
|
wid: traj._load_mp_vectors_for_window(args.db, wid) for wid in window_ids
|
|
}
|
|
# create filtered copy that removes party-level entity ids
|
|
filtered_window_vecs = {
|
|
wid: {ent: vec for ent, vec in d.items() if ent not in party_names}
|
|
for wid, d in raw_window_vecs.items()
|
|
}
|
|
|
|
aligned_filtered = traj._procrustes_align_windows(filtered_window_vecs)
|
|
# stack and compute PCA
|
|
all_vecs = []
|
|
entity_index = []
|
|
for wid, d in aligned_filtered.items():
|
|
for ent, v in d.items():
|
|
n = np.linalg.norm(v)
|
|
all_vecs.append(v / n if n > 1e-10 else v)
|
|
entity_index.append((wid, ent))
|
|
|
|
if not all_vecs:
|
|
logger.error("No vectors left after excluding parties — aborting")
|
|
return 2
|
|
|
|
M = np.vstack(all_vecs)
|
|
Mc = M - M.mean(axis=0)
|
|
try:
|
|
U, s, Vt = np.linalg.svd(Mc, full_matrices=False)
|
|
except Exception:
|
|
logger.exception("SVD failed on filtered data")
|
|
return 3
|
|
|
|
sv2 = s**2
|
|
evr = sv2 / (sv2.sum() + 1e-20)
|
|
logger.info("Filtered PCA EVR top2: %s", evr[:2].tolist())
|
|
|
|
comp1 = Vt[0]
|
|
comp1_hat = comp1 / (np.linalg.norm(comp1) + 1e-12)
|
|
comp2 = Vt[1] if Vt.shape[0] > 1 else np.zeros_like(comp1)
|
|
comp2_hat = comp2 / (np.linalg.norm(comp2) + 1e-12)
|
|
|
|
# project filtered entities for latest window
|
|
filtered_positions = {}
|
|
global_mean = M.mean(axis=0)
|
|
for (wid, ent), vec in zip(entity_index, M):
|
|
if wid != latest:
|
|
continue
|
|
v_centered = vec - global_mean
|
|
x = float(np.dot(v_centered, comp1_hat))
|
|
y = float(np.dot(v_centered, comp2_hat))
|
|
filtered_positions[ent] = (x, y)
|
|
|
|
# save JSON and small report
|
|
out_json = os.path.join(args.out, "svd_filtered_positions.json")
|
|
with open(out_json, "w", encoding="utf-8") as f:
|
|
json.dump(
|
|
{
|
|
"latest": latest,
|
|
"positions": filtered_positions,
|
|
"evr": evr[:2].tolist(),
|
|
},
|
|
f,
|
|
indent=2,
|
|
)
|
|
logger.info("Wrote filtered positions to %s", out_json)
|
|
|
|
# Also generate plots if plotly available
|
|
try:
|
|
party_map = _load_party_map(args.db)
|
|
# positions_by_window format expected by plot functions — include only latest
|
|
positions_by_window = {latest: filtered_positions}
|
|
pcomp_out = os.path.join(args.out, f"political_compass_filtered_{latest}.html")
|
|
plot_political_compass(
|
|
positions_by_window,
|
|
window_id=latest,
|
|
party_of=party_map,
|
|
axis_def={"method": "pca", "explained_variance_ratio": evr[:2]},
|
|
output_path=pcomp_out,
|
|
)
|
|
logger.info("Wrote filtered compass to %s", pcomp_out)
|
|
# simple trajectory plotting for filtered set — top movers by count
|
|
traj_out = os.path.join(args.out, f"trajectories_filtered_{latest}.html")
|
|
# Build simple per-MP coords across windows for filtered set
|
|
mp_coords = {}
|
|
for wid in window_ids:
|
|
for ent, coord in aligned_filtered.get(wid, {}).items():
|
|
if ent not in mp_coords:
|
|
mp_coords[ent] = []
|
|
mp_coords[ent].append((wid, tuple(coord.tolist())))
|
|
# pick MPs with at least 2 windows
|
|
names = [n for n, v in mp_coords.items() if len(v) >= 2]
|
|
plot_2d_trajectories(
|
|
{
|
|
wid: {
|
|
n: mp_coords[n][i][1]
|
|
for n in names
|
|
for i, (w, _) in enumerate(mp_coords[n])
|
|
if w == wid
|
|
}
|
|
for wid in window_ids
|
|
},
|
|
mp_names=names[:50],
|
|
output_path=traj_out,
|
|
)
|
|
logger.info("Wrote filtered trajectories to %s", traj_out)
|
|
except Exception:
|
|
logger.exception("Plotting filtered results failed — plots skipped")
|
|
|
|
# console summary
|
|
print("Top duplicate groups (count, sample):")
|
|
for k, v in group_list[:20]:
|
|
print(len(v), v[:6])
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|
|
|