You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
445 lines
15 KiB
445 lines
15 KiB
import json
|
|
import logging
|
|
import re
|
|
from typing import Optional, Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
from scipy.sparse import csr_matrix
|
|
from scipy.sparse.linalg import svds
|
|
from scipy.linalg import orthogonal_procrustes
|
|
|
|
_HAS_SCIPY = True
|
|
except Exception:
|
|
# Provide lightweight fallbacks for environments without scipy
|
|
csr_matrix = lambda x: x
|
|
|
|
def svds(a, k=1):
|
|
# fallback to numpy.linalg.svd on dense arrays
|
|
U, s, Vt = np.linalg.svd(np.array(a), full_matrices=False)
|
|
# return last k components to mimic scipy.svds behaviour
|
|
return U[:, -k:], s[-k:], Vt[-k:, :]
|
|
|
|
def orthogonal_procrustes(A, B):
|
|
# simple orthogonal Procrustes via SVD: find R minimizing ||A R - B||
|
|
U, _, Vt = np.linalg.svd(A.T.dot(B))
|
|
R = U.dot(Vt)
|
|
scale = 1.0
|
|
return R, scale
|
|
|
|
_HAS_SCIPY = False
|
|
import duckdb
|
|
|
|
from database import MotionDatabase
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
# Map textual votes to numeric values for SVD
|
|
VOTE_MAP = {
|
|
"Voor": 1.0,
|
|
"voor": 1.0,
|
|
"Tegen": -1.0,
|
|
"tegen": -1.0,
|
|
"Geen stem": 0.0,
|
|
"Onbekend": 0.0,
|
|
"Onbekend stem": 0.0,
|
|
"Blanco": 0.0,
|
|
}
|
|
|
|
# Mapping from short party names (as they appear in party-level vote rows)
|
|
# to canonical party names in mp_metadata. Parties not listed here are either
|
|
# already matching or are skipped (no valid mp_metadata coverage).
|
|
_PARTY_NAME_MAP = {
|
|
"NSC": "Nieuw Sociaal Contract",
|
|
"Gündogan": "Gündoğan",
|
|
"Keijzer": "Lid Keijzer",
|
|
# Pre-merger: both GroenLinks and PvdA votes map to the merged faction
|
|
"GroenLinks": "GroenLinks-PvdA",
|
|
"PvdA": "GroenLinks-PvdA",
|
|
# Omtzigt initially sat alone before founding NSC
|
|
"Omtzigt": "Nieuw Sociaal Contract",
|
|
}
|
|
|
|
# Party names for which we have no usable mp_metadata (tiny noise, skip expansion)
|
|
_SKIP_PARTIES = {"Brinkman", "Bontes", "Krol", "Van Kooten-Arissen"}
|
|
|
|
# Special-character corrections for individual vote name parts
|
|
_NAME_CHAR_FIXES: Dict[str, str] = {
|
|
"Gündogan": "Gündoğan",
|
|
}
|
|
|
|
|
|
def _votes_name_to_meta_format(votes_name: str) -> str:
|
|
"""Convert an mp_votes individual-record name to mp_metadata canonical format.
|
|
|
|
mp_votes format: ``{surname} {lowercase_tussenvoegsel}, {initials} ({FirstName})``
|
|
e.g. ``Dijk van, I. (Inge)`` → ``Van Dijk, I.``
|
|
``Beer de, M.E.E.`` → ``De Beer, M.E.E.``
|
|
``Abassi el, I.`` → ``El Abassi, I.``
|
|
``Baarle van, S.R.T.`` → ``Van Baarle, S.R.T.``
|
|
|
|
mp_metadata format: ``{Capital_tussenvoegsel} {Achternaam}, {initials}``
|
|
|
|
Steps:
|
|
1. Split on ``, `` → name_part, initials_part.
|
|
2. Strip parenthetical first name from initials_part.
|
|
3. In name_part, isolate trailing lowercase words as tussenvoegsel;
|
|
the rest is the achternaam.
|
|
4. Reconstruct as ``{Capitalized tussenvoegsel} {achternaam}, {initials}``.
|
|
5. Apply special-character fixes.
|
|
"""
|
|
if "," not in votes_name:
|
|
return votes_name
|
|
|
|
comma_idx = votes_name.index(",")
|
|
name_part = votes_name[:comma_idx].strip()
|
|
initials_part = votes_name[comma_idx + 1 :].strip()
|
|
|
|
# Remove parenthetical first name, e.g. "(Inge)" or "(Jan-Willem)"
|
|
initials_part = re.sub(r"\s*\([^)]+\)$", "", initials_part).strip()
|
|
|
|
# Split name_part into words; trailing lowercase words are tussenvoegsel
|
|
words = name_part.split()
|
|
# Find split point: last run of lowercase words at the end
|
|
split = len(words)
|
|
for i in range(len(words) - 1, -1, -1):
|
|
if words[i][0].islower():
|
|
split = i
|
|
else:
|
|
break
|
|
achternaam_words = words[:split]
|
|
tussenvoegsel_words = words[split:]
|
|
|
|
if tussenvoegsel_words:
|
|
# Capitalize the first letter of the first tussenvoegsel word
|
|
tussenvoegsel_words[0] = tussenvoegsel_words[0].capitalize()
|
|
canonical = (
|
|
" ".join(tussenvoegsel_words + achternaam_words) + ", " + initials_part
|
|
)
|
|
else:
|
|
canonical = " ".join(achternaam_words) + ", " + initials_part
|
|
|
|
# Apply special-character fixes
|
|
for bad, good in _NAME_CHAR_FIXES.items():
|
|
canonical = canonical.replace(bad, good)
|
|
|
|
return canonical
|
|
|
|
|
|
def _build_expanded_rows(
|
|
db_path: str, start_date: str, end_date: str
|
|
) -> List[Tuple[int, str, str, str]]:
|
|
"""Build vote rows expanding party-level votes to individual MPs.
|
|
|
|
For motions that have only party-level vote records (mp_name is a party code,
|
|
not a 'Lastname, F.' individual), each party vote is expanded to all individual
|
|
MPs of that party who were active on the motion date (via mp_metadata).
|
|
|
|
For motions that already have individual MP records, those rows are kept as-is.
|
|
|
|
Returns list of (motion_id, mp_name, vote, date_str) tuples.
|
|
"""
|
|
conn = duckdb.connect(db_path, read_only=True)
|
|
try:
|
|
# Load all vote rows for the window
|
|
vote_rows = conn.execute(
|
|
"SELECT motion_id, mp_name, vote, date FROM mp_votes "
|
|
"WHERE date BETWEEN ? AND ?",
|
|
(start_date, end_date),
|
|
).fetchall()
|
|
|
|
# Load mp_metadata (name, party, van, tot_en_met)
|
|
meta_rows = conn.execute(
|
|
"SELECT mp_name, party, van, tot_en_met FROM mp_metadata"
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
if not vote_rows:
|
|
return []
|
|
|
|
# Build mp_metadata lookup: canonical_party -> list of (mp_name, van, tot_en_met)
|
|
from collections import defaultdict
|
|
import datetime
|
|
|
|
party_to_mps: Dict[str, List[Tuple]] = defaultdict(list)
|
|
for mp_name, party, van, tot_en_met in meta_rows:
|
|
if party and mp_name:
|
|
party_to_mps[party].append((mp_name, van, tot_en_met))
|
|
|
|
def get_active_mps(canonical_party: str, motion_date) -> List[str]:
|
|
"""Return MP names active in canonical_party on motion_date."""
|
|
result = []
|
|
for mp_name, van, tot_en_met in party_to_mps.get(canonical_party, []):
|
|
if van is None or van > motion_date:
|
|
continue
|
|
if tot_en_met is not None and tot_en_met < motion_date:
|
|
continue
|
|
result.append(mp_name)
|
|
return result
|
|
|
|
# Group rows by motion_id, separate individual vs party rows
|
|
from collections import defaultdict as _dd
|
|
|
|
motion_individual: Dict[int, List] = _dd(list)
|
|
motion_party: Dict[int, List] = _dd(list)
|
|
|
|
for motion_id, mp_name, vote, date in vote_rows:
|
|
mid = int(motion_id)
|
|
# Individual MPs have comma in name (e.g. "Bergkamp, V.A.")
|
|
if "," in str(mp_name):
|
|
motion_individual[mid].append((mp_name, vote, date))
|
|
else:
|
|
motion_party[mid].append((mp_name, vote, date))
|
|
|
|
# Build the final expanded rows
|
|
expanded: List[Tuple[int, str, str, str]] = []
|
|
|
|
all_motion_ids = set(motion_individual.keys()) | set(motion_party.keys())
|
|
for mid in all_motion_ids:
|
|
if mid in motion_individual and motion_individual[mid]:
|
|
# Motion already has individual MP rows — convert to mp_metadata name format,
|
|
# then use directly; skip party rows for this motion.
|
|
for mp_name, vote, date in motion_individual[mid]:
|
|
canonical_name = _votes_name_to_meta_format(str(mp_name))
|
|
expanded.append((mid, canonical_name, vote, str(date)))
|
|
else:
|
|
# Party-only motion — expand each party row to individual MPs
|
|
for party_name, vote, date in motion_party[mid]:
|
|
if party_name in _SKIP_PARTIES:
|
|
continue
|
|
canonical = _PARTY_NAME_MAP.get(party_name, party_name)
|
|
active_mps = get_active_mps(canonical, date)
|
|
if not active_mps:
|
|
# If we have no mp_metadata for this party (common in tests or
|
|
# minimal DB fixtures), fall back to using the party code itself
|
|
# as a single representative row rather than dropping the motion.
|
|
# This keeps downstream pipelines (SVD, tests) working when
|
|
# detailed mp_metadata is not present.
|
|
_logger.debug(
|
|
"No active MPs found for party %s (canonical: %s) on %s; falling back to party-level row",
|
|
party_name,
|
|
canonical,
|
|
date,
|
|
)
|
|
expanded.append((mid, canonical, vote, str(date)))
|
|
else:
|
|
for mp_name in active_mps:
|
|
expanded.append((mid, mp_name, vote, str(date)))
|
|
|
|
return expanded
|
|
|
|
|
|
def _safe_k(mat: np.ndarray, k: int) -> int:
|
|
"""Return a safe k for svds: must be < min(mat.shape)."""
|
|
if mat is None:
|
|
return 0
|
|
m, n = mat.shape
|
|
min_dim = min(m, n)
|
|
# svds requires k < min_dim
|
|
if min_dim <= 1:
|
|
return 0
|
|
return min(k, min_dim - 1)
|
|
|
|
|
|
def _build_vote_matrix(
|
|
db: MotionDatabase, start_date: str, end_date: str
|
|
) -> Tuple[np.ndarray, List[str], List[int]]:
|
|
"""Build dense vote matrix (mp x motion) for votes between start_date and end_date.
|
|
|
|
Returns (matrix, mp_names, motion_ids)
|
|
"""
|
|
conn = duckdb.connect(db.db_path)
|
|
rows = conn.execute(
|
|
"SELECT motion_id, mp_name, vote FROM mp_votes WHERE date BETWEEN ? AND ?",
|
|
(start_date, end_date),
|
|
).fetchall()
|
|
conn.close()
|
|
|
|
if not rows:
|
|
return np.zeros((0, 0)), [], []
|
|
|
|
motion_ids = sorted({int(r[0]) for r in rows})
|
|
mp_names = sorted({r[1] for r in rows})
|
|
|
|
m = len(mp_names)
|
|
n = len(motion_ids)
|
|
mat = np.zeros((m, n), dtype=float)
|
|
|
|
mp_index = {name: i for i, name in enumerate(mp_names)}
|
|
motion_index = {mid: j for j, mid in enumerate(motion_ids)}
|
|
|
|
for motion_id, mp_name, vote in rows:
|
|
i = mp_index[mp_name]
|
|
j = motion_index[int(motion_id)]
|
|
val = VOTE_MAP.get(
|
|
vote, VOTE_MAP.get(vote.strip() if isinstance(vote, str) else vote, 0.0)
|
|
)
|
|
try:
|
|
mat[i, j] = float(val)
|
|
except Exception:
|
|
mat[i, j] = 0.0
|
|
|
|
return mat, mp_names, motion_ids
|
|
|
|
|
|
def _procrustes_align(
|
|
reference_anchor: np.ndarray,
|
|
current_anchor: np.ndarray,
|
|
min_overlap: int = 3,
|
|
) -> np.ndarray:
|
|
"""Align current_anchor to reference_anchor using orthogonal Procrustes.
|
|
|
|
This function will only attempt alignment when there is a reasonable number of
|
|
overlapping rows (default: min_overlap). If the overlap is too small or if any
|
|
input is invalid, the original current_anchor is returned unchanged.
|
|
|
|
Returns transformed_current_anchor
|
|
"""
|
|
# basic validation
|
|
if reference_anchor is None or current_anchor is None:
|
|
return current_anchor
|
|
|
|
if not isinstance(reference_anchor, np.ndarray) or not isinstance(
|
|
current_anchor, np.ndarray
|
|
):
|
|
return current_anchor
|
|
|
|
# Determine overlap by number of available rows. If too small, skip alignment.
|
|
n_ref = reference_anchor.shape[0]
|
|
n_cur = current_anchor.shape[0]
|
|
overlap = min(n_ref, n_cur)
|
|
if overlap < min_overlap:
|
|
_logger.debug(
|
|
"Procrustes alignment skipped: overlap %s < min_overlap %s",
|
|
overlap,
|
|
min_overlap,
|
|
)
|
|
return current_anchor
|
|
|
|
# Use only the overlapping rows to compute the orthogonal transform.
|
|
ref_sub = reference_anchor[:overlap, :]
|
|
cur_sub = current_anchor[:overlap, :]
|
|
|
|
try:
|
|
# orthogonal_procrustes(A, B) returns R, scale such that A @ R = B * scale
|
|
# We want to transform current_anchor to align with reference_anchor so
|
|
# call orthogonal_procrustes(cur_sub, ref_sub) and apply resulting R/scale
|
|
R, _scale = orthogonal_procrustes(cur_sub, ref_sub)
|
|
transformed = current_anchor.dot(R)
|
|
return transformed
|
|
except Exception:
|
|
_logger.exception("Procrustes alignment failed")
|
|
return current_anchor
|
|
|
|
|
|
def compute_svd_for_window(
|
|
db_path: str,
|
|
window_id: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
k: int = 50,
|
|
) -> Dict:
|
|
"""Pure-compute SVD for a window. Safe to run in a subprocess.
|
|
|
|
Opens the DB in read-only mode (allows concurrent parallel workers).
|
|
Does NOT write to the DB — caller is responsible for persisting results.
|
|
|
|
Party-level vote rows are expanded to individual MP rows using mp_metadata
|
|
so that the vote matrix contains only individual MPs (no party aggregates).
|
|
This prevents the block-diagonal structure that causes SVD axes to be disjoint.
|
|
|
|
Returns dict with keys:
|
|
window_id, k_used, mp_rows, motion_rows
|
|
where *_rows are List[Tuple[entity_type, entity_id, vector, model]]
|
|
"""
|
|
empty = {"window_id": window_id, "k_used": 0, "mp_rows": [], "motion_rows": []}
|
|
|
|
# Build expanded rows: party votes → individual MP votes
|
|
rows = _build_expanded_rows(db_path, start_date, end_date)
|
|
|
|
if not rows:
|
|
return empty
|
|
|
|
motion_ids = sorted({int(r[0]) for r in rows})
|
|
mp_names = sorted({r[1] for r in rows})
|
|
|
|
m_count = len(mp_names)
|
|
n_count = len(motion_ids)
|
|
mat = np.zeros((m_count, n_count), dtype=float)
|
|
|
|
mp_index = {name: i for i, name in enumerate(mp_names)}
|
|
motion_index = {mid: j for j, mid in enumerate(motion_ids)}
|
|
|
|
for motion_id, mp_name, vote, _date in rows:
|
|
i = mp_index[mp_name]
|
|
j = motion_index[int(motion_id)]
|
|
val = VOTE_MAP.get(
|
|
vote, VOTE_MAP.get(vote.strip() if isinstance(vote, str) else vote, 0.0)
|
|
)
|
|
try:
|
|
mat[i, j] = float(val)
|
|
except Exception:
|
|
mat[i, j] = 0.0
|
|
|
|
if mat.size == 0 or mat.shape[0] == 0 or mat.shape[1] == 0:
|
|
return empty
|
|
|
|
k_used = _safe_k(mat, k)
|
|
if k_used <= 0:
|
|
return empty
|
|
|
|
try:
|
|
A = csr_matrix(mat)
|
|
U, s, Vt = svds(A, k=k_used)
|
|
idx = np.argsort(s)[::-1]
|
|
s = s[idx]
|
|
U = U[:, idx]
|
|
Vt = Vt[idx, :]
|
|
|
|
mp_vecs = (U * s.reshape(1, -1)).tolist()
|
|
motion_vecs = (Vt.T * s.reshape(1, -1)).tolist()
|
|
|
|
mp_rows = [
|
|
("mp", mp_name, mp_vecs[i], None) for i, mp_name in enumerate(mp_names)
|
|
]
|
|
motion_rows = [
|
|
("motion", str(mid), motion_vecs[j], None)
|
|
for j, mid in enumerate(motion_ids)
|
|
]
|
|
|
|
return {
|
|
"window_id": window_id,
|
|
"k_used": k_used,
|
|
"mp_rows": mp_rows,
|
|
"motion_rows": motion_rows,
|
|
}
|
|
|
|
except Exception:
|
|
_logger.exception("SVD failed for window %s", window_id)
|
|
return empty
|
|
|
|
|
|
def run_svd_for_window(
|
|
db: MotionDatabase,
|
|
window_id: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
k: int = 50,
|
|
) -> Dict:
|
|
"""Run SVD on votes in given date window and store vectors in DB.
|
|
|
|
Returns metadata dict with keys: k_used, stored_mp, stored_motion
|
|
"""
|
|
result = compute_svd_for_window(db.db_path, window_id, start_date, end_date, k)
|
|
if result["k_used"] == 0:
|
|
return {"k_used": 0, "stored_mp": 0, "stored_motion": 0}
|
|
|
|
rows = result["mp_rows"] + result["motion_rows"]
|
|
stored = db.batch_store_svd_vectors(window_id, rows)
|
|
return {
|
|
"k_used": result["k_used"],
|
|
"stored_mp": len(result["mp_rows"]),
|
|
"stored_motion": len(result["motion_rows"]),
|
|
}
|
|
|