feat: add motion-loading helpers to axis_classifier

main
Sven Geboers 1 month ago
parent 1e52a8a8cc
commit 96224be6ee
  1. 97
      analysis/axis_classifier.py

@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import re
import json
_logger = logging.getLogger(__name__)
@ -145,6 +146,102 @@ def _classify_from_titles(titles: List[str]) -> Tuple[Optional[str], float]:
return best_cats[0], confidence
def _load_motion_vectors(db_path: str, window_id: str) -> Dict[int, np.ndarray]:
"""Load SVD motion vectors for a given window from DuckDB.
Returns {motion_id: vector_array}. Returns {} on any error.
"""
try:
import duckdb
conn = duckdb.connect(db_path, read_only=True)
rows = conn.execute(
"SELECT entity_id, vector FROM svd_vectors "
"WHERE entity_type = 'motion' AND window_id = ?",
[window_id],
).fetchall()
conn.close()
result: Dict[int, np.ndarray] = {}
for entity_id, vector_raw in rows:
try:
mid = int(entity_id)
vec = np.array(json.loads(vector_raw), dtype=float)
result[mid] = vec
except Exception:
continue
return result
except Exception as exc:
_logger.debug("Failed to load motion vectors for window %s: %s", window_id, exc)
return {}
def _project_motions(
motion_vecs: Dict[int, np.ndarray],
x_axis: np.ndarray,
y_axis: np.ndarray,
global_mean: np.ndarray,
) -> Dict[int, Tuple[float, float]]:
"""Project motion vectors onto the PCA axes after centering by global_mean.
Returns {motion_id: (x_score, y_score)}.
"""
projections: Dict[int, Tuple[float, float]] = {}
for mid, vec in motion_vecs.items():
try:
centered = vec - global_mean
x_score = float(np.dot(centered, x_axis))
y_score = float(np.dot(centered, y_axis))
projections[mid] = (x_score, y_score)
except Exception:
continue
return projections
def _top_motion_ids(
projections: Dict[int, Tuple[float, float]],
axis: str,
n: int = 5,
) -> Dict[str, List[int]]:
"""Return the top-n motion IDs at each pole of the given axis.
axis: 'x' or 'y'
Returns {'+': [motion_ids], '-': [motion_ids]} (highest positive first,
most negative first in the '-' list).
"""
idx = 0 if axis == "x" else 1
sorted_ids = sorted(projections, key=lambda mid: projections[mid][idx])
neg_ids = sorted_ids[:n] # most negative
pos_ids = sorted_ids[-n:][::-1] # most positive
return {"+": pos_ids, "-": neg_ids}
def _fetch_motion_titles(
db_path: str,
motion_ids: List[int],
) -> Dict[int, Tuple[str, str]]:
"""Fetch (title, date) for a list of motion IDs from DuckDB.
Returns {motion_id: (title, date_str)}. Missing IDs are omitted.
Returns {} on any DB error.
"""
if not motion_ids:
return {}
try:
import duckdb
placeholders = ", ".join("?" * len(motion_ids))
conn = duckdb.connect(db_path, read_only=True)
rows = conn.execute(
f"SELECT id, title, date FROM motions WHERE id IN ({placeholders})",
motion_ids,
).fetchall()
conn.close()
return {int(row[0]): (str(row[1]), str(row[2])) for row in rows}
except Exception as exc:
_logger.debug("Failed to fetch motion titles: %s", exc)
return {}
def _load_ideology(csv_path: Path) -> Dict[str, Dict[str, float]]:
"""Load party ideology scores from CSV.

Loading…
Cancel
Save