diff --git a/explorer.py b/explorer.py index c676a21..c68666f 100644 --- a/explorer.py +++ b/explorer.py @@ -18,14 +18,369 @@ import json import logging import os import re +import traceback from typing import Dict, List, Optional, Tuple -import duckdb +try: + import duckdb + + _DUCKDB_AVAILABLE = True +except Exception: + duckdb = None + _DUCKDB_AVAILABLE = False import numpy as np import pandas as pd -import plotly.express as px -import plotly.graph_objects as go -import streamlit as st + +try: + import plotly.express as px + import plotly.graph_objects as go +except Exception: + # Plotly may be unavailable in lightweight test environments. Provide a tiny + # local fallback that exposes a Figure-like object with `.data` and + # `add_trace()` so unit tests can run without installing plotly. + px = None + import types + + class _DummyTrace: + def __init__(self, **kwargs): + # Preserve commonly-used attributes accessed by tests + self.name = kwargs.get("name") + self.x = kwargs.get("x") + self.y = kwargs.get("y") + self.text = kwargs.get("text") + self.customdata = kwargs.get("customdata") + + class _DummyFigure: + def __init__(self): + self.data = [] + + def add_trace(self, trace): + # plotly passes a Scatter object; our tests only inspect `.data` + # elements for `.name` and `.customdata`. Accept both our + # _DummyTrace and dict-like kwargs. + if isinstance(trace, _DummyTrace): + self.data.append(trace) + else: + # Some code may call go.Scatter(...) which returns an object; + # if a mapping is passed here instead, coerce to _DummyTrace. + try: + # attempt attribute access + name = getattr(trace, "name", None) + x = getattr(trace, "x", None) + y = getattr(trace, "y", None) + text = getattr(trace, "text", None) + customdata = getattr(trace, "customdata", None) + except Exception: + # Last resort: treat as mapping + name = trace.get("name") if hasattr(trace, "get") else None + x = trace.get("x") if hasattr(trace, "get") else None + y = trace.get("y") if hasattr(trace, "get") else None + text = trace.get("text") if hasattr(trace, "get") else None + customdata = ( + trace.get("customdata") if hasattr(trace, "get") else None + ) + self.data.append( + _DummyTrace(name=name, x=x, y=y, text=text, customdata=customdata) + ) + + def add_annotation(self, *args, **kwargs): + # noop for tests that don't import full plotly + return None + + go = types.SimpleNamespace( + Figure=_DummyFigure, Scatter=lambda **kwargs: _DummyTrace(**kwargs) + ) +try: + import streamlit as st +except Exception: + # Minimal dummy replacement for Streamlit used during tests / import-time. + # We only need a tiny subset so unit tests can import explorer without + # installing streamlit. All functions here are no-ops or simple fallbacks. + class _DummySt: + def cache_data(self, *args, **kwargs): + def _decorator(func): + return func + + return _decorator + + def markdown(self, *args, **kwargs): + return None + + def subheader(self, *args, **kwargs): + return None + + def plotly_chart(self, *args, **kwargs): + return None + + def caption(self, *args, **kwargs): + return None + + def text_area(self, *args, **kwargs): + return None + + def json(self, *args, **kwargs): + return None + + def checkbox(self, *args, **kwargs): + # default to False unless value provided + return kwargs.get("value", False) + + def warning(self, *args, **kwargs): + return None + + def info(self, *args, **kwargs): + return None + + def selectbox(self, *args, **kwargs): + # return first option if options provided + opts = ( + kwargs.get("options") + if kwargs.get("options") is not None + else (args[1] if len(args) > 1 else []) + ) + return opts[0] if opts else None + + def multiselect(self, *args, **kwargs): + opts = ( + kwargs.get("options") + if kwargs.get("options") is not None + else (args[1] if len(args) > 1 else []) + ) + default = kwargs.get("default") + if default is not None: + return default + return opts[:6] if opts else [] + + def number_input(self, *args, **kwargs): + return kwargs.get("value") if "value" in kwargs else 1 + + def slider(self, *args, **kwargs): + return kwargs.get("value") if "value" in kwargs else 0.35 + + def expander(self, *args, **kwargs): + class _Ctx: + def __enter__(self_inner): + return self_inner + + def __exit__(self_inner, exc_type, exc, tb): + return False + + return _Ctx() + + def columns(self, *args, **kwargs): + # Return a tuple of simple objects with the methods used in the UI + class _Col: + def markdown(self, *a, **k): + return None + + def metric(self, *a, **k): + return None + + def dataframe(self, *a, **k): + return None + + n = len(args[0]) if args else 1 + return tuple(_Col() for _ in range(n)) + + st = _DummySt() +# Temporary diagnostics for Trajectories plotting — set by instrumentation when +# EXPLORER_DEBUG_TRAJECTORIES is enabled. This is intended to be small, opt-in and +# reversible once root cause is found. +_last_trajectories_diagnostics: dict = {} +# Backwards/alternate name used by instrumentation: keep a second module-level +# reference so callers/tests can look for either name. +_last_diagnostics = _last_trajectories_diagnostics + + +def get_debug_trajectories_enabled() -> bool: + """Return True when EXPLORER_DEBUG_TRAJECTORIES env var indicates debug mode. + + Accepts '1', 'true', 'True'. Used as default for a per-tab checkbox. + """ + v = os.getenv("EXPLORER_DEBUG_TRAJECTORIES") + return str(v) in ("1", "true", "True") + + +from explorer_helpers import ( + inspect_positions_for_issues, + compute_party_centroids, +) + + +def select_trajectory_plot_data( + positions_by_window: Dict[str, Dict[str, Tuple[float, float]]], + party_map: Dict[str, str], + windows: List[str], + selected_parties: List[str], + smooth_alpha: float = 0.35, + mp_fallback_count: Optional[int] = None, +) -> Tuple[go.Figure, int, Optional[str]]: + """Return (fig, trace_count, banner_text). + + Helper used by build_trajectories_tab. Does not call Streamlit. + """ + # Use env var default if not provided + if mp_fallback_count is None: + try: + mp_fallback_count = int(os.getenv("EXPLORER_MP_FALLBACK_COUNT", "20")) + except Exception: + mp_fallback_count = 20 + + # Compute per-party centroids aligned to windows + party_centroids, meta = compute_party_centroids( + positions_by_window, party_map, windows + ) + + # Use inspector to collect diagnostics (import-safe, pure helper). Keep this + # call local to the helper to ensure the inspector is exercised and the + # diagnostics are available for logging/debugging. Do not call Streamlit + # from here so the function remains import-safe for tests. + try: + inspector_summary = inspect_positions_for_issues(positions_by_window, party_map) + except Exception: + # Capture traceback diagnostics so callers (and tests) can inspect what went wrong. + tb = traceback.format_exc() + inspector_summary = {} + try: + # Attach diagnostics to the helper function for callers that want to inspect + # the last error directly on the function object. + select_trajectory_plot_data._last_diagnostics = { + "stage": "inspector_exception", + "exception": tb, + } + except Exception: + # best-effort only + pass + try: + # Also update the module-level trajectories diagnostics so the UI can show + # a compact summary when debugging is enabled. + _last_trajectories_diagnostics.update( + {"stage": "inspector_exception", "exception": tb} + ) + except Exception: + pass + logger.debug("select_trajectory_plot_data inspector summary: %s", inspector_summary) + + # Determine which parties have at least one non-nan centroid + plottable_parties = [] + for p, vals in party_centroids.items(): + has_valid = any(not (np.isnan(x) and np.isnan(y)) for x, y in vals) + if has_valid: + plottable_parties.append(p) + + fig = go.Figure() + trace_count = 0 + banner_text: Optional[str] = None + + def _ema_smooth(values: List[float], alpha: float) -> List[float]: + if not values or alpha >= 1.0: + return values + smoothed: List[float] = [] + prev = None + for v in values: + if v is None or (isinstance(v, float) and np.isnan(v)): + smoothed.append(float(np.nan)) + continue + v = float(v) + if prev is None: + prev = v + else: + prev = alpha * v + (1 - alpha) * prev + smoothed.append(float(prev)) + return smoothed + + # If no plottable parties, fallback to MP trajectories + if not plottable_parties: + # Build mp_positions across windows + mp_positions: Dict[str, Dict[str, Tuple[float, float]]] = {} + for wid in windows: + pos = positions_by_window.get(wid, {}) + for mp_name, xy in pos.items(): + try: + x, y = float(xy[0]), float(xy[1]) + except Exception: + continue + mp_positions.setdefault(mp_name, {})[wid] = (x, y) + + # Rank MPs by activity (number of windows with positions) + mp_activity = sorted( + [(mp, len(wdict)) for mp, wdict in mp_positions.items()], + key=lambda t: t[1], + reverse=True, + ) + top_mps = [mp for mp, _ in mp_activity[:mp_fallback_count]] + + for mp in top_mps: + wids_sorted = sorted(mp_positions.get(mp, {}).keys()) + if not wids_sorted: + continue + xs_raw = [mp_positions[mp][w][0] for w in wids_sorted] + ys_raw = [mp_positions[mp][w][1] for w in wids_sorted] + xs = _ema_smooth(xs_raw, smooth_alpha) + ys = _ema_smooth(ys_raw, smooth_alpha) + custom_raw = [ + ( + float(rx) if rx is not None else float(np.nan), + float(ry) if ry is not None else float(np.nan), + ) + for rx, ry in zip(xs_raw, ys_raw) + ] + fig.add_trace( + go.Scatter( + x=xs, + y=ys, + mode="lines+markers", + name=mp, + text=wids_sorted, + customdata=custom_raw, + line=dict(color="#888888", shape="spline", smoothing=1.3), + marker=dict(color="#888888", size=6), + ) + ) + trace_count += 1 + + banner_text = "Partijcentroiden niet beschikbaar — tonen individuele MP-trajecten als fallback." + return fig, trace_count, banner_text + + # Otherwise plot party centroids for selected parties intersecting plottable + to_plot = [p for p in selected_parties if p in plottable_parties] + # If none selected, default to all plottable + if not to_plot: + to_plot = plottable_parties + + for party in to_plot: + vals = party_centroids.get(party, []) + if not vals: + continue + xs_raw = [v[0] for v in vals] + ys_raw = [v[1] for v in vals] + xs = _ema_smooth(xs_raw, smooth_alpha) + ys = _ema_smooth(ys_raw, smooth_alpha) + # Ensure customdata preserves NaNs + custom_raw = [ + ( + float(x) if (x is not None and not np.isnan(x)) else float(np.nan), + float(y) if (y is not None and not np.isnan(y)) else float(np.nan), + ) + for x, y in zip(xs_raw, ys_raw) + ] + colour = PARTY_COLOURS.get(party, "#9E9E9E") + fig.add_trace( + go.Scatter( + x=xs, + y=ys, + mode="lines+markers", + name=party, + text=windows, + customdata=custom_raw, + line=dict(color=colour, shape="spline", smoothing=1.3), + marker=dict(color=colour, size=8), + ) + ) + trace_count += 1 + + return fig, trace_count, None + logger = logging.getLogger(__name__) @@ -266,10 +621,11 @@ def load_positions( ) # Axis orientation is guaranteed by compute_2d_axes via canonical party anchors - # (Procrustes alignment + sign-fixing). Lock labels to their known semantic meaning - # instead of relying on the keyword classifier which can fall back to generic labels. - axis_def["x_label"] = "Links\u2013Rechts" - axis_def["y_label"] = "Progressief\u2013Conservatief" + # (Procrustes alignment + sign-fixing). We do NOT forcibly override axis labels + # here so the classifier output (if available) can be surfaced conditionally in + # the UI based on per-window confidence. Label selection is performed at render + # time in the tabs so we can show fallback labels while still surfacing the + # classifier interpretation and confidence when informative. # Filter displayed windows by window_size AFTER PCA computation. if window_size == "annual": @@ -642,65 +998,91 @@ def _render_scree_plot(importances: List[float], n_show: int = 15) -> None: def _build_party_axis_figure( - party_scores: Dict[str, List[float]], + party_coords: Dict[str, Tuple[float, float]], comp_sel: int, theme: dict, bootstrap_data: Optional[Dict[str, Dict]] = None, ) -> Optional[go.Figure]: """Build a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`. - Pure function that returns a go.Figure (no Streamlit calls). + Accepts explicit per-party 2D coordinates (x,y) and uses the component selection to + pick the value (comp_sel==1 -> x, comp_sel==2 -> y). This makes the API explicit and + avoids indexing into long SVD vectors. - Args: - party_scores: {party_name: [float*k]} — mean SVD vectors per party. - comp_sel: 1-indexed SVD axis number. - theme: dict with keys label, explanation, positive_pole, negative_pole, flip. - bootstrap_data: optional output from compute_party_bootstrap_cis — - {party: {centroid, ci_lower, ci_upper, std, n_mps}}. - When provided, 95% CI is shown in hover text and N=1 parties get a diamond - marker. Error bars are intentionally not drawn — use hover to see the interval. - - Returns: - go.Figure, or None if no data available. + Returns go.Figure or None if no data available. """ - if not party_scores: + if not party_coords: return None - axis_idx = comp_sel - 1 # 0-based index into the 50-dim vector + if comp_sel not in (1, 2): + raise ValueError( + "_build_party_axis_figure only supports comp_sel 1 or 2 when using explicit coords" + ) + + axis_idx = comp_sel - 1 flip = theme.get("flip", False) - data: list[dict] = [] - for party, vec in party_scores.items(): - if axis_idx < len(vec): - score = vec[axis_idx] + + parties = [] + scores = [] + colours = [] + + # Support two shapes for party_coords: + # - explicit 2D coords: (x, y) + # - full SVD vectors (len>2) where we should pick the axis_idx element + for party, val in party_coords.items(): + try: + # explicit (x, y) + if hasattr(val, "__len__") and len(val) == 2: + x, y = val + score = float(x if axis_idx == 0 else y) + else: + # treat as sequence/array-like of full SVD vector + score = float(val[axis_idx]) + if flip: score = -score - data.append({"party": party, "score": score}) + except Exception: + # skip malformed entries silently + continue - if not data: - return None + parties.append(party) + scores.append(score) + colours.append(PARTY_COLOURS.get(party, "#9E9E9E")) - scores = [d["score"] for d in data] - parties = [d["party"] for d in data] - colours = [PARTY_COLOURS.get(p, "#9E9E9E") for p in parties] + if not scores: + return None # Build hover text: include N when bootstrap data available + hover = [] + symbols = [] if bootstrap_data: - hover = [] for p, s in zip(parties, scores): bd = bootstrap_data.get(p) - n_mps = bd["n_mps"] if bd else "?" - hover.append(f"{p}: {s:.3f} (N={n_mps})") + if bd: + n_mps = bd.get("n_mps", "?") + ci_low = None + ci_high = None + try: + ci_low = float(bd["ci_lower"][axis_idx]) + ci_high = float(bd["ci_upper"][axis_idx]) + except Exception: + pass + if ci_low is not None and ci_high is not None: + hover.append( + f"{p}: {s:.3f} (N={n_mps}, 95%-BI: [{ci_low:.3f}, {ci_high:.3f}])" + ) + else: + hover.append(f"{p}: {s:.3f} (N={n_mps})") + symbols.append("diamond" if n_mps == 1 else "circle") + else: + hover.append(f"{p}: {s:.3f}") + symbols.append("circle") + marker_kwargs = {"size": 14, "color": colours, "symbol": symbols} else: hover = [f"{p}: {s:.3f}" for p, s in zip(parties, scores)] - - # Determine axis labels: left = progressive pole, right = conservative pole - pos_pole = theme.get("positive_pole", "") - neg_pole = theme.get("negative_pole", "") - left_label = pos_pole if flip else neg_pole - right_label = neg_pole if flip else pos_pole + marker_kwargs = {"size": 14, "color": colours} fig = go.Figure() - # Baseline x_min, x_max = min(scores) * 1.15, max(scores) * 1.15 if x_min == x_max: x_min, x_max = x_min - 1, x_max + 1 @@ -715,34 +1097,7 @@ def _build_party_axis_figure( ) ) - # Build marker kwargs and hover text. - # When bootstrap data is available, 95% CI is embedded in the hover tooltip and - # N=1 parties get a diamond marker to signal low-reliability estimates. - # Error bars are intentionally omitted — they clutter the 1D chart. - marker_kwargs: dict = {"size": 14, "color": colours} - - if bootstrap_data: - hover = [] - symbols = [] - for p, s in zip(parties, scores): - bd = bootstrap_data.get(p) - if bd: - n_mps = bd["n_mps"] - ci_low = float(bd["ci_lower"][axis_idx]) - ci_high = float(bd["ci_upper"][axis_idx]) - hover.append( - f"{p}: {s:.3f} (N={n_mps}, 95%-BI: [{ci_low:.3f}, {ci_high:.3f}])" - ) - symbols.append("diamond" if n_mps == 1 else "circle") - else: - hover.append(f"{p}: {s:.3f}") - symbols.append("circle") - marker_kwargs["symbol"] = symbols - else: - hover = [f"{p}: {s:.3f}" for p, s in zip(parties, scores)] - - # Party markers - scatter_kwargs: dict = { + scatter_kwargs = { "x": scores, "y": [0] * len(scores), "mode": "markers+text", @@ -755,6 +1110,11 @@ def _build_party_axis_figure( } fig.add_trace(go.Scatter(**scatter_kwargs)) + pos_pole = theme.get("positive_pole", "") + neg_pole = theme.get("negative_pole", "") + left_label = pos_pole if flip else neg_pole + right_label = neg_pole if flip else pos_pole + fig.update_layout( height=160, margin={"l": 10, "r": 10, "t": 10, "b": 30}, @@ -773,17 +1133,16 @@ def _build_party_axis_figure( def _render_party_axis_chart( - party_scores: Dict[str, List[float]], + party_coords: Dict[str, Tuple[float, float]], comp_sel: int, theme: dict, bootstrap_data: Optional[Dict[str, Dict]] = None, ) -> None: """Render a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`. - Delegates figure construction to _build_party_axis_figure, then renders via - st.plotly_chart. + Expects explicit per-party coords mapping (party -> (x,y)) for components 1 & 2. """ - fig = _build_party_axis_figure(party_scores, comp_sel, theme, bootstrap_data) + fig = _build_party_axis_figure(party_coords, comp_sel, theme, bootstrap_data) if fig is None: st.caption("_Partijdata niet beschikbaar voor deze as._") return @@ -936,6 +1295,11 @@ def build_compass_tab(db_path: str, window_size: str) -> None: # Compass always uses annual windows regardless of the sidebar window_size setting. positions_by_window, axis_def = load_positions(db_path, "annual") + # load_positions may return None for axis_def when resources are missing + # (e.g. classifier fallback or failed enrichment). Guard so UI rendering + # code doesn't crash on axis_def.get calls. + if axis_def is None: + axis_def = {} if not positions_by_window: st.warning( "Geen positiedata beschikbaar. Controleer of de pipeline is gedraaid." @@ -1026,8 +1390,34 @@ def build_compass_tab(db_path: str, window_size: str) -> None: st.info("Geen partijen met genoeg Kamerleden voor dit venster.") return - _x_label = axis_def.get("x_label", "Links\u2013Rechts") - _y_label = axis_def.get("y_label", "Progressief\u2013Conservatief") + # The first two SVD axes are clear, interpretable axes for our dataset. + # Show the classifier-provided full labels on the compass unconditionally + # so users see the canonical interpretation. We keep the confidence-based + # captions/interpretations in the expander but do not hide the axis titles + # for the compass. Note: the vertical axis title is rotated by Plotly — + # this can make "Progressief–Conservatief" look reversed because the word + # "Progressief" appears at the top when rendered; we therefore add explicit + # directional annotations to make the polarity unambiguous. + # Prefer classifier-provided labels for the first two axes. However, the + # classifier sometimes returns the concise numeric fallbacks "As 1"/"As 2" + # when it couldn't find an interpretable label. For the compass we prefer + # conventional semantic defaults instead of the generic "As N" strings so + # the chart remains readable. + _raw_x = axis_def.get("x_label") + _raw_y = axis_def.get("y_label") + + # Use the classifier helper to map internal/modal labels (e.g. "As 1") to + # user-facing labels. Import at function-time to avoid module import cycles + # and keep explorer lightweight. If the helper is unavailable fall back to + # conventional semantic defaults so the UI remains readable. + try: + from analysis.axis_classifier import display_label_for_modal + + _x_label = display_label_for_modal(_raw_x, "x") + _y_label = display_label_for_modal(_raw_y, "y") + except Exception: + _x_label = _raw_x or "Links\u2013Rechts" + _y_label = _raw_y or "Progressief\u2013Conservatief" if level == "Partijen": # Aggregate to party centroids @@ -1213,8 +1603,36 @@ def build_trajectories_tab(db_path: str, window_size: str) -> None: st.markdown("Hoe bewegen partijen over de tijdsvensters heen?") positions_by_window, axis_def = load_positions(db_path, window_size) + if axis_def is None: + axis_def = {} if not positions_by_window: - st.warning("Geen positiedata beschikbaar.") + # Instrumentation: record why trajectories tab aborted early + try: + _last_trajectories_diagnostics.update( + { + "stage": "load_positions_empty", + "positions_by_window_len": len(positions_by_window), + } + ) + except Exception: + pass + try: + st.warning("Geen positiedata beschikbaar.") + except Exception: + pass + # If debug enabled, show diagnostics in UI (best-effort) + try: + if get_debug_trajectories_enabled(): + try: + st.text_area( + "Trajectories diagnostics", + json.dumps(_last_trajectories_diagnostics, default=str), + height=160, + ) + except Exception: + pass + except Exception: + pass return party_map = load_party_map(db_path) @@ -1223,11 +1641,22 @@ def build_trajectories_tab(db_path: str, window_size: str) -> None: # Compute party centroids per window centroids: Dict[str, Dict[str, Tuple[float, float]]] = {} all_parties: set = set() + + # Helper to normalise MP names (strip parenthetical first names) to match + # entries in the party_map. This mirrors the behaviour used in the compass + # tab so both tabs resolve parties the same way. + def _strip_paren(name: str) -> str: + return re.sub(r"\s*\([^)]*\)", "", name).strip() + for wid in windows: pos = positions_by_window.get(wid, {}) per_party: Dict[str, List[Tuple[float, float]]] = {} for mp_name, (x, y) in pos.items(): - party = party_map.get(mp_name, "Unknown") + # Try exact match first, then stripped-name match to handle + # variants like "Dijk, J.P. (Jimmy)" -> "Dijk, J.P." used in mp_metadata + party = party_map.get(mp_name) or party_map.get( + _strip_paren(mp_name), "Unknown" + ) if party == "Unknown": continue per_party.setdefault(party, []).append((x, y)) @@ -1242,6 +1671,21 @@ def build_trajectories_tab(db_path: str, window_size: str) -> None: all_parties_sorted = sorted(all_parties) + # If no parties were found after mapping MPs to parties, show a helpful + # message instead of rendering an empty chart. This commonly happens when + # the party map failed to load (DB error) or the min_mps threshold filtered + # out all parties. + if not all_parties_sorted: + st.info( + "Geen partijen beschikbaar om trajecten te tekenen. Controleer of de party mapping is geladen (mp_metadata) en of de minimum Kamerleden-instelling te hoog staat." + ) + try: + st.caption(f"Bekende partijen in party_map: {len(party_map)}") + except Exception: + pass + # Do not return here: allow per-MP fallback plotting below when no + # party-level centroids are available so the user still sees trajectories. + # Default: show CDA, D66, VVD — the three parties that span the political centre default_parties = [p for p in ["CDA", "D66", "VVD"] if p in all_parties] if not default_parties: @@ -1255,19 +1699,223 @@ def build_trajectories_tab(db_path: str, window_size: str) -> None: default=default_parties, ) - # Smoothing slider — EMA alpha controls noise reduction - smooth_alpha = st.slider( - "Glad maken (EMA-\u03b1)", - min_value=0.1, - max_value=1.0, - value=0.35, - step=0.05, - help=( - "\u03b1=1.0 toont de ruwe data; lagere waarden maken de lijn gladder. " - "Standaard 0.35 voor een goed evenwicht tussen detail en ruis." - ), + # Ensure EMA smoothing helper is available for per-MP fallback plotting which + # appears earlier in the function. Define here so calls above won't fail. + def _ema_smooth(values: List[float], alpha: float) -> List[float]: + if not values or alpha >= 1.0: + return values + smoothed = [values[0]] + for v in values[1:]: + smoothed.append(alpha * v + (1 - alpha) * smoothed[-1]) + return smoothed + + # default smoothing alpha used for inline per-MP plotting; may be overridden + # by the smoothing controls shown later in the UI. + smooth_alpha = 0.35 + + # If no party-level centroids were computed, fall back to per-MP trajectories + # so the user still sees a plot even when the party_map is missing or empty. + if not centroids: + # Build per-MP time series from positions_by_window + mp_positions: Dict[str, Dict[str, Tuple[float, float]]] = {} + for wid in windows: + pos = positions_by_window.get(wid, {}) + for mp_name, xy in pos.items(): + # Defensive conversion: skip malformed coordinates instead of raising + try: + x, y = float(xy[0]), float(xy[1]) + except Exception: + # skip malformed entries silently (diagnostics will show counts) + continue + mp_positions.setdefault(mp_name, {})[wid] = (x, y) + + if not mp_positions: + try: + _last_trajectories_diagnostics.update( + { + "stage": "no_mp_positions", + "mp_positions_count": len(mp_positions), + } + ) + except Exception: + pass + try: + st.info("Geen positiedata beschikbaar voor trajectplotten.") + except Exception: + pass + # show diagnostics when debug enabled + try: + if get_debug_trajectories_enabled(): + try: + st.text_area( + "Trajectories diagnostics", + json.dumps(_last_trajectories_diagnostics, default=str), + height=160, + ) + except Exception: + pass + except Exception: + pass + return + + mp_list = sorted(mp_positions.keys()) + default_mps = mp_list[:6] + selected_mps = st.multiselect( + "Selecteer Kamerleden (fallback)", options=mp_list, default=default_mps + ) + + # Plot per-MP trajectories + fig = go.Figure() + trace_count = 0 + for mp in selected_mps: + wids_sorted = sorted(mp_positions[mp].keys()) + xs_raw = [mp_positions[mp][w][0] for w in wids_sorted] + ys_raw = [mp_positions[mp][w][1] for w in wids_sorted] + xs = _ema_smooth(xs_raw, smooth_alpha) + ys = _ema_smooth(ys_raw, smooth_alpha) + custom_raw = [(float(rx), float(ry)) for rx, ry in zip(xs_raw, ys_raw)] + fig.add_trace( + go.Scatter( + x=xs, + y=ys, + mode="lines+markers", + name=mp, + text=wids_sorted, + customdata=custom_raw, + line=dict(color="#888888", shape="spline", smoothing=1.3), + marker=dict(color="#888888", size=6), + hovertemplate=( + f"{mp}
" + "venster: %{text}
" + "x (smoothed): %{x:.3f}
" + "x (raw): %{customdata[0]:.3f}
" + "y (smoothed): %{y:.3f}
" + "y (raw): %{customdata[1]:.3f}" + ), + ) + ) + trace_count += 1 + + _add_y_direction_annotations(fig) + if trace_count == 0: + st.info( + "Geen trajecten getekend: geen geselecteerde Kamerleden met voldoende data." + ) + else: + st.plotly_chart(fig, use_container_width=True) + return + + # Developer override: if EXPLORER_FORCE_SHOW_TRAJECTORIES=1 in the + # environment, bypass party filtering and show the first MPs' trajectories + # directly (helps diagnose production environments where party mapping + # or filtering prevents any traces from appearing). This is safe to keep + # in main because it only triggers when explicitly enabled. + if os.getenv("EXPLORER_FORCE_SHOW_TRAJECTORIES") in ("1", "true", "True"): + # Build per-MP time series from positions_by_window and plot first 6 MPs + mp_positions: Dict[str, Dict[str, Tuple[float, float]]] = {} + for wid in windows: + pos = positions_by_window.get(wid, {}) + for mp_name, (x, y) in pos.items(): + mp_positions.setdefault(mp_name, {})[wid] = (float(x), float(y)) + + mp_list = sorted(mp_positions.keys()) + if not mp_list: + st.info("Geen MP-positiegegevens beschikbaar om te tonen.") + return + + sample_mps = mp_list[:6] + fig = go.Figure() + for mp in sample_mps: + wids_sorted = sorted(mp_positions[mp].keys()) + xs_raw = [mp_positions[mp][w][0] for w in wids_sorted] + ys_raw = [mp_positions[mp][w][1] for w in wids_sorted] + xs = _ema_smooth(xs_raw, 0.35) + ys = _ema_smooth(ys_raw, 0.35) + custom_raw = [(float(rx), float(ry)) for rx, ry in zip(xs_raw, ys_raw)] + fig.add_trace( + go.Scatter( + x=xs, + y=ys, + mode="lines+markers", + name=mp, + text=wids_sorted, + customdata=custom_raw, + line=dict(color="#444444", shape="spline", smoothing=1.3), + marker=dict(color="#444444", size=6), + hovertemplate=( + f"{mp}
" + "venster: %{text}
" + "x (smoothed): %{x:.3f}
" + "x (raw): %{customdata[0]:.3f}
" + "y (smoothed): %{y:.3f}
" + "y (raw): %{customdata[1]:.3f}" + ), + ) + ) + _add_y_direction_annotations(fig) + st.plotly_chart(fig, use_container_width=True) + return + + # Debug expander: show data used to build trajectories so we can diagnose + # why no traces are appearing. Leave this collapsed by default in normal + # runs; when troubleshooting it will show counts and small samples. + try: + # Add a little opt-in checkbox in the UI to enable debug diagnostic output + debug_checkbox = False + try: + debug_checkbox = st.checkbox( + "Enable trajectories diagnostics (show extra info)", + value=get_debug_trajectories_enabled(), + ) + except Exception: + debug_checkbox = get_debug_trajectories_enabled() + if debug_checkbox: + try: + with st.expander( + "DEBUG: Trajectories data (showing diagnostics)", expanded=False + ): + st.write("windows (count):", len(windows)) + st.write("windows sample:", windows[:10]) + st.write("party_map entries:", len(party_map)) + st.write("parties with centroids:", len(all_parties_sorted)) + st.write("default_parties:", default_parties) + st.write("selected_parties:", selected_parties) + st.write("min_mps setting:", min_mps) + # sample centroid counts per party + sample = { + p: len(centroids.get(p, {})) + for p in list(all_parties_sorted)[:8] + } + st.write("sample centroid window counts per party:", sample) + except Exception: + pass + except Exception: + # Don't crash UI if st isn't available or expander fails + pass + + # Smoothing controls + smoothing_method = st.selectbox( + "Smoothing methode", + options=["EMA", "Spline", "None"], + index=0, + help="EMA = exponential moving average; Spline = low-degree polynomial spline fit; None = raw centroids", ) + # EMA alpha only shown/used when EMA is selected + smooth_alpha = 1.0 + if smoothing_method == "EMA": + smooth_alpha = st.slider( + "Glad maken (EMA-\u03b1)", + min_value=0.1, + max_value=1.0, + value=0.35, + step=0.05, + help=( + "\u03b1=1.0 toont de ruwe data; lagere waarden maken de lijn gladder. " + "Standaard 0.35 voor een goed evenwicht tussen detail en ruis." + ), + ) + def _ema_smooth(values: List[float], alpha: float) -> List[float]: """Apply exponential moving average; alpha=1.0 means no smoothing.""" if not values or alpha >= 1.0: @@ -1277,7 +1925,65 @@ def build_trajectories_tab(db_path: str, window_size: str) -> None: smoothed.append(alpha * v + (1 - alpha) * smoothed[-1]) return smoothed + def _spline_smooth(values: List[float]) -> List[float]: + """Perform a basic low-degree polynomial fit over index -> value and evaluate at indices. + + This provides a simple spline-like smoothing without adding scipy as a dependency. + For very small N this returns the raw values. + """ + n = len(values) + if n <= 2: + return values + deg = min(3, n - 1) + try: + idx = np.arange(n, dtype=float) + coeffs = np.polyfit(idx, np.array(values, dtype=float), deg=deg) + smooth = np.polyval(coeffs, idx) + return [float(v) for v in smooth] + except Exception: + return values + fig = go.Figure() + trace_count = 0 + # New: delegate plotting selection to helper for testability + # Note: select_trajectory_plot_data returns (fig, trace_count, banner_text) + try: + fig2, trace_count2, banner_text = select_trajectory_plot_data( + positions_by_window, party_map, windows, selected_parties, smooth_alpha + ) + # If helper returned a figure, replace + if fig2 is not None: + fig = fig2 + trace_count = trace_count2 + if banner_text: + try: + st.caption(banner_text) + except Exception: + pass + try: + _last_trajectories_diagnostics.update({"banner_text": banner_text}) + except Exception: + pass + except Exception as e: + tb = traceback.format_exc() + # attach diagnostics to the helper and module + try: + select_trajectory_plot_data._last_diagnostics = {"exception": tb} + except Exception: + pass + try: + _last_trajectories_diagnostics.update( + {"stage": "select_helper_exception", "exception": tb} + ) + except Exception: + pass + logger.exception("select_trajectory_plot_data failed") + debug_enabled = get_debug_trajectories_enabled() + if debug_enabled: + try: + st.text_area("select_trajectory_plot_data traceback", tb, height=240) + except Exception: + pass for party in selected_parties: if party not in centroids: continue @@ -1286,6 +1992,8 @@ def build_trajectories_tab(db_path: str, window_size: str) -> None: ys_raw = [centroids[party][w][1] for w in wids_sorted] xs = _ema_smooth(xs_raw, smooth_alpha) ys = _ema_smooth(ys_raw, smooth_alpha) + # Preserve raw (unsmoothed) values per-point so hover can show both raw and smoothed + custom_raw = [(float(rx), float(ry)) for rx, ry in zip(xs_raw, ys_raw)] colour = PARTY_COLOURS.get(party, "#9E9E9E") fig.add_trace( go.Scatter( @@ -1294,25 +2002,102 @@ def build_trajectories_tab(db_path: str, window_size: str) -> None: mode="lines+markers", name=party, text=wids_sorted, # full window ID for hover + customdata=custom_raw, line=dict(color=colour, shape="spline", smoothing=1.3), marker=dict(color=colour, size=8), hovertemplate=( f"{party}
" "venster: %{text}
" - "x: %{x:.3f}
y: %{y:.3f}" + "x (smoothed): %{x:.3f}
" + "x (raw): %{customdata[0]:.3f}
" + "y (smoothed): %{y:.3f}
" + "y (raw): %{customdata[1]:.3f}" ), ) ) + trace_count += 1 + + # For trajectories, the chart spans multiple windows. Use the classifier's + # per-window confidences aggregated (mean) to decide whether to use the + # classifier label or fall back to the conventional short label. + _THRESHOLD = 0.65 + x_conf_map = axis_def.get("x_label_confidence", {}) or {} + y_conf_map = axis_def.get("y_label_confidence", {}) or {} + + def _mean_conf(m: dict) -> Optional[float]: + vals = [v for v in m.values() if v is not None] + if not vals: + return None + return float(sum(vals) / len(vals)) + + x_mean = _mean_conf(x_conf_map) + y_mean = _mean_conf(y_conf_map) + + +def choose_trajectory_title(axis_def: dict, axis: str, threshold: float = 0.65) -> str: + """Choose a short trajectory axis title based on aggregated confidence. + + axis: 'x' or 'y'. Returns axis_def label when its mean confidence >= threshold, + otherwise returns the compact fallback 'As 1' / 'As 2'. Matches previous logic. + """ + _TH = threshold + conf_map = axis_def.get(f"{axis}_label_confidence", {}) or {} + vals = [v for v in conf_map.values() if v is not None] + mean = float(sum(vals) / len(vals)) if vals else None + label = axis_def.get(f"{axis}_label") + if mean is not None and mean >= _TH and label: + return label + # Prefer the user-facing semantic fallback via the classifier helper + try: + from analysis.axis_classifier import display_label_for_modal + + fallback_modal = "As 1" if axis == "x" else "As 2" + return display_label_for_modal(fallback_modal, axis) + except Exception: + return "As 1" if axis == "x" else "As 2" + + x_title = choose_trajectory_title(axis_def, "x", threshold=_THRESHOLD) + y_title = choose_trajectory_title(axis_def, "y", threshold=_THRESHOLD) fig.update_layout( title="Partij trajectories", - xaxis_title=axis_def.get("x_label", "Links\u2013Rechts"), - yaxis_title=axis_def.get("y_label", "Progressief\u2013Conservatief"), + xaxis_title=x_title, + yaxis_title=y_title, height=600, legend_title_text="Partij", ) _add_y_direction_annotations(fig) - st.plotly_chart(fig, use_container_width=True) + # If no traces were added to the figure, show a diagnostic message so the + # user knows why the plot is empty. + try: + _last_trajectories_diagnostics.update({"trace_count": trace_count}) + except Exception: + pass + debug_enabled = get_debug_trajectories_enabled() + if trace_count == 0: + try: + st.info( + "Geen trajecten getekend: geen geselecteerde partijen met voldoende data. Controleer de partijselectie en de 'Min. Kamerleden per partij' instelling." + ) + except Exception: + pass + if debug_enabled: + try: + st.text_area( + "Trajectories diagnostics", + json.dumps(_last_trajectories_diagnostics, default=str), + height=240, + ) + except Exception: + try: + st.json(_last_trajectories_diagnostics) + except Exception: + pass + else: + try: + st.plotly_chart(fig, use_container_width=True) + except Exception: + pass # --------------------------------------------------------------------------- @@ -1756,13 +2541,71 @@ def build_svd_components_tab(db_path: str) -> None: motions = comp_map.get(comp_sel, []) # Party axis chart - party_scores = load_party_axis_scores(db_path) + # Default party scores (single-window mean vectors) as a fallback + party_scores_default = load_party_axis_scores(db_path) party_mp_vectors = load_party_mp_vectors(db_path) bootstrap_data = ( _cached_bootstrap_cis(party_mp_vectors) if party_mp_vectors else None ) + + # For components 1 and 2, prefer MP-centroid values from the Procrustes-aligned + # positions_by_window so the compass matches the trajectories (MP-mean centroids). + if comp_sel in (1, 2): + try: + positions_by_window, axis_def = load_positions(db_path) + if axis_def is None: + axis_def = {} + # choose the current parliament window if present + window = ( + "current_parliament" + if "current_parliament" in positions_by_window + else sorted(positions_by_window.keys())[-1] + ) + pos = positions_by_window.get(window, {}) + + # build party -> list of MP x/y coords + party_map = load_party_map(db_path) + per_party_coords: dict = {} + for ent, (x, y) in pos.items(): + party = party_map.get(ent) + if party is None: + continue + per_party_coords.setdefault(party, []).append((x, y)) + + # construct party_scores mapping: prefer MP centroid [x,y], fallback to default vector + party_scores = {} + for party in set( + list(per_party_coords.keys()) + list(party_scores_default.keys()) + ): + coords = per_party_coords.get(party) + if coords: + xs = [c[0] for c in coords] + ys = [c[1] for c in coords] + party_scores[party] = [float(np.mean(xs)), float(np.mean(ys))] + else: + # fallback: use the default single-window SVD mean vector + party_scores[party] = party_scores_default.get(party, []) + + except Exception: + # On any error, fall back to the old behaviour + logger.exception( + "Failed to derive party centroids from positions_by_window; falling back to load_party_axis_scores" + ) + party_scores = party_scores_default + else: + party_scores = party_scores_default + + # Convert party_scores (possibly [x,y] lists or legacy vectors) into explicit (x,y) coords + party_coords: dict = {} + for p, v in party_scores.items(): + try: + if v and len(v) >= 2: + party_coords[p] = (float(v[0]), float(v[1])) + except Exception: + continue + _render_party_axis_chart( - party_scores, comp_sel, theme, bootstrap_data=bootstrap_data + party_coords, comp_sel, theme, bootstrap_data=bootstrap_data ) # Batch-fetch motion details (title, date, policy_area, url, body_text, voting_results) @@ -2022,16 +2865,25 @@ def run_app() -> None: # About section with st.sidebar.expander("ℹ️ Over", expanded=False): try: - con = duckdb.connect(database=db_path, read_only=True) - n_motions = con.execute("SELECT COUNT(*) FROM motions").fetchone()[0] - n_fused = con.execute("SELECT COUNT(*) FROM fused_embeddings").fetchone()[0] - n_sim = con.execute("SELECT COUNT(*) FROM similarity_cache").fetchone()[0] - con.close() - st.markdown( - f"**Moties:** {n_motions:,} \n" - f"**Fused embeddings:** {n_fused:,} \n" - f"**Similarity cache:** {n_sim:,}" - ) + if _DUCKDB_AVAILABLE: + con = duckdb.connect(database=db_path, read_only=True) + n_motions = con.execute("SELECT COUNT(*) FROM motions").fetchone()[0] + n_fused = con.execute( + "SELECT COUNT(*) FROM fused_embeddings" + ).fetchone()[0] + n_sim = con.execute("SELECT COUNT(*) FROM similarity_cache").fetchone()[ + 0 + ] + con.close() + st.markdown( + f"**Moties:** {n_motions:,} \n" + f"**Fused embeddings:** {n_fused:,} \n" + f"**Similarity cache:** {n_sim:,}" + ) + else: + st.warning( + "DuckDB niet beschikbaar in deze Python-omgeving; DB diagnostics zijn niet beschikbaar." + ) except Exception as e: st.warning(f"DB niet bereikbaar: {e}") diff --git a/explorer_helpers.py b/explorer_helpers.py new file mode 100644 index 0000000..9ed5e11 --- /dev/null +++ b/explorer_helpers.py @@ -0,0 +1,227 @@ +"""Helper utilities used by explorer.py. + +Primary export: +- compute_party_coords: compute per-party (x_mean, y_mean) from positions_by_window. + +This module is intentionally free of Streamlit side-effects to be easy to unit test. +""" + +from __future__ import annotations + +import logging +import math +import re +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + + +def _strip_paren(s: str) -> str: + # helper used in plan to try to strip parenthetical variants + return s.split("(")[0].strip() + + +def inspect_positions_for_issues( + positions_by_window: Dict[str, Dict[str, Tuple[float, float]]], + party_map: Dict[str, str], +) -> Dict[str, Any]: + """Inspect positions_by_window for simple issues/summary. + + Returns a dictionary with keys including the previous ones (windows_count, + window_labels, mp_id_set, party_map_count, parties_with_centroid_counts, + mismatched_mp_ids_sample) plus: + - mp_positions_count: int (num unique MP ids seen) + - mp_positions_sample: list[str] (sorted sample up to 10) + - windows_with_no_positions: list[str] + + This helper remains pure and import-safe so unit tests can exercise it. + """ + windows = list(positions_by_window.keys()) + windows_count = len(windows) + window_labels = sorted(windows)[:10] + + mp_id_set: Set[str] = set() + parties_with_centroid_counts: Dict[str, int] = {} + mismatched: Set[str] = set() + windows_with_no_positions: List[str] = [] + + for win, pos in positions_by_window.items(): + if not pos: + windows_with_no_positions.append(win) + continue + present_parties: Set[str] = set() + for ent in pos.keys(): + if not ent: + continue + mp_id_set.add(ent) + party = party_map.get(ent) + if party is None: + # try stripping paren variant + party = party_map.get(_strip_paren(ent)) + if party: + present_parties.add(party) + else: + mismatched.add(ent) + + for p in present_parties: + parties_with_centroid_counts[p] = parties_with_centroid_counts.get(p, 0) + 1 + + mismatched_mp_ids_sample = sorted(list(mismatched))[:10] + + mp_positions_sample = sorted(list(mp_id_set))[:10] + mp_positions_count = len(mp_id_set) + + return { + "windows_count": windows_count, + "window_labels": window_labels, + "mp_id_set": mp_id_set, + "party_map_count": len(party_map), + "parties_with_centroid_counts": parties_with_centroid_counts, + "mismatched_mp_ids_sample": mismatched_mp_ids_sample, + "mp_positions_sample": mp_positions_sample, + "mp_positions_count": mp_positions_count, + "windows_with_no_positions": windows_with_no_positions, + } + + +def compute_party_coords( + positions_by_window: Dict[str, Dict[str, Tuple[float, float]]], + party_map: Dict[str, str], + window_id: str, + fallback_party_scores: Optional[Dict[str, List[float]]] = None, +) -> Tuple[Dict[str, Tuple[float, float]], Set[str]]: + """ + Compute per-party centroids (x_mean, y_mean) for a specific window. + + Args: + positions_by_window: mapping window_id -> {entity_name: (x, y)} + party_map: mapping mp_name -> party abbreviation (Normalized) + window_id: which window to compute centroids for (key into positions_by_window) + fallback_party_scores: optional mapping party -> numeric vector (len>=2). When a + party has no MPs in the window and fallback_party_scores contains an entry, + the first two elements of that vector will be used as a fallback (x,y). + + Returns: + (party_coords, fallback_used) where: + - party_coords: {party: (x_mean, y_mean)} for parties with a computed coord or fallback. + - fallback_used: set of party names where fallback_party_scores was used. + """ + pos = positions_by_window.get(window_id, {}) or {} + + per_party: Dict[str, List[Tuple[float, float]]] = {} + for ent, xy in pos.items(): + if not ent or xy is None: + continue + try: + x, y = float(xy[0]), float(xy[1]) + except Exception: + # skip malformed coords + continue + party = party_map.get(ent) + if party is None: + # try stripped name fallback + party = party_map.get(_strip_paren(ent)) + if not party or party == "Unknown": + continue + per_party.setdefault(party, []).append((x, y)) + + party_coords: Dict[str, Tuple[float, float]] = {} + fallback_used: Set[str] = set() + + # compute means for parties that have MPs + for party, coords in per_party.items(): + xs = [c[0] for c in coords] + ys = [c[1] for c in coords] + # defensive: drop nan/inf + xs = [float(x) for x in xs if not (math.isnan(x) or math.isinf(x))] + ys = [float(y) for y in ys if not (math.isnan(y) or math.isinf(y))] + if not xs or not ys: + continue + party_coords[party] = (float(np.mean(xs)), float(np.mean(ys))) + + # fallback: use supplied party vectors if a party has no MPs in this window + if fallback_party_scores: + for party, vec in fallback_party_scores.items(): + if party in party_coords: + continue + if not vec: + continue + try: + # vec may be list, np.array, etc. + if len(vec) >= 2: + x_f, y_f = float(vec[0]), float(vec[1]) + if ( + math.isnan(x_f) + or math.isnan(y_f) + or math.isinf(x_f) + or math.isinf(y_f) + ): + continue + party_coords[party] = (x_f, y_f) + fallback_used.add(party) + except Exception: + continue + + if fallback_used: + logger.warning( + "compute_party_coords used fallback for parties: %s", + sorted(list(fallback_used)), + ) + + return party_coords, fallback_used + + +def compute_party_centroids( + positions_by_window: Dict[str, Dict[str, Tuple[float, float]]], + party_map: Dict[str, str], + windows: List[str], +) -> Tuple[Dict[str, List[Tuple[float, float]]], Dict[str, Any]]: + """Compute per-party centroids across multiple windows. + + Returns (party_centroids, metadata) + - party_centroids: mapping party -> list of (x,y) tuples of length len(windows). + Entries without MPs are (np.nan, np.nan). + - metadata: dict with keys 'per_party_counts', 'total_windows', 'parties' + """ + party_centroids: Dict[str, List[Tuple[float, float]]] = {} + # collect all parties from party_map values + parties = sorted(set(party_map.values())) + # if no parties known, return empty dict but still metadata + if not parties: + return {}, { + "per_party_counts": {}, + "total_windows": len(windows), + "parties": [], + } + + # initialize lists + for p in parties: + party_centroids[p] = [] + + # for each window, compute party coords using compute_party_coords for that window + for w in windows: + coords, _ = compute_party_coords(positions_by_window or {}, party_map, w) + for p in parties: + if p in coords: + # ensure numeric floats + party_centroids[p].append((float(coords[p][0]), float(coords[p][1]))) + else: + party_centroids[p].append((float(np.nan), float(np.nan))) + + # metadata + per_party_counts: Dict[str, int] = {} + for p, vals in party_centroids.items(): + count = 0 + for x, y in vals: + if not (np.isnan(x) or np.isnan(y)): + count += 1 + per_party_counts[p] = count + + metadata = { + "per_party_counts": per_party_counts, + "total_windows": len(windows), + "parties": parties, + } + return party_centroids, metadata diff --git a/tests/test_diagnose_no_plot_trajectories.py b/tests/test_diagnose_no_plot_trajectories.py new file mode 100644 index 0000000..1d452fa --- /dev/null +++ b/tests/test_diagnose_no_plot_trajectories.py @@ -0,0 +1,49 @@ +import os +import types + +import explorer + + +def test_load_positions_empty_sets_diagnostics(monkeypatch): + # Monkeypatch load_positions to return empty positions + monkeypatch.setattr( + explorer, "load_positions", lambda db_path, window_size: ({}, {}) + ) + monkeypatch.setenv("EXPLORER_DEBUG_TRAJECTORIES", "1") + + # Call build_trajectories_tab; it should set diagnostics and return without exception + explorer.build_trajectories_tab(db_path="unused", window_size="annual") + + assert ( + explorer._last_trajectories_diagnostics.get("stage") == "load_positions_empty" + ) + + +def test_select_helper_exception_is_captured(monkeypatch): + # Provide a minimal non-empty positions_by_window + positions = {"W1": {"mp1": (0.1, 0.2)}} + + def fake_load_positions(db_path, window_size): + return positions, {} + + monkeypatch.setattr(explorer, "load_positions", fake_load_positions) + # Ensure party_map maps the mp so centroids/path that invoke select_trajectory_plot_data + monkeypatch.setattr(explorer, "load_party_map", lambda db_path: {"mp1": "P1"}) + + # Patch select_trajectory_plot_data to raise + def bad_helper(*args, **kwargs): + raise ValueError("boom") + + monkeypatch.setattr(explorer, "select_trajectory_plot_data", bad_helper) + monkeypatch.setenv("EXPLORER_DEBUG_TRAJECTORIES", "1") + + explorer.build_trajectories_tab(db_path="unused", window_size="annual") + + # Ensure the helper function has diagnostics attached and module diagnostics updated + assert getattr(explorer.select_trajectory_plot_data, "_last_diagnostics", None) + assert "exception" in explorer.select_trajectory_plot_data._last_diagnostics + assert ( + explorer._last_trajectories_diagnostics.get("stage") + == "select_helper_exception" + ) + assert "ValueError" in explorer._last_trajectories_diagnostics.get("exception", "") diff --git a/tests/test_explorer_helpers_diagnostics.py b/tests/test_explorer_helpers_diagnostics.py new file mode 100644 index 0000000..5657685 --- /dev/null +++ b/tests/test_explorer_helpers_diagnostics.py @@ -0,0 +1,22 @@ +import numpy as np +from explorer_helpers import inspect_positions_for_issues + + +def test_inspect_positions_for_issues_basic(): + positions_by_window = { + "w1": {"mp1": (1.0, 2.0), "mp2": (float("nan"), float("nan"))}, + "w2": {}, + } + party_map = {"mp1": "P1"} + d = inspect_positions_for_issues(positions_by_window, party_map) + + # basic keys still present + assert d["windows_count"] == 2 + assert isinstance(d["mp_id_set"], set) + # new diagnostics + assert "mp_positions_count" in d + assert d["mp_positions_count"] >= 1 + assert "mp_positions_sample" in d + assert isinstance(d["mp_positions_sample"], list) + assert "windows_with_no_positions" in d + assert isinstance(d["windows_with_no_positions"], list)