"""visualize.py — Plotly interactive plots for parliamentary embeddings. Produces self-contained HTML files. Functions: plot_umap_scatter — 2D scatter of fused motion embeddings, coloured by cluster plot_mp_trajectory — Line plot of MP drift across windows plot_political_axis — Bar chart of MP scores on the ideological axis """ import logging from typing import Dict, List, Optional, Tuple import numpy as np from typing import Any _logger = logging.getLogger(__name__) def _require_plotly(): try: import plotly.graph_objects as go import plotly.express as px return go, px except ImportError: raise ImportError("plotly is not installed. Install it with: uv add plotly") def _load_party_map(db_path: str = "data/motions.db") -> Dict[str, str]: """Build a party mapping mp_name -> party. Prefers mp_metadata where available; otherwise uses majority-party from mp_votes. Returns a dict of mp_name -> party (strings). """ try: import duckdb except Exception: _logger.debug("duckdb not available when building party map") return {} conn = duckdb.connect(db_path) try: # metadata-based mapping rows = conn.execute( "SELECT mp_name, party FROM mp_metadata WHERE party IS NOT NULL" ).fetchall() meta_map = {r[0]: r[1] for r in rows} # majority-party heuristic from mp_votes rows = conn.execute( """ SELECT mp_name, party, COUNT(*) as n FROM mp_votes WHERE party IS NOT NULL GROUP BY mp_name, party """ ).fetchall() counts: Dict[str, List[tuple]] = {} for mp_name, party, n in rows: counts.setdefault(mp_name, []).append((party, n)) maj_map: Dict[str, str] = {} for mp_name, arr in counts.items(): maj_map[mp_name] = max(arr, key=lambda x: x[1])[0] merged = dict(maj_map) # prefer metadata mapping when available merged.update(meta_map) _logger.info( "Built party map: %d from mp_votes majority, %d from mp_metadata", len(maj_map), len(meta_map), ) return merged finally: try: conn.close() except Exception: pass def plot_umap_scatter( motion_ids: List[int], coords: List[List[float]], labels: Optional[List[int]] = None, window_id: Optional[str] = None, output_path: str = "analysis_umap.html", ) -> str: """Produce a 2D scatter plot of UMAP-reduced fused embeddings. Args: motion_ids: Motion IDs (used as hover labels) coords: List of [x, y] coordinates labels: Optional cluster labels (integer per motion) window_id: Window label for the plot title output_path: Where to write the self-contained HTML Returns the output_path on success. """ go, px = _require_plotly() xs = [c[0] for c in coords] ys = [c[1] for c in coords] color = labels if labels is not None else [0] * len(motion_ids) title = f"UMAP — fused motion embeddings" + (f" ({window_id})" if window_id else "") fig = px.scatter( x=xs, y=ys, color=[str(c) for c in color], hover_name=[str(mid) for mid in motion_ids], title=title, labels={"x": "UMAP-1", "y": "UMAP-2", "color": "Cluster"}, ) fig.write_html(output_path, include_plotlyjs="cdn") _logger.info("UMAP scatter written to %s", output_path) return output_path def plot_mp_trajectory( trajectories: Dict[str, Dict], mp_names: Optional[List[str]] = None, output_path: str = "analysis_trajectory.html", ) -> str: """Line plot of MP drift across time windows. Args: trajectories: Output of analysis.trajectory.compute_trajectories() mp_names: Subset of MPs to plot (default: all) output_path: Output HTML file path Returns the output_path on success. """ go, px = _require_plotly() if mp_names is None: mp_names = list(trajectories.keys()) fig = go.Figure() for mp in mp_names: if mp not in trajectories: continue data = trajectories[mp] windows = data["windows"] drifts_cumulative = [0.0] + list(np.cumsum(data["drift"])) # Plot cumulative drift per window transition x_labels = windows[: len(drifts_cumulative)] fig.add_trace( go.Scatter( x=x_labels, y=drifts_cumulative, mode="lines+markers", name=mp, ) ) fig.update_layout( title="MP Political Drift Over Time (Cumulative)", xaxis_title="Window", yaxis_title="Cumulative Drift", ) fig.write_html(output_path, include_plotlyjs="cdn") _logger.info("Trajectory plot written to %s", output_path) return output_path def plot_political_axis( scores: Dict[str, float], party_of: Optional[Dict[str, str]] = None, window_id: Optional[str] = None, n_top: int = 30, output_path: str = "analysis_political_axis.html", ) -> str: """Horizontal bar chart of MP scores on the ideological axis. Args: scores: {mp_name: score} from political_axis module party_of: Optional {mp_name: party} for colour-coding window_id: Window label for the title n_top: Show only the top/bottom n MPs by score output_path: Output HTML path Returns the output_path on success. """ go, px = _require_plotly() # Sort by score sorted_items = sorted(scores.items(), key=lambda kv: kv[1]) # Take n_top from each end if list is large if len(sorted_items) > 2 * n_top: sorted_items = sorted_items[:n_top] + sorted_items[-n_top:] names = [item[0] for item in sorted_items] vals = [item[1] for item in sorted_items] colors = ( [party_of.get(n, "Unknown") for n in names] if party_of else ["Unknown"] * len(names) ) title = "MP Ideological Axis Score" + (f" ({window_id})" if window_id else "") fig = px.bar( x=vals, y=names, color=colors, orientation="h", title=title, labels={"x": "Score (← left — right →)", "y": "MP", "color": "Party"}, ) fig.update_layout(yaxis={"categoryorder": "total ascending"}) fig.write_html(output_path, include_plotlyjs="cdn") _logger.info("Political axis chart written to %s", output_path) return output_path def plot_political_compass( positions_by_window: Dict, window_id: str, party_of: Optional[Dict] = None, axis_def: Optional[Dict] = None, y_scale: Optional[float] = None, output_path: str = "analysis_compass.html", ) -> str: """Plot 2D political compass scatter for a single window. Args: positions_by_window: {window_id: {mp_name: (x,y)}} window_id: which window to plot party_of: optional mapping mp_name -> party for colouring output_path: HTML output path Returns output_path """ go, px = _require_plotly() pos = positions_by_window.get(window_id, {}) xs = [v[0] for v in pos.values()] ys = [v[1] for v in pos.values()] names = list(pos.keys()) # If no party mapping provided, try to load from data/motions.db (duckdb) if party_of is None: try: import duckdb # type: ignore conn = None try: conn = duckdb.connect(database="data/motions.db", read_only=True) df = conn.execute("SELECT mp_name, party FROM mp_metadata").fetchdf() # convert to dict party_of = { row[0]: row[1] for row in df.itertuples(index=False, name=None) } _logger.info( "Loaded party mapping for %d MPs from data/motions.db", len(party_of), ) finally: if conn is not None: try: conn.close() except Exception: pass except ImportError: _logger.debug("duckdb not installed; proceeding without party mapping") except Exception as e: _logger.debug("Could not load party mapping from data/motions.db: %s", e) parties = [party_of.get(n, "Unknown") if party_of else "Unknown" for n in names] # If axis_def provided and evr small, optionally scale y for visibility scaled_ys = ys if axis_def and y_scale is None: evr = axis_def.get("explained_variance_ratio") if axis_def else None # Accept lists/tuples or numpy arrays; avoid ambiguous truth checks evr_list = None if evr is not None: try: evr_list = list(evr) except Exception: try: evr_list = [float(evr)] except Exception: evr_list = None if evr_list is not None and len(evr_list) >= 2: evr1, evr2 = float(evr_list[0]), float(evr_list[1]) if evr2 < 1e-6: scale_guess = 1.0 else: scale_guess = min(max(1.0, float(evr1 / (evr2 + 1e-9)) ** 0.5), 8.0) scaled_ys = [y * scale_guess for y in ys] _logger.info( "Auto-scaling Y by %.2f for visibility (evr1=%.3f evr2=%.3f)", scale_guess, evr1, evr2, ) elif axis_def and y_scale is not None: scaled_ys = [y * float(y_scale) for y in ys] # mark unknowns differently: use descriptive labels so the legend doesn't # show numeric symbol values like "PVV, 0" when color and symbol combine. unknown_labels = [ "Unknown" if parties[i] == "Unknown" else "Known" for i in range(len(names)) ] fig = px.scatter( x=xs, y=scaled_ys, color=parties, symbol=unknown_labels, hover_name=names, title=f"Political Compass ({window_id})", labels={ "x": "Left ← — → Right", "y": "Progressive ← — → Conservative", "color": "Party", "symbol": "Known?", }, ) fig.update_traces(marker=dict(size=8, opacity=0.85)) # annotate explained variance if available if axis_def and axis_def.get("method") == "pca": evr = axis_def.get("explained_variance_ratio") evr_list = None if evr is not None: try: evr_list = list(evr) except Exception: try: evr_list = [float(evr)] except Exception: evr_list = None if evr_list is not None and len(evr_list) >= 2: fig.update_layout( title=f"Political Compass ({window_id}) — PCA EVR PC1={evr_list[0] * 100:.1f}%, PC2={evr_list[1] * 100:.1f}%" ) fig.write_html(output_path, include_plotlyjs="cdn") _logger.info("Political compass written to %s", output_path) return output_path def plot_2d_trajectories( positions_by_window: Dict, mp_names: Optional[List[str]] = None, output_path: str = "analysis_trajectories_compass.html", ) -> str: """Plot MP trajectories across windows on the 2D compass. Args: positions_by_window: {window_id: {mp_name: (x,y)}} mp_names: list of MPs to plot (default: all found in positions) output_path: output HTML path """ go, px = _require_plotly() # collect window order window_ids = list(positions_by_window.keys()) # build per-MP time-ordered coords # mp_coords maps mp_name -> list of (window_id, (x,y)) mp_coords: Dict[str, List[Tuple[str, Tuple[float, float]]]] = {} for wid in window_ids: for mp, coord in positions_by_window.get(wid, {}).items(): mp_coords.setdefault(mp, []).append((wid, coord)) if mp_names is None: mp_names = list(mp_coords.keys()) fig = go.Figure() for mp in mp_names: if mp not in mp_coords: continue items = mp_coords[mp] # ensure sorted by window order items_sorted = sorted(items, key=lambda it: window_ids.index(it[0])) xs = [c[1][0] for c in items_sorted] ys = [c[1][1] for c in items_sorted] text = [f"{mp} ({w})" for w, _ in items_sorted] fig.add_trace( go.Scatter( x=xs, y=ys, mode="lines+markers", name=mp, text=text, hoverinfo="text" ) ) # Add an arrow indicating the final direction (only one arrow per MP to # avoid clutter). Use an annotation with an arrowhead from the penultimate # to the last point and label the endpoint with the MP name. try: if len(xs) >= 2: x0, y0 = xs[-2], ys[-2] x1, y1 = xs[-1], ys[-1] # small style choices — subtle arrow and a short label fig.add_annotation( x=x1, y=y1, ax=x0, ay=y0, xref="x", yref="y", axref="x", ayref="y", showarrow=True, arrowhead=3, arrowsize=1.0, arrowwidth=1.2, arrowcolor="rgba(0,0,0,0.6)", opacity=0.8, ) # endpoint label slightly offset to reduce overlap with marker fig.add_annotation( x=x1, y=y1, xref="x", yref="y", text=mp, showarrow=False, xanchor="left", yanchor="bottom", font=dict(size=10, color="rgba(0,0,0,0.8)"), ) except Exception: _logger.exception("Failed to add arrow/label for MP %s", mp) fig.update_layout( title="MP Trajectories on Political Compass", xaxis_title="Left ← — → Right", yaxis_title="Progressive ← — → Conservative", ) fig.write_html(output_path, include_plotlyjs="cdn") _logger.info("2D trajectories compass written to %s", output_path) return output_path