You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/analysis/visualize.py

163 lines
4.9 KiB

"""visualize.py — Plotly interactive plots for parliamentary embeddings.
Produces self-contained HTML files.
Functions:
plot_umap_scatter — 2D scatter of fused motion embeddings, coloured by cluster
plot_mp_trajectory — Line plot of MP drift across windows
plot_political_axis — Bar chart of MP scores on the ideological axis
"""
import logging
from typing import Dict, List, Optional
import numpy as np
_logger = logging.getLogger(__name__)
def _require_plotly():
try:
import plotly.graph_objects as go
import plotly.express as px
return go, px
except ImportError:
raise ImportError("plotly is not installed. Install it with: uv add plotly")
def plot_umap_scatter(
motion_ids: List[int],
coords: List[List[float]],
labels: Optional[List[int]] = None,
window_id: Optional[str] = None,
output_path: str = "analysis_umap.html",
) -> str:
"""Produce a 2D scatter plot of UMAP-reduced fused embeddings.
Args:
motion_ids: Motion IDs (used as hover labels)
coords: List of [x, y] coordinates
labels: Optional cluster labels (integer per motion)
window_id: Window label for the plot title
output_path: Where to write the self-contained HTML
Returns the output_path on success.
"""
go, px = _require_plotly()
xs = [c[0] for c in coords]
ys = [c[1] for c in coords]
color = labels if labels is not None else [0] * len(motion_ids)
title = f"UMAP — fused motion embeddings" + (f" ({window_id})" if window_id else "")
fig = px.scatter(
x=xs,
y=ys,
color=[str(c) for c in color],
hover_name=[str(mid) for mid in motion_ids],
title=title,
labels={"x": "UMAP-1", "y": "UMAP-2", "color": "Cluster"},
)
fig.write_html(output_path, include_plotlyjs="cdn")
_logger.info("UMAP scatter written to %s", output_path)
return output_path
def plot_mp_trajectory(
trajectories: Dict[str, Dict],
mp_names: Optional[List[str]] = None,
output_path: str = "analysis_trajectory.html",
) -> str:
"""Line plot of MP drift across time windows.
Args:
trajectories: Output of analysis.trajectory.compute_trajectories()
mp_names: Subset of MPs to plot (default: all)
output_path: Output HTML file path
Returns the output_path on success.
"""
go, px = _require_plotly()
if mp_names is None:
mp_names = list(trajectories.keys())
fig = go.Figure()
for mp in mp_names:
if mp not in trajectories:
continue
data = trajectories[mp]
windows = data["windows"]
drifts_cumulative = [0.0] + list(np.cumsum(data["drift"]))
# Plot cumulative drift per window transition
x_labels = windows[: len(drifts_cumulative)]
fig.add_trace(
go.Scatter(
x=x_labels,
y=drifts_cumulative,
mode="lines+markers",
name=mp,
)
)
fig.update_layout(
title="MP Political Drift Over Time (Cumulative)",
xaxis_title="Window",
yaxis_title="Cumulative Drift",
)
fig.write_html(output_path, include_plotlyjs="cdn")
_logger.info("Trajectory plot written to %s", output_path)
return output_path
def plot_political_axis(
scores: Dict[str, float],
party_of: Optional[Dict[str, str]] = None,
window_id: Optional[str] = None,
n_top: int = 30,
output_path: str = "analysis_political_axis.html",
) -> str:
"""Horizontal bar chart of MP scores on the ideological axis.
Args:
scores: {mp_name: score} from political_axis module
party_of: Optional {mp_name: party} for colour-coding
window_id: Window label for the title
n_top: Show only the top/bottom n MPs by score
output_path: Output HTML path
Returns the output_path on success.
"""
go, px = _require_plotly()
# Sort by score
sorted_items = sorted(scores.items(), key=lambda kv: kv[1])
# Take n_top from each end if list is large
if len(sorted_items) > 2 * n_top:
sorted_items = sorted_items[:n_top] + sorted_items[-n_top:]
names = [item[0] for item in sorted_items]
vals = [item[1] for item in sorted_items]
colors = (
[party_of.get(n, "Unknown") for n in names]
if party_of
else ["Unknown"] * len(names)
)
title = "MP Ideological Axis Score" + (f" ({window_id})" if window_id else "")
fig = px.bar(
x=vals,
y=names,
color=colors,
orientation="h",
title=title,
labels={"x": "Score (← left — right →)", "y": "MP", "color": "Party"},
)
fig.update_layout(yaxis={"categoryorder": "total ascending"})
fig.write_html(output_path, include_plotlyjs="cdn")
_logger.info("Political axis chart written to %s", output_path)
return output_path