You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
5.9 KiB
170 lines
5.9 KiB
"""Analysis primitives for agent operation.
|
|
|
|
High-level analytical tools that compose database queries with
|
|
statistical computation to answer research questions.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from agent_tools.database import query_party_positions, query_svd_vectors
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def analyze_party_shift(
|
|
db_path: str,
|
|
party: str,
|
|
window_start: str,
|
|
window_end: str,
|
|
metric: str = "euclidean",
|
|
) -> Dict[str, Any]:
|
|
"""Analyze how a party's position shifted between two windows."""
|
|
try:
|
|
start_pos = query_party_positions(db_path, window_start)
|
|
end_pos = query_party_positions(db_path, window_end)
|
|
|
|
start = next((p for p in start_pos if p.get("party") == party), None)
|
|
end = next((p for p in end_pos if p.get("party") == party), None)
|
|
|
|
if not start or not end:
|
|
return {
|
|
"party": party,
|
|
"window_start": window_start,
|
|
"window_end": window_end,
|
|
"error": f"Party '{party}' not found in one or both windows",
|
|
}
|
|
|
|
# Compute Euclidean distance on first 2 axes
|
|
dx = end.get("axis_1", 0.0) - start.get("axis_1", 0.0)
|
|
dy = end.get("axis_2", 0.0) - start.get("axis_2", 0.0)
|
|
shift = (dx ** 2 + dy ** 2) ** 0.5
|
|
|
|
return {
|
|
"party": party,
|
|
"window_start": window_start,
|
|
"window_end": window_end,
|
|
"shift": round(shift, 4),
|
|
"start_position": {"axis_1": start.get("axis_1"), "axis_2": start.get("axis_2")},
|
|
"end_position": {"axis_1": end.get("axis_1"), "axis_2": end.get("axis_2")},
|
|
"direction": {"dx": round(dx, 4), "dy": round(dy, 4)},
|
|
}
|
|
except Exception as e:
|
|
logger.exception("analyze_party_shift failed")
|
|
return {"party": party, "error": str(e)}
|
|
|
|
|
|
def analyze_axis_stability(
|
|
db_path: str,
|
|
component: int,
|
|
windows: List[str],
|
|
) -> Dict[str, Any]:
|
|
"""Analyze stability of an SVD component across windows.
|
|
|
|
Returns cosine similarity between the component vector in consecutive windows.
|
|
"""
|
|
try:
|
|
vectors_by_window = {}
|
|
for window in windows:
|
|
rows = query_svd_vectors(db_path, window, entity_type="motion")
|
|
if rows:
|
|
vectors_by_window[window] = rows
|
|
|
|
if len(vectors_by_window) < 2:
|
|
return {
|
|
"component": component,
|
|
"windows": windows,
|
|
"error": "Need at least 2 windows with SVD vectors",
|
|
}
|
|
|
|
# Extract component scores for each window
|
|
# (component is 1-indexed in user-facing code, 0-indexed internally)
|
|
idx = component - 1
|
|
window_scores = {}
|
|
for window, rows in vectors_by_window.items():
|
|
scores = []
|
|
for row in rows:
|
|
vec = row.get("vector")
|
|
if isinstance(vec, str):
|
|
vec = json.loads(vec)
|
|
if isinstance(vec, list) and idx < len(vec):
|
|
scores.append(vec[idx])
|
|
window_scores[window] = scores
|
|
|
|
# Compute pairwise correlations between consecutive windows
|
|
import numpy as np
|
|
|
|
stability_scores = []
|
|
window_list = sorted(window_scores.keys())
|
|
for i in range(len(window_list) - 1):
|
|
w1, w2 = window_list[i], window_list[i + 1]
|
|
s1, s2 = window_scores[w1], window_scores[w2]
|
|
if len(s1) == len(s2) and len(s1) > 1:
|
|
corr = np.corrcoef(s1, s2)[0, 1]
|
|
stability_scores.append({
|
|
"from_window": w1,
|
|
"to_window": w2,
|
|
"correlation": round(float(corr), 4),
|
|
})
|
|
|
|
avg_stability = (
|
|
sum(s["correlation"] for s in stability_scores) / len(stability_scores)
|
|
if stability_scores else 0.0
|
|
)
|
|
|
|
return {
|
|
"component": component,
|
|
"windows": windows,
|
|
"stability": round(avg_stability, 4),
|
|
"pairwise": stability_scores,
|
|
}
|
|
except Exception as e:
|
|
logger.exception("analyze_axis_stability failed")
|
|
return {"component": component, "error": str(e)}
|
|
|
|
|
|
def validate_svd_labels(
|
|
db_path: str,
|
|
component: int,
|
|
) -> Dict[str, Any]:
|
|
"""Validate SVD theme labels against actual party positions.
|
|
|
|
Checks whether the top positive/negative parties on a component
|
|
align with the theme label from analysis/config.py.
|
|
"""
|
|
try:
|
|
from analysis.config import SVD_THEMES
|
|
|
|
theme = SVD_THEMES.get(component, {})
|
|
label = theme.get("label", "Unknown")
|
|
description = theme.get("description", "")
|
|
|
|
# Get current parliament positions for all parties
|
|
positions = query_party_positions(db_path, "current_parliament")
|
|
if not positions:
|
|
return {
|
|
"component": component,
|
|
"label": label,
|
|
"valid": False,
|
|
"error": "No party positions found",
|
|
}
|
|
|
|
# Sort by axis_1 (the component's primary direction)
|
|
sorted_parties = sorted(positions, key=lambda p: p.get("axis_1", 0.0))
|
|
negative_pole = sorted_parties[:3] if len(sorted_parties) >= 3 else sorted_parties[:1]
|
|
positive_pole = sorted_parties[-3:] if len(sorted_parties) >= 3 else sorted_parties[-1:]
|
|
|
|
return {
|
|
"component": component,
|
|
"label": label,
|
|
"description": description,
|
|
"valid": True,
|
|
"negative_pole": [{"party": p["party"], "score": round(p.get("axis_1", 0.0), 4)} for p in negative_pole],
|
|
"positive_pole": [{"party": p["party"], "score": round(p.get("axis_1", 0.0), 4)} for p in positive_pole],
|
|
}
|
|
except Exception as e:
|
|
logger.exception("validate_svd_labels failed")
|
|
return {"component": component, "valid": False, "error": str(e)}
|
|
|