You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/analysis/visualize.py

434 lines
14 KiB

"""visualize.py — Plotly interactive plots for parliamentary embeddings.
Produces self-contained HTML files.
Functions:
plot_umap_scatter — 2D scatter of fused motion embeddings, coloured by cluster
plot_mp_trajectory — Line plot of MP drift across windows
plot_political_axis — Bar chart of MP scores on the ideological axis
"""
import logging
from typing import Dict, List, Optional, Tuple
import numpy as np
from typing import Any
_logger = logging.getLogger(__name__)
def _require_plotly():
try:
import plotly.graph_objects as go
import plotly.express as px
return go, px
except ImportError:
raise ImportError("plotly is not installed. Install it with: uv add plotly")
def _load_party_map(db_path: str = "data/motions.db") -> Dict[str, str]:
"""Build a party mapping mp_name -> party.
Prefers mp_metadata where available; otherwise uses majority-party from mp_votes.
Returns a dict of mp_name -> party (strings).
"""
try:
import duckdb
except Exception:
_logger.debug("duckdb not available when building party map")
return {}
conn = duckdb.connect(db_path)
try:
# metadata-based mapping
rows = conn.execute(
"SELECT mp_name, party FROM mp_metadata WHERE party IS NOT NULL"
).fetchall()
meta_map = {r[0]: r[1] for r in rows}
# majority-party heuristic from mp_votes
rows = conn.execute(
"""
SELECT mp_name, party, COUNT(*) as n
FROM mp_votes
WHERE party IS NOT NULL
GROUP BY mp_name, party
"""
).fetchall()
counts: Dict[str, List[tuple]] = {}
for mp_name, party, n in rows:
counts.setdefault(mp_name, []).append((party, n))
maj_map: Dict[str, str] = {}
for mp_name, arr in counts.items():
maj_map[mp_name] = max(arr, key=lambda x: x[1])[0]
merged = dict(maj_map)
# prefer metadata mapping when available
merged.update(meta_map)
_logger.info(
"Built party map: %d from mp_votes majority, %d from mp_metadata",
len(maj_map),
len(meta_map),
)
return merged
finally:
try:
conn.close()
except Exception:
pass
def plot_umap_scatter(
motion_ids: List[int],
coords: List[List[float]],
labels: Optional[List[int]] = None,
window_id: Optional[str] = None,
output_path: str = "analysis_umap.html",
) -> str:
"""Produce a 2D scatter plot of UMAP-reduced fused embeddings.
Args:
motion_ids: Motion IDs (used as hover labels)
coords: List of [x, y] coordinates
labels: Optional cluster labels (integer per motion)
window_id: Window label for the plot title
output_path: Where to write the self-contained HTML
Returns the output_path on success.
"""
go, px = _require_plotly()
xs = [c[0] for c in coords]
ys = [c[1] for c in coords]
color = labels if labels is not None else [0] * len(motion_ids)
title = f"UMAP — fused motion embeddings" + (f" ({window_id})" if window_id else "")
fig = px.scatter(
x=xs,
y=ys,
color=[str(c) for c in color],
hover_name=[str(mid) for mid in motion_ids],
title=title,
labels={"x": "UMAP-1", "y": "UMAP-2", "color": "Cluster"},
)
fig.write_html(output_path, include_plotlyjs="cdn")
_logger.info("UMAP scatter written to %s", output_path)
return output_path
def plot_mp_trajectory(
trajectories: Dict[str, Dict],
mp_names: Optional[List[str]] = None,
output_path: str = "analysis_trajectory.html",
) -> str:
"""Line plot of MP drift across time windows.
Args:
trajectories: Output of analysis.trajectory.compute_trajectories()
mp_names: Subset of MPs to plot (default: all)
output_path: Output HTML file path
Returns the output_path on success.
"""
go, px = _require_plotly()
if mp_names is None:
mp_names = list(trajectories.keys())
fig = go.Figure()
for mp in mp_names:
if mp not in trajectories:
continue
data = trajectories[mp]
windows = data["windows"]
drifts_cumulative = [0.0] + list(np.cumsum(data["drift"]))
# Plot cumulative drift per window transition
x_labels = windows[: len(drifts_cumulative)]
fig.add_trace(
go.Scatter(
x=x_labels,
y=drifts_cumulative,
mode="lines+markers",
name=mp,
)
)
fig.update_layout(
title="MP Political Drift Over Time (Cumulative)",
xaxis_title="Window",
yaxis_title="Cumulative Drift",
)
fig.write_html(output_path, include_plotlyjs="cdn")
_logger.info("Trajectory plot written to %s", output_path)
return output_path
def plot_political_axis(
scores: Dict[str, float],
party_of: Optional[Dict[str, str]] = None,
window_id: Optional[str] = None,
n_top: int = 30,
output_path: str = "analysis_political_axis.html",
) -> str:
"""Horizontal bar chart of MP scores on the ideological axis.
Args:
scores: {mp_name: score} from political_axis module
party_of: Optional {mp_name: party} for colour-coding
window_id: Window label for the title
n_top: Show only the top/bottom n MPs by score
output_path: Output HTML path
Returns the output_path on success.
"""
go, px = _require_plotly()
# Sort by score
sorted_items = sorted(scores.items(), key=lambda kv: kv[1])
# Take n_top from each end if list is large
if len(sorted_items) > 2 * n_top:
sorted_items = sorted_items[:n_top] + sorted_items[-n_top:]
names = [item[0] for item in sorted_items]
vals = [item[1] for item in sorted_items]
colors = (
[party_of.get(n, "Unknown") for n in names]
if party_of
else ["Unknown"] * len(names)
)
title = "MP Ideological Axis Score" + (f" ({window_id})" if window_id else "")
fig = px.bar(
x=vals,
y=names,
color=colors,
orientation="h",
title=title,
labels={"x": "Score (← left — right →)", "y": "MP", "color": "Party"},
)
fig.update_layout(yaxis={"categoryorder": "total ascending"})
fig.write_html(output_path, include_plotlyjs="cdn")
_logger.info("Political axis chart written to %s", output_path)
return output_path
def plot_political_compass(
positions_by_window: Dict,
window_id: str,
party_of: Optional[Dict] = None,
axis_def: Optional[Dict] = None,
y_scale: Optional[float] = None,
output_path: str = "analysis_compass.html",
) -> str:
"""Plot 2D political compass scatter for a single window.
Args:
positions_by_window: {window_id: {mp_name: (x,y)}}
window_id: which window to plot
party_of: optional mapping mp_name -> party for colouring
output_path: HTML output path
Returns output_path
"""
go, px = _require_plotly()
pos = positions_by_window.get(window_id, {})
xs = [v[0] for v in pos.values()]
ys = [v[1] for v in pos.values()]
names = list(pos.keys())
# If no party mapping provided, try to load from data/motions.db (duckdb)
if party_of is None:
try:
import duckdb # type: ignore
conn = None
try:
conn = duckdb.connect(database="data/motions.db", read_only=True)
df = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchdf()
# convert to dict
party_of = {
row[0]: row[1] for row in df.itertuples(index=False, name=None)
}
_logger.info(
"Loaded party mapping for %d MPs from data/motions.db",
len(party_of),
)
finally:
if conn is not None:
try:
conn.close()
except Exception:
pass
except ImportError:
_logger.debug("duckdb not installed; proceeding without party mapping")
except Exception as e:
_logger.debug("Could not load party mapping from data/motions.db: %s", e)
parties = [party_of.get(n, "Unknown") if party_of else "Unknown" for n in names]
# If axis_def provided and evr small, optionally scale y for visibility
scaled_ys = ys
if axis_def and y_scale is None:
evr = axis_def.get("explained_variance_ratio") if axis_def else None
# Accept lists/tuples or numpy arrays; avoid ambiguous truth checks
evr_list = None
if evr is not None:
try:
evr_list = list(evr)
except Exception:
try:
evr_list = [float(evr)]
except Exception:
evr_list = None
if evr_list is not None and len(evr_list) >= 2:
evr1, evr2 = float(evr_list[0]), float(evr_list[1])
if evr2 < 1e-6:
scale_guess = 1.0
else:
scale_guess = min(max(1.0, float(evr1 / (evr2 + 1e-9)) ** 0.5), 8.0)
scaled_ys = [y * scale_guess for y in ys]
_logger.info(
"Auto-scaling Y by %.2f for visibility (evr1=%.3f evr2=%.3f)",
scale_guess,
evr1,
evr2,
)
elif axis_def and y_scale is not None:
scaled_ys = [y * float(y_scale) for y in ys]
# mark unknowns differently: use descriptive labels so the legend doesn't
# show numeric symbol values like "PVV, 0" when color and symbol combine.
unknown_labels = [
"Unknown" if parties[i] == "Unknown" else "Known" for i in range(len(names))
]
fig = px.scatter(
x=xs,
y=scaled_ys,
color=parties,
symbol=unknown_labels,
hover_name=names,
title=f"Political Compass ({window_id})",
labels={
"x": "Left ← — → Right",
"y": "Progressive ← — → Conservative",
"color": "Party",
"symbol": "Known?",
},
)
fig.update_traces(marker=dict(size=8, opacity=0.85))
# annotate explained variance if available
if axis_def and axis_def.get("method") == "pca":
evr = axis_def.get("explained_variance_ratio")
evr_list = None
if evr is not None:
try:
evr_list = list(evr)
except Exception:
try:
evr_list = [float(evr)]
except Exception:
evr_list = None
if evr_list is not None and len(evr_list) >= 2:
fig.update_layout(
title=f"Political Compass ({window_id}) — PCA EVR PC1={evr_list[0] * 100:.1f}%, PC2={evr_list[1] * 100:.1f}%"
)
fig.write_html(output_path, include_plotlyjs="cdn")
_logger.info("Political compass written to %s", output_path)
return output_path
def plot_2d_trajectories(
positions_by_window: Dict,
mp_names: Optional[List[str]] = None,
output_path: str = "analysis_trajectories_compass.html",
) -> str:
"""Plot MP trajectories across windows on the 2D compass.
Args:
positions_by_window: {window_id: {mp_name: (x,y)}}
mp_names: list of MPs to plot (default: all found in positions)
output_path: output HTML path
"""
go, px = _require_plotly()
# collect window order
window_ids = list(positions_by_window.keys())
# build per-MP time-ordered coords
# mp_coords maps mp_name -> list of (window_id, (x,y))
mp_coords: Dict[str, List[Tuple[str, Tuple[float, float]]]] = {}
for wid in window_ids:
for mp, coord in positions_by_window.get(wid, {}).items():
mp_coords.setdefault(mp, []).append((wid, coord))
if mp_names is None:
mp_names = list(mp_coords.keys())
fig = go.Figure()
for mp in mp_names:
if mp not in mp_coords:
continue
items = mp_coords[mp]
# ensure sorted by window order
items_sorted = sorted(items, key=lambda it: window_ids.index(it[0]))
xs = [c[1][0] for c in items_sorted]
ys = [c[1][1] for c in items_sorted]
text = [f"{mp} ({w})" for w, _ in items_sorted]
fig.add_trace(
go.Scatter(
x=xs, y=ys, mode="lines+markers", name=mp, text=text, hoverinfo="text"
)
)
# Add an arrow indicating the final direction (only one arrow per MP to
# avoid clutter). Use an annotation with an arrowhead from the penultimate
# to the last point and label the endpoint with the MP name.
try:
if len(xs) >= 2:
x0, y0 = xs[-2], ys[-2]
x1, y1 = xs[-1], ys[-1]
# small style choices — subtle arrow and a short label
fig.add_annotation(
x=x1,
y=y1,
ax=x0,
ay=y0,
xref="x",
yref="y",
axref="x",
ayref="y",
showarrow=True,
arrowhead=3,
arrowsize=1.0,
arrowwidth=1.2,
arrowcolor="rgba(0,0,0,0.6)",
opacity=0.8,
)
# endpoint label slightly offset to reduce overlap with marker
fig.add_annotation(
x=x1,
y=y1,
xref="x",
yref="y",
text=mp,
showarrow=False,
xanchor="left",
yanchor="bottom",
font=dict(size=10, color="rgba(0,0,0,0.8)"),
)
except Exception:
_logger.exception("Failed to add arrow/label for MP %s", mp)
fig.update_layout(
title="MP Trajectories on Political Compass",
xaxis_title="Left ← — → Right",
yaxis_title="Progressive ← — → Conservative",
)
fig.write_html(output_path, include_plotlyjs="cdn")
_logger.info("2D trajectories compass written to %s", output_path)
return output_path