fix(explorer): fix scree plot data and add bar+line combo chart

- Use all individual MPs (not party aggregates) for L2-norm computation;
  party-aggregated vectors have near-zero values on some dims due to
  Procrustes alignment, producing spurious zeros
- Sort importances descending so scree plot is properly monotonic
- Relabel x-axis as 'Rang' since dim ordering after Procrustes alignment
  no longer matches original singular value order
- Add Scatter line trace connecting bar tops for elbow visibility
main
Sven Geboers 1 month ago
parent c5cbc89c1f
commit cf22ffc093
  1. 65
      explorer.py

@ -252,23 +252,27 @@ def load_party_axis_scores(db_path: str) -> Dict[str, List[float]]:
@st.cache_data(show_spinner="Scree-plot laden…") @st.cache_data(show_spinner="Scree-plot laden…")
def load_scree_data(db_path: str) -> List[float]: def load_scree_data(db_path: str) -> List[float]:
"""Return a list of component importances (L2-norm of party scores per dimension). """Return component importances (L2-norm per SVD dimension), sorted descending.
Uses the same svd_vectors data as load_party_axis_scores but aggregates across Uses ALL individual MP vectors (entity_type='mp', window='current_parliament'),
all components (0-indexed). Returns a list of length == vector dimensionality (50). excluding party-aggregated rows. Since the stored vectors are U*s (scaled by
singular values), the L2-norm of all MP scores per dimension approximates the
singular value for that dimension. Sorting descending gives the proper scree shape.
Note: Procrustes alignment across sub-windows may scramble the original dimension
ordering, so we sort by magnitude rather than relying on dimension index order.
""" """
try: try:
con = duckdb.connect(database=db_path, read_only=True) con = duckdb.connect(database=db_path, read_only=True)
party_list = sorted(CURRENT_PARLIAMENT_PARTIES)
placeholders = ", ".join("?" for _ in party_list)
rows = con.execute( rows = con.execute(
f"SELECT vector FROM svd_vectors " "SELECT entity_id, vector FROM svd_vectors "
f"WHERE entity_type='mp' AND window_id='current_parliament' " "WHERE entity_type='mp' AND window_id='current_parliament'"
f"AND entity_id IN ({placeholders})",
party_list,
).fetchall() ).fetchall()
# Individual MPs have "Lastname, F." format; party rows are short codes without commas
vectors: List[List[float]] = [] vectors: List[List[float]] = []
for (raw_vec,) in rows: for entity_id, raw_vec in rows:
if "," not in entity_id:
continue # skip party-aggregated rows
if isinstance(raw_vec, str): if isinstance(raw_vec, str):
vec = json.loads(raw_vec) vec = json.loads(raw_vec)
elif isinstance(raw_vec, (bytes, bytearray)): elif isinstance(raw_vec, (bytes, bytearray)):
@ -289,7 +293,7 @@ def load_scree_data(db_path: str) -> List[float]:
col = [v[dim] for v in vectors if dim < len(v)] col = [v[dim] for v in vectors if dim < len(v)]
l2 = sum(x**2 for x in col) ** 0.5 l2 = sum(x**2 for x in col) ** 0.5
importances.append(l2) importances.append(l2)
return importances return sorted(importances, reverse=True)
except Exception: except Exception:
logger.exception("Failed to load scree data") logger.exception("Failed to load scree data")
return [] return []
@ -301,33 +305,47 @@ def load_scree_data(db_path: str) -> List[float]:
def _render_scree_plot(importances: List[float], n_show: int = 15) -> None: def _render_scree_plot(importances: List[float], n_show: int = 15) -> None:
"""Render a bar chart showing relative component importance (scree plot). """Render a bar+line combo chart showing relative SVD component importance.
Bars show the L2-norm (singular value proxy) per rank; a line connects the tops
of the bars to make the 'elbow' in the scree curve easy to spot.
Args: Args:
importances: List of L2-norm scores per component (0-indexed). importances: List of importance values sorted descending (from load_scree_data).
n_show: How many components to display (default: first 15). n_show: How many components to display (default: first 15).
""" """
if not importances: if not importances:
return return
data = importances[:n_show] data = importances[:n_show]
components = list(range(1, len(data) + 1)) ranks = list(range(1, len(data) + 1))
colours = [ bar_colour = "#90CAF9"
PARTY_COLOURS.get("PVV", "#1565C0") if i == 0 else "#90CAF9" line_colour = "#1565C0"
for i in range(len(data)) fig = go.Figure()
] fig.add_trace(
fig = go.Figure(
go.Bar( go.Bar(
x=components, x=ranks,
y=data, y=data,
marker_color=colours, marker_color=bar_colour,
hovertemplate="As %{x}<br>Gewicht: %{y:.2f}<extra></extra>", hovertemplate="Rang %{x}<br>Gewicht: %{y:.2f}<extra></extra>",
showlegend=False,
)
)
fig.add_trace(
go.Scatter(
x=ranks,
y=data,
mode="lines+markers",
line={"color": line_colour, "width": 2},
marker={"size": 6, "color": line_colour},
hoverinfo="skip",
showlegend=False,
) )
) )
fig.update_layout( fig.update_layout(
height=220, height=220,
margin={"l": 10, "r": 10, "t": 10, "b": 30}, margin={"l": 10, "r": 10, "t": 10, "b": 30},
xaxis={ xaxis={
"title": "SVD-as", "title": "Rang",
"tickmode": "linear", "tickmode": "linear",
"tick0": 1, "tick0": 1,
"dtick": 1, "dtick": 1,
@ -342,6 +360,7 @@ def _render_scree_plot(importances: List[float], n_show: int = 15) -> None:
}, },
plot_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)",
bargap=0.2,
) )
st.plotly_chart(fig, use_container_width=True) st.plotly_chart(fig, use_container_width=True)

Loading…
Cancel
Save