"""Analysis primitives for agent operation. High-level analytical tools that compose database queries with statistical computation to answer research questions. """ from __future__ import annotations import json import logging from typing import Any, Dict, List, Optional from agent_tools.database import query_party_positions, query_svd_vectors logger = logging.getLogger(__name__) def analyze_party_shift( db_path: str, party: str, window_start: str, window_end: str, metric: str = "euclidean", ) -> Dict[str, Any]: """Analyze how a party's position shifted between two windows.""" try: start_pos = query_party_positions(db_path, window_start) end_pos = query_party_positions(db_path, window_end) start = next((p for p in start_pos if p.get("party") == party), None) end = next((p for p in end_pos if p.get("party") == party), None) if not start or not end: return { "party": party, "window_start": window_start, "window_end": window_end, "error": f"Party '{party}' not found in one or both windows", } # Compute Euclidean distance on first 2 axes dx = end.get("axis_1", 0.0) - start.get("axis_1", 0.0) dy = end.get("axis_2", 0.0) - start.get("axis_2", 0.0) shift = (dx ** 2 + dy ** 2) ** 0.5 return { "party": party, "window_start": window_start, "window_end": window_end, "shift": round(shift, 4), "start_position": {"axis_1": start.get("axis_1"), "axis_2": start.get("axis_2")}, "end_position": {"axis_1": end.get("axis_1"), "axis_2": end.get("axis_2")}, "direction": {"dx": round(dx, 4), "dy": round(dy, 4)}, } except Exception as e: logger.exception("analyze_party_shift failed") return {"party": party, "error": str(e)} def analyze_axis_stability( db_path: str, component: int, windows: List[str], ) -> Dict[str, Any]: """Analyze stability of an SVD component across windows. Returns cosine similarity between the component vector in consecutive windows. """ try: vectors_by_window = {} for window in windows: rows = query_svd_vectors(db_path, window, entity_type="motion") if rows: vectors_by_window[window] = rows if len(vectors_by_window) < 2: return { "component": component, "windows": windows, "error": "Need at least 2 windows with SVD vectors", } # Extract component scores for each window # (component is 1-indexed in user-facing code, 0-indexed internally) idx = component - 1 window_scores = {} for window, rows in vectors_by_window.items(): scores = [] for row in rows: vec = row.get("vector") if isinstance(vec, str): vec = json.loads(vec) if isinstance(vec, list) and idx < len(vec): scores.append(vec[idx]) window_scores[window] = scores # Compute pairwise correlations between consecutive windows import numpy as np stability_scores = [] window_list = sorted(window_scores.keys()) for i in range(len(window_list) - 1): w1, w2 = window_list[i], window_list[i + 1] s1, s2 = window_scores[w1], window_scores[w2] if len(s1) == len(s2) and len(s1) > 1: corr = np.corrcoef(s1, s2)[0, 1] stability_scores.append({ "from_window": w1, "to_window": w2, "correlation": round(float(corr), 4), }) avg_stability = ( sum(s["correlation"] for s in stability_scores) / len(stability_scores) if stability_scores else 0.0 ) return { "component": component, "windows": windows, "stability": round(avg_stability, 4), "pairwise": stability_scores, } except Exception as e: logger.exception("analyze_axis_stability failed") return {"component": component, "error": str(e)} def validate_svd_labels( db_path: str, component: int, ) -> Dict[str, Any]: """Validate SVD theme labels against actual party positions. Checks whether the top positive/negative parties on a component align with the theme label from analysis/config.py. """ try: from analysis.config import SVD_THEMES theme = SVD_THEMES.get(component, {}) label = theme.get("label", "Unknown") description = theme.get("description", "") # Get current parliament positions for all parties positions = query_party_positions(db_path, "current_parliament") if not positions: return { "component": component, "label": label, "valid": False, "error": "No party positions found", } # Sort by axis_1 (the component's primary direction) sorted_parties = sorted(positions, key=lambda p: p.get("axis_1", 0.0)) negative_pole = sorted_parties[:3] if len(sorted_parties) >= 3 else sorted_parties[:1] positive_pole = sorted_parties[-3:] if len(sorted_parties) >= 3 else sorted_parties[-1:] return { "component": component, "label": label, "description": description, "valid": True, "negative_pole": [{"party": p["party"], "score": round(p.get("axis_1", 0.0), 4)} for p in negative_pole], "positive_pole": [{"party": p["party"], "score": round(p.get("axis_1", 0.0), 4)} for p in positive_pole], } except Exception as e: logger.exception("validate_svd_labels failed") return {"component": component, "valid": False, "error": str(e)}