You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_svd_pipeline.py

63 lines
1.8 KiB

import json
import numpy as np
import pytest
from database import db as motion_db
from pipeline.svd_pipeline import (
_safe_k,
_build_vote_matrix,
_procrustes_align,
run_svd_for_window,
)
def test_safe_k_and_build_and_run(tmp_path):
np.random.seed(0)
# reset DB file for test
db_path = tmp_path / "test.db"
# point the MotionDatabase to this test DB
motion_db.db_path = str(db_path)
motion_db._init_database()
# Create synthetic dataset: 5 MPs x 6 motions
mps = [f"MP_{i}" for i in range(5)]
motions = list(range(100, 106))
dates = ["2020-01-0" + str(i + 1) for i in range(6)]
votes = ["Voor", "Tegen", "Geen stem"]
# insert votes: fill full matrix using MotionDatabase helper
for j, motion_id in enumerate(motions):
for i, mp in enumerate(mps):
vote = votes[(i + j) % len(votes)]
motion_db.insert_mp_vote(motion_id, mp, vote, date=dates[j])
mat, mp_names, motion_ids = _build_vote_matrix(
motion_db, "2020-01-01", "2020-01-10"
)
assert mat.shape == (5, 6)
# _safe_k: with k=10 -> min_dim=5 -> returns 4
assert _safe_k(mat, 10) == 4
assert _safe_k(mat, 3) == 3
# run_svd_for_window with k=10 -> should use k_used=4
res = run_svd_for_window(motion_db, "w1", "2020-01-01", "2020-01-10", k=10)
assert res["k_used"] == 4
assert res["stored_mp"] == 5
assert res["stored_motion"] == 6
def test_procrustes_align():
np.random.seed(0)
# create reference anchors and current anchors rotated + noise
ref = np.random.randn(10, 3)
# create orthogonal rotation
Q, _ = np.linalg.qr(np.random.randn(3, 3))
cur = ref.dot(Q) + 0.1 * np.random.randn(10, 3)
before = np.linalg.norm(cur - ref)
transformed = _procrustes_align(ref, cur)
after = np.linalg.norm(transformed - ref)
assert after < before