You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
285 lines
8.9 KiB
285 lines
8.9 KiB
"""visualize.py — Plotly interactive plots for parliamentary embeddings.
|
|
|
|
Produces self-contained HTML files.
|
|
|
|
Functions:
|
|
plot_umap_scatter — 2D scatter of fused motion embeddings, coloured by cluster
|
|
plot_mp_trajectory — Line plot of MP drift across windows
|
|
plot_political_axis — Bar chart of MP scores on the ideological axis
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
from typing import Any
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _require_plotly():
|
|
try:
|
|
import plotly.graph_objects as go
|
|
import plotly.express as px
|
|
|
|
return go, px
|
|
except ImportError:
|
|
raise ImportError("plotly is not installed. Install it with: uv add plotly")
|
|
|
|
|
|
def plot_umap_scatter(
|
|
motion_ids: List[int],
|
|
coords: List[List[float]],
|
|
labels: Optional[List[int]] = None,
|
|
window_id: Optional[str] = None,
|
|
output_path: str = "analysis_umap.html",
|
|
) -> str:
|
|
"""Produce a 2D scatter plot of UMAP-reduced fused embeddings.
|
|
|
|
Args:
|
|
motion_ids: Motion IDs (used as hover labels)
|
|
coords: List of [x, y] coordinates
|
|
labels: Optional cluster labels (integer per motion)
|
|
window_id: Window label for the plot title
|
|
output_path: Where to write the self-contained HTML
|
|
|
|
Returns the output_path on success.
|
|
"""
|
|
go, px = _require_plotly()
|
|
|
|
xs = [c[0] for c in coords]
|
|
ys = [c[1] for c in coords]
|
|
color = labels if labels is not None else [0] * len(motion_ids)
|
|
title = f"UMAP — fused motion embeddings" + (f" ({window_id})" if window_id else "")
|
|
|
|
fig = px.scatter(
|
|
x=xs,
|
|
y=ys,
|
|
color=[str(c) for c in color],
|
|
hover_name=[str(mid) for mid in motion_ids],
|
|
title=title,
|
|
labels={"x": "UMAP-1", "y": "UMAP-2", "color": "Cluster"},
|
|
)
|
|
fig.write_html(output_path, include_plotlyjs="cdn")
|
|
_logger.info("UMAP scatter written to %s", output_path)
|
|
return output_path
|
|
|
|
|
|
def plot_mp_trajectory(
|
|
trajectories: Dict[str, Dict],
|
|
mp_names: Optional[List[str]] = None,
|
|
output_path: str = "analysis_trajectory.html",
|
|
) -> str:
|
|
"""Line plot of MP drift across time windows.
|
|
|
|
Args:
|
|
trajectories: Output of analysis.trajectory.compute_trajectories()
|
|
mp_names: Subset of MPs to plot (default: all)
|
|
output_path: Output HTML file path
|
|
|
|
Returns the output_path on success.
|
|
"""
|
|
go, px = _require_plotly()
|
|
|
|
if mp_names is None:
|
|
mp_names = list(trajectories.keys())
|
|
|
|
fig = go.Figure()
|
|
|
|
for mp in mp_names:
|
|
if mp not in trajectories:
|
|
continue
|
|
data = trajectories[mp]
|
|
windows = data["windows"]
|
|
drifts_cumulative = [0.0] + list(np.cumsum(data["drift"]))
|
|
# Plot cumulative drift per window transition
|
|
x_labels = windows[: len(drifts_cumulative)]
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=x_labels,
|
|
y=drifts_cumulative,
|
|
mode="lines+markers",
|
|
name=mp,
|
|
)
|
|
)
|
|
|
|
fig.update_layout(
|
|
title="MP Political Drift Over Time (Cumulative)",
|
|
xaxis_title="Window",
|
|
yaxis_title="Cumulative Drift",
|
|
)
|
|
fig.write_html(output_path, include_plotlyjs="cdn")
|
|
_logger.info("Trajectory plot written to %s", output_path)
|
|
return output_path
|
|
|
|
|
|
def plot_political_axis(
|
|
scores: Dict[str, float],
|
|
party_of: Optional[Dict[str, str]] = None,
|
|
window_id: Optional[str] = None,
|
|
n_top: int = 30,
|
|
output_path: str = "analysis_political_axis.html",
|
|
) -> str:
|
|
"""Horizontal bar chart of MP scores on the ideological axis.
|
|
|
|
Args:
|
|
scores: {mp_name: score} from political_axis module
|
|
party_of: Optional {mp_name: party} for colour-coding
|
|
window_id: Window label for the title
|
|
n_top: Show only the top/bottom n MPs by score
|
|
output_path: Output HTML path
|
|
|
|
Returns the output_path on success.
|
|
"""
|
|
go, px = _require_plotly()
|
|
|
|
# Sort by score
|
|
sorted_items = sorted(scores.items(), key=lambda kv: kv[1])
|
|
|
|
# Take n_top from each end if list is large
|
|
if len(sorted_items) > 2 * n_top:
|
|
sorted_items = sorted_items[:n_top] + sorted_items[-n_top:]
|
|
|
|
names = [item[0] for item in sorted_items]
|
|
vals = [item[1] for item in sorted_items]
|
|
colors = (
|
|
[party_of.get(n, "Unknown") for n in names]
|
|
if party_of
|
|
else ["Unknown"] * len(names)
|
|
)
|
|
|
|
title = "MP Ideological Axis Score" + (f" ({window_id})" if window_id else "")
|
|
|
|
fig = px.bar(
|
|
x=vals,
|
|
y=names,
|
|
color=colors,
|
|
orientation="h",
|
|
title=title,
|
|
labels={"x": "Score (← left — right →)", "y": "MP", "color": "Party"},
|
|
)
|
|
fig.update_layout(yaxis={"categoryorder": "total ascending"})
|
|
fig.write_html(output_path, include_plotlyjs="cdn")
|
|
_logger.info("Political axis chart written to %s", output_path)
|
|
return output_path
|
|
|
|
|
|
def plot_political_compass(
|
|
positions_by_window: Dict,
|
|
window_id: str,
|
|
party_of: Optional[Dict] = None,
|
|
output_path: str = "analysis_compass.html",
|
|
) -> str:
|
|
"""Plot 2D political compass scatter for a single window.
|
|
|
|
Args:
|
|
positions_by_window: {window_id: {mp_name: (x,y)}}
|
|
window_id: which window to plot
|
|
party_of: optional mapping mp_name -> party for colouring
|
|
output_path: HTML output path
|
|
|
|
Returns output_path
|
|
"""
|
|
go, px = _require_plotly()
|
|
|
|
pos = positions_by_window.get(window_id, {})
|
|
xs = [v[0] for v in pos.values()]
|
|
ys = [v[1] for v in pos.values()]
|
|
names = list(pos.keys())
|
|
|
|
# If no party mapping provided, try to load from data/motions.db (duckdb)
|
|
if party_of is None:
|
|
try:
|
|
import duckdb # type: ignore
|
|
|
|
try:
|
|
conn = duckdb.connect(database="data/motions.db", read_only=True)
|
|
df = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchdf()
|
|
# convert to dict
|
|
party_of = {
|
|
row[0]: row[1] for row in df.itertuples(index=False, name=None)
|
|
}
|
|
_logger.info(
|
|
"Loaded party mapping for %d MPs from data/motions.db",
|
|
len(party_of),
|
|
)
|
|
finally:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
except ImportError:
|
|
_logger.debug("duckdb not installed; proceeding without party mapping")
|
|
except Exception as e:
|
|
_logger.debug("Could not load party mapping from data/motions.db: %s", e)
|
|
|
|
parties = [party_of.get(n, "Unknown") if party_of else "Unknown" for n in names]
|
|
|
|
fig = px.scatter(
|
|
x=xs,
|
|
y=ys,
|
|
color=parties,
|
|
hover_name=names,
|
|
title=f"Political Compass ({window_id})",
|
|
labels={
|
|
"x": "Left ← — → Right",
|
|
"y": "Progressive ← — → Conservative",
|
|
"color": "Party",
|
|
},
|
|
)
|
|
fig.update_traces(marker=dict(size=8, opacity=0.8))
|
|
fig.write_html(output_path, include_plotlyjs="cdn")
|
|
_logger.info("Political compass written to %s", output_path)
|
|
return output_path
|
|
|
|
|
|
def plot_2d_trajectories(
|
|
positions_by_window: Dict,
|
|
mp_names: Optional[List[str]] = None,
|
|
output_path: str = "analysis_trajectories_compass.html",
|
|
) -> str:
|
|
"""Plot MP trajectories across windows on the 2D compass.
|
|
|
|
Args:
|
|
positions_by_window: {window_id: {mp_name: (x,y)}}
|
|
mp_names: list of MPs to plot (default: all found in positions)
|
|
output_path: output HTML path
|
|
"""
|
|
go, px = _require_plotly()
|
|
|
|
# collect window order
|
|
window_ids = list(positions_by_window.keys())
|
|
# build per-MP time-ordered coords
|
|
# mp_coords maps mp_name -> list of (window_id, (x,y))
|
|
mp_coords: Dict[str, List[Tuple[str, Tuple[float, float]]]] = {}
|
|
for wid in window_ids:
|
|
for mp, coord in positions_by_window.get(wid, {}).items():
|
|
mp_coords.setdefault(mp, []).append((wid, coord))
|
|
|
|
if mp_names is None:
|
|
mp_names = list(mp_coords.keys())
|
|
|
|
fig = go.Figure()
|
|
for mp in mp_names:
|
|
if mp not in mp_coords:
|
|
continue
|
|
items = mp_coords[mp]
|
|
# ensure sorted by window order
|
|
items_sorted = sorted(items, key=lambda it: window_ids.index(it[0]))
|
|
xs = [c[1][0] for c in items_sorted]
|
|
ys = [c[1][1] for c in items_sorted]
|
|
text = [f"{mp} ({w})" for w, _ in items_sorted]
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=xs, y=ys, mode="lines+markers", name=mp, text=text, hoverinfo="text"
|
|
)
|
|
)
|
|
|
|
fig.update_layout(
|
|
title="MP Trajectories on Political Compass",
|
|
xaxis_title="Left ← — → Right",
|
|
yaxis_title="Progressive ← — → Conservative",
|
|
)
|
|
fig.write_html(output_path, include_plotlyjs="cdn")
|
|
_logger.info("2D trajectories compass written to %s", output_path)
|
|
return output_path
|
|
|