feat(explorer): harden load_party_axis_scores (close DB, deterministic params)

Plan: docs/superpowers/plans/2026-03-24-svd-tab-redesign.md
main
Sven Geboers 1 month ago
parent 35dbc8118a
commit 9f3ae15a16
  1. 105
      explorer.py

@ -189,16 +189,21 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
Returns:
{party_name: [float * k]} k = 50 for the canonical 2025 window
"""
con = None
try:
# Use a deterministic, ordered list for parameter binding
party_list = sorted(CURRENT_PARLIAMENT_PARTIES)
if not party_list:
return {}
con = duckdb.connect(database=db_path, read_only=True)
placeholders = ", ".join("?" for _ in CURRENT_PARLIAMENT_PARTIES)
placeholders = ", ".join("?" for _ in party_list)
rows = con.execute(
f"SELECT entity_id, vector FROM svd_vectors "
f"WHERE entity_type='mp' AND window_id='2025' "
f"AND entity_id IN ({placeholders})",
list(CURRENT_PARLIAMENT_PARTIES),
party_list,
).fetchall()
con.close()
return {
row[0]: json.loads(row[1]) if isinstance(row[1], str) else list(row[1])
for row in rows
@ -206,6 +211,9 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
except Exception:
logger.exception("Failed to load party axis scores")
return {}
finally:
if con is not None:
con.close()
@st.cache_data(show_spinner="Moties laden…")
@ -231,6 +239,97 @@ def load_motions_df(db_path: str) -> pd.DataFrame:
con.close()
def _render_party_axis_chart(
party_scores: Dict[str, List[float]], comp_sel: int
) -> None:
"""Render a 1D horizontal Plotly scatter of party positions on SVD axis `comp_sel`.
party_scores: mapping party -> list-like vector (50-dim)
comp_sel: 1-based component index
"""
# Validate component selection
if not isinstance(comp_sel, int) or comp_sel < 1:
st.caption("_Ongeldige SVD-as geselecteerd._")
return
if not party_scores:
st.caption("_Partijdata niet beschikbaar_")
return
axis_idx = comp_sel - 1
parties: List[str] = []
xs: List[float] = []
for party, vec in party_scores.items():
# Ensure vec is indexable/sequence-like
if not isinstance(vec, (list, tuple, np.ndarray)):
# skip malformed entries
continue
try:
raw = vec[axis_idx]
# Convert to float safely
val = float(raw)
except Exception:
# skip entries that cannot be indexed or converted
continue
parties.append(party)
xs.append(val)
if not xs:
st.caption("_Partijdata niet beschikbaar_")
return
try:
x_min = min(xs)
x_max = max(xs)
except Exception:
st.caption("_Onvoldoende gegevens om asbereik te berekenen_")
return
# If min == max, apply symmetric padding around the value.
if x_min == x_max:
padding = 0.5 if x_min == 0 else abs(x_min) * 0.1
if padding <= 0:
padding = 0.5
x_min = x_min - padding
x_max = x_max + padding
else:
# Expand range slightly for visual padding
x_min = x_min * 1.15
x_max = x_max * 1.15
# Build horizontal scatter: y is constant (0) but offset for label placement
ys = [0 for _ in xs]
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=xs,
y=ys,
mode="markers+text",
text=parties,
textposition="top center",
marker=dict(
size=10, color=[PARTY_COLOURS.get(p, "#9E9E9E") for p in parties]
),
hovertemplate="%{text}<br>x: %{x:.3f}<extra></extra>",
)
)
fig.update_layout(
title=f"Partijposities op SVD-as {comp_sel}",
xaxis_title="Negatief ← — → Positief",
yaxis=dict(visible=False),
xaxis=dict(range=[x_min, x_max]),
height=300,
margin=dict(t=40, b=40, l=40, r=40),
showlegend=False,
)
st.plotly_chart(fig, use_container_width=True)
def query_similar(
db_path: str,
source_motion_id: int,

Loading…
Cancel
Save