diff --git a/analysis/axis_classifier.py b/analysis/axis_classifier.py index d71b932..4fd11c6 100644 --- a/analysis/axis_classifier.py +++ b/analysis/axis_classifier.py @@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Tuple import numpy as np import re +import json _logger = logging.getLogger(__name__) @@ -145,6 +146,102 @@ def _classify_from_titles(titles: List[str]) -> Tuple[Optional[str], float]: return best_cats[0], confidence +def _load_motion_vectors(db_path: str, window_id: str) -> Dict[int, np.ndarray]: + """Load SVD motion vectors for a given window from DuckDB. + + Returns {motion_id: vector_array}. Returns {} on any error. + """ + try: + import duckdb + + conn = duckdb.connect(db_path, read_only=True) + rows = conn.execute( + "SELECT entity_id, vector FROM svd_vectors " + "WHERE entity_type = 'motion' AND window_id = ?", + [window_id], + ).fetchall() + conn.close() + result: Dict[int, np.ndarray] = {} + for entity_id, vector_raw in rows: + try: + mid = int(entity_id) + vec = np.array(json.loads(vector_raw), dtype=float) + result[mid] = vec + except Exception: + continue + return result + except Exception as exc: + _logger.debug("Failed to load motion vectors for window %s: %s", window_id, exc) + return {} + + +def _project_motions( + motion_vecs: Dict[int, np.ndarray], + x_axis: np.ndarray, + y_axis: np.ndarray, + global_mean: np.ndarray, +) -> Dict[int, Tuple[float, float]]: + """Project motion vectors onto the PCA axes after centering by global_mean. + + Returns {motion_id: (x_score, y_score)}. + """ + projections: Dict[int, Tuple[float, float]] = {} + for mid, vec in motion_vecs.items(): + try: + centered = vec - global_mean + x_score = float(np.dot(centered, x_axis)) + y_score = float(np.dot(centered, y_axis)) + projections[mid] = (x_score, y_score) + except Exception: + continue + return projections + + +def _top_motion_ids( + projections: Dict[int, Tuple[float, float]], + axis: str, + n: int = 5, +) -> Dict[str, List[int]]: + """Return the top-n motion IDs at each pole of the given axis. + + axis: 'x' or 'y' + Returns {'+': [motion_ids], '-': [motion_ids]} (highest positive first, + most negative first in the '-' list). + """ + idx = 0 if axis == "x" else 1 + sorted_ids = sorted(projections, key=lambda mid: projections[mid][idx]) + neg_ids = sorted_ids[:n] # most negative + pos_ids = sorted_ids[-n:][::-1] # most positive + return {"+": pos_ids, "-": neg_ids} + + +def _fetch_motion_titles( + db_path: str, + motion_ids: List[int], +) -> Dict[int, Tuple[str, str]]: + """Fetch (title, date) for a list of motion IDs from DuckDB. + + Returns {motion_id: (title, date_str)}. Missing IDs are omitted. + Returns {} on any DB error. + """ + if not motion_ids: + return {} + try: + import duckdb + + placeholders = ", ".join("?" * len(motion_ids)) + conn = duckdb.connect(db_path, read_only=True) + rows = conn.execute( + f"SELECT id, title, date FROM motions WHERE id IN ({placeholders})", + motion_ids, + ).fetchall() + conn.close() + return {int(row[0]): (str(row[1]), str(row[2])) for row in rows} + except Exception as exc: + _logger.debug("Failed to fetch motion titles: %s", exc) + return {} + + def _load_ideology(csv_path: Path) -> Dict[str, Dict[str, float]]: """Load party ideology scores from CSV.