You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
1.8 KiB
63 lines
1.8 KiB
import json
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from database import db as motion_db
|
|
from pipeline.svd_pipeline import (
|
|
_safe_k,
|
|
_build_vote_matrix,
|
|
_procrustes_align,
|
|
run_svd_for_window,
|
|
)
|
|
|
|
|
|
def test_safe_k_and_build_and_run(tmp_path):
|
|
np.random.seed(0)
|
|
# reset DB file for test
|
|
db_path = tmp_path / "test.db"
|
|
# point the MotionDatabase to this test DB
|
|
motion_db.db_path = str(db_path)
|
|
motion_db._init_database()
|
|
|
|
# Create synthetic dataset: 5 MPs x 6 motions
|
|
mps = [f"MP_{i}" for i in range(5)]
|
|
motions = list(range(100, 106))
|
|
dates = ["2020-01-0" + str(i + 1) for i in range(6)]
|
|
|
|
votes = ["Voor", "Tegen", "Geen stem"]
|
|
|
|
# insert votes: fill full matrix using MotionDatabase helper
|
|
for j, motion_id in enumerate(motions):
|
|
for i, mp in enumerate(mps):
|
|
vote = votes[(i + j) % len(votes)]
|
|
motion_db.insert_mp_vote(motion_id, mp, vote, date=dates[j])
|
|
|
|
mat, mp_names, motion_ids = _build_vote_matrix(
|
|
motion_db, "2020-01-01", "2020-01-10"
|
|
)
|
|
assert mat.shape == (5, 6)
|
|
|
|
# _safe_k: with k=10 -> min_dim=5 -> returns 4
|
|
assert _safe_k(mat, 10) == 4
|
|
assert _safe_k(mat, 3) == 3
|
|
|
|
# run_svd_for_window with k=10 -> should use k_used=4
|
|
res = run_svd_for_window(motion_db, "w1", "2020-01-01", "2020-01-10", k=10)
|
|
assert res["k_used"] == 4
|
|
assert res["stored_mp"] == 5
|
|
assert res["stored_motion"] == 6
|
|
|
|
|
|
def test_procrustes_align():
|
|
np.random.seed(0)
|
|
# create reference anchors and current anchors rotated + noise
|
|
ref = np.random.randn(10, 3)
|
|
# create orthogonal rotation
|
|
Q, _ = np.linalg.qr(np.random.randn(3, 3))
|
|
cur = ref.dot(Q) + 0.1 * np.random.randn(10, 3)
|
|
|
|
before = np.linalg.norm(cur - ref)
|
|
transformed = _procrustes_align(ref, cur)
|
|
after = np.linalg.norm(transformed - ref)
|
|
|
|
assert after < before
|
|
|