You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
motief/agent_tools/analysis.py

170 lines
5.9 KiB

"""Analysis primitives for agent operation.
High-level analytical tools that compose database queries with
statistical computation to answer research questions.
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List, Optional
from agent_tools.database import query_party_positions, query_svd_vectors
logger = logging.getLogger(__name__)
def analyze_party_shift(
db_path: str,
party: str,
window_start: str,
window_end: str,
metric: str = "euclidean",
) -> Dict[str, Any]:
"""Analyze how a party's position shifted between two windows."""
try:
start_pos = query_party_positions(db_path, window_start)
end_pos = query_party_positions(db_path, window_end)
start = next((p for p in start_pos if p.get("party") == party), None)
end = next((p for p in end_pos if p.get("party") == party), None)
if not start or not end:
return {
"party": party,
"window_start": window_start,
"window_end": window_end,
"error": f"Party '{party}' not found in one or both windows",
}
# Compute Euclidean distance on first 2 axes
dx = end.get("axis_1", 0.0) - start.get("axis_1", 0.0)
dy = end.get("axis_2", 0.0) - start.get("axis_2", 0.0)
shift = (dx ** 2 + dy ** 2) ** 0.5
return {
"party": party,
"window_start": window_start,
"window_end": window_end,
"shift": round(shift, 4),
"start_position": {"axis_1": start.get("axis_1"), "axis_2": start.get("axis_2")},
"end_position": {"axis_1": end.get("axis_1"), "axis_2": end.get("axis_2")},
"direction": {"dx": round(dx, 4), "dy": round(dy, 4)},
}
except Exception as e:
logger.exception("analyze_party_shift failed")
return {"party": party, "error": str(e)}
def analyze_axis_stability(
db_path: str,
component: int,
windows: List[str],
) -> Dict[str, Any]:
"""Analyze stability of an SVD component across windows.
Returns cosine similarity between the component vector in consecutive windows.
"""
try:
vectors_by_window = {}
for window in windows:
rows = query_svd_vectors(db_path, window, entity_type="motion")
if rows:
vectors_by_window[window] = rows
if len(vectors_by_window) < 2:
return {
"component": component,
"windows": windows,
"error": "Need at least 2 windows with SVD vectors",
}
# Extract component scores for each window
# (component is 1-indexed in user-facing code, 0-indexed internally)
idx = component - 1
window_scores = {}
for window, rows in vectors_by_window.items():
scores = []
for row in rows:
vec = row.get("vector")
if isinstance(vec, str):
vec = json.loads(vec)
if isinstance(vec, list) and idx < len(vec):
scores.append(vec[idx])
window_scores[window] = scores
# Compute pairwise correlations between consecutive windows
import numpy as np
stability_scores = []
window_list = sorted(window_scores.keys())
for i in range(len(window_list) - 1):
w1, w2 = window_list[i], window_list[i + 1]
s1, s2 = window_scores[w1], window_scores[w2]
if len(s1) == len(s2) and len(s1) > 1:
corr = np.corrcoef(s1, s2)[0, 1]
stability_scores.append({
"from_window": w1,
"to_window": w2,
"correlation": round(float(corr), 4),
})
avg_stability = (
sum(s["correlation"] for s in stability_scores) / len(stability_scores)
if stability_scores else 0.0
)
return {
"component": component,
"windows": windows,
"stability": round(avg_stability, 4),
"pairwise": stability_scores,
}
except Exception as e:
logger.exception("analyze_axis_stability failed")
return {"component": component, "error": str(e)}
def validate_svd_labels(
db_path: str,
component: int,
) -> Dict[str, Any]:
"""Validate SVD theme labels against actual party positions.
Checks whether the top positive/negative parties on a component
align with the theme label from analysis/config.py.
"""
try:
from analysis.config import SVD_THEMES
theme = SVD_THEMES.get(component, {})
label = theme.get("label", "Unknown")
description = theme.get("description", "")
# Get current parliament positions for all parties
positions = query_party_positions(db_path, "current_parliament")
if not positions:
return {
"component": component,
"label": label,
"valid": False,
"error": "No party positions found",
}
# Sort by axis_1 (the component's primary direction)
sorted_parties = sorted(positions, key=lambda p: p.get("axis_1", 0.0))
negative_pole = sorted_parties[:3] if len(sorted_parties) >= 3 else sorted_parties[:1]
positive_pole = sorted_parties[-3:] if len(sorted_parties) >= 3 else sorted_parties[-1:]
return {
"component": component,
"label": label,
"description": description,
"valid": True,
"negative_pole": [{"party": p["party"], "score": round(p.get("axis_1", 0.0), 4)} for p in negative_pole],
"positive_pole": [{"party": p["party"], "score": round(p.get("axis_1", 0.0), 4)} for p in positive_pole],
}
except Exception as e:
logger.exception("validate_svd_labels failed")
return {"component": component, "valid": False, "error": str(e)}