You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/pipeline/svd_pipeline.py

264 lines
7.7 KiB

import json
import logging
from typing import Optional, Dict, List, Tuple
import numpy as np
try:
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import svds
from scipy.linalg import orthogonal_procrustes
_HAS_SCIPY = True
except Exception:
# Provide lightweight fallbacks for environments without scipy
csr_matrix = lambda x: x
def svds(a, k=1):
# fallback to numpy.linalg.svd on dense arrays
U, s, Vt = np.linalg.svd(np.array(a), full_matrices=False)
# return last k components to mimic scipy.svds behaviour
return U[:, -k:], s[-k:], Vt[-k:, :]
def orthogonal_procrustes(A, B):
# simple orthogonal Procrustes via SVD: find R minimizing ||A R - B||
U, _, Vt = np.linalg.svd(A.T.dot(B))
R = U.dot(Vt)
scale = 1.0
return R, scale
_HAS_SCIPY = False
import duckdb
from database import MotionDatabase
_logger = logging.getLogger(__name__)
# Map textual votes to numeric values for SVD
VOTE_MAP = {
"Voor": 1.0,
"voor": 1.0,
"Tegen": -1.0,
"tegen": -1.0,
"Geen stem": 0.0,
"Onbekend": 0.0,
"Onbekend stem": 0.0,
"Blanco": 0.0,
}
def _safe_k(mat: np.ndarray, k: int) -> int:
"""Return a safe k for svds: must be < min(mat.shape)."""
if mat is None:
return 0
m, n = mat.shape
min_dim = min(m, n)
# svds requires k < min_dim
if min_dim <= 1:
return 0
return min(k, min_dim - 1)
def _build_vote_matrix(
db: MotionDatabase, start_date: str, end_date: str
) -> Tuple[np.ndarray, List[str], List[int]]:
"""Build dense vote matrix (mp x motion) for votes between start_date and end_date.
Returns (matrix, mp_names, motion_ids)
"""
conn = duckdb.connect(db.db_path)
rows = conn.execute(
"SELECT motion_id, mp_name, vote FROM mp_votes WHERE date BETWEEN ? AND ?",
(start_date, end_date),
).fetchall()
conn.close()
if not rows:
return np.zeros((0, 0)), [], []
motion_ids = sorted({int(r[0]) for r in rows})
mp_names = sorted({r[1] for r in rows})
m = len(mp_names)
n = len(motion_ids)
mat = np.zeros((m, n), dtype=float)
mp_index = {name: i for i, name in enumerate(mp_names)}
motion_index = {mid: j for j, mid in enumerate(motion_ids)}
for motion_id, mp_name, vote in rows:
i = mp_index[mp_name]
j = motion_index[int(motion_id)]
val = VOTE_MAP.get(
vote, VOTE_MAP.get(vote.strip() if isinstance(vote, str) else vote, 0.0)
)
try:
mat[i, j] = float(val)
except Exception:
mat[i, j] = 0.0
return mat, mp_names, motion_ids
def _procrustes_align(
reference_anchor: np.ndarray,
current_anchor: np.ndarray,
min_overlap: int = 3,
) -> np.ndarray:
"""Align current_anchor to reference_anchor using orthogonal Procrustes.
This function will only attempt alignment when there is a reasonable number of
overlapping rows (default: min_overlap). If the overlap is too small or if any
input is invalid, the original current_anchor is returned unchanged.
Returns transformed_current_anchor
"""
# basic validation
if reference_anchor is None or current_anchor is None:
return current_anchor
if not isinstance(reference_anchor, np.ndarray) or not isinstance(
current_anchor, np.ndarray
):
return current_anchor
# Determine overlap by number of available rows. If too small, skip alignment.
n_ref = reference_anchor.shape[0]
n_cur = current_anchor.shape[0]
overlap = min(n_ref, n_cur)
if overlap < min_overlap:
_logger.debug(
"Procrustes alignment skipped: overlap %s < min_overlap %s",
overlap,
min_overlap,
)
return current_anchor
# Use only the overlapping rows to compute the orthogonal transform.
ref_sub = reference_anchor[:overlap, :]
cur_sub = current_anchor[:overlap, :]
try:
# orthogonal_procrustes(A, B) returns R, scale such that A @ R = B * scale
# We want to transform current_anchor to align with reference_anchor so
# call orthogonal_procrustes(cur_sub, ref_sub) and apply resulting R/scale
R, _scale = orthogonal_procrustes(cur_sub, ref_sub)
transformed = current_anchor.dot(R)
return transformed
except Exception:
_logger.exception("Procrustes alignment failed")
return current_anchor
def compute_svd_for_window(
db_path: str,
window_id: str,
start_date: str,
end_date: str,
k: int = 50,
) -> Dict:
"""Pure-compute SVD for a window. Safe to run in a subprocess.
Opens the DB in read-only mode (allows concurrent parallel workers).
Does NOT write to the DB — caller is responsible for persisting results.
Returns dict with keys:
window_id, k_used, mp_rows, motion_rows
where *_rows are List[Tuple[entity_type, entity_id, vector, model]]
"""
empty = {"window_id": window_id, "k_used": 0, "mp_rows": [], "motion_rows": []}
# Read vote matrix using a read-only connection — safe to run in parallel.
conn = duckdb.connect(db_path, read_only=True)
try:
rows = conn.execute(
"SELECT motion_id, mp_name, vote FROM mp_votes WHERE date BETWEEN ? AND ?",
(start_date, end_date),
).fetchall()
finally:
conn.close()
if not rows:
return empty
motion_ids = sorted({int(r[0]) for r in rows})
mp_names = sorted({r[1] for r in rows})
m_count = len(mp_names)
n_count = len(motion_ids)
mat = np.zeros((m_count, n_count), dtype=float)
mp_index = {name: i for i, name in enumerate(mp_names)}
motion_index = {mid: j for j, mid in enumerate(motion_ids)}
for motion_id, mp_name, vote in rows:
i = mp_index[mp_name]
j = motion_index[int(motion_id)]
val = VOTE_MAP.get(
vote, VOTE_MAP.get(vote.strip() if isinstance(vote, str) else vote, 0.0)
)
try:
mat[i, j] = float(val)
except Exception:
mat[i, j] = 0.0
if mat.size == 0 or mat.shape[0] == 0 or mat.shape[1] == 0:
return empty
k_used = _safe_k(mat, k)
if k_used <= 0:
return empty
try:
A = csr_matrix(mat)
U, s, Vt = svds(A, k=k_used)
idx = np.argsort(s)[::-1]
s = s[idx]
U = U[:, idx]
Vt = Vt[idx, :]
mp_vecs = (U * s.reshape(1, -1)).tolist()
motion_vecs = (Vt.T * s.reshape(1, -1)).tolist()
mp_rows = [
("mp", mp_name, mp_vecs[i], None) for i, mp_name in enumerate(mp_names)
]
motion_rows = [
("motion", str(mid), motion_vecs[j], None)
for j, mid in enumerate(motion_ids)
]
return {
"window_id": window_id,
"k_used": k_used,
"mp_rows": mp_rows,
"motion_rows": motion_rows,
}
except Exception:
_logger.exception("SVD failed for window %s", window_id)
return empty
def run_svd_for_window(
db: MotionDatabase,
window_id: str,
start_date: str,
end_date: str,
k: int = 50,
) -> Dict:
"""Run SVD on votes in given date window and store vectors in DB.
Returns metadata dict with keys: k_used, stored_mp, stored_motion
"""
result = compute_svd_for_window(db.db_path, window_id, start_date, end_date, k)
if result["k_used"] == 0:
return {"k_used": 0, "stored_mp": 0, "stored_motion": 0}
rows = result["mp_rows"] + result["motion_rows"]
stored = db.batch_store_svd_vectors(window_id, rows)
return {
"k_used": result["k_used"],
"stored_mp": len(result["mp_rows"]),
"stored_motion": len(result["motion_rows"]),
}