You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
6.6 KiB
207 lines
6.6 KiB
import duckdb
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from database import MotionDatabase
|
|
|
|
|
|
def _create_motion_and_get_id(
|
|
db: MotionDatabase, url: str, title: str, layman: str = "x"
|
|
):
|
|
md = {
|
|
"title": title,
|
|
"description": title,
|
|
"date": "2023-01-01",
|
|
"policy_area": "test",
|
|
"voting_results": {},
|
|
"winning_margin": 0.5,
|
|
"layman_explanation": layman,
|
|
"url": url,
|
|
}
|
|
ok = db.insert_motion(md)
|
|
assert ok, "insert_motion failed"
|
|
conn = duckdb.connect(db.db_path)
|
|
row = conn.execute("SELECT id FROM motions WHERE url = ?", (url,)).fetchone()
|
|
conn.close()
|
|
assert row is not None, "couldn't find inserted motion"
|
|
return int(row[0])
|
|
|
|
|
|
def test_match_mps_basic(tmp_path: Path):
|
|
db_path = str(tmp_path / "test_motions.db")
|
|
db = MotionDatabase(db_path)
|
|
|
|
# create 4 motions
|
|
mids = []
|
|
for i in range(1, 5):
|
|
mids.append(_create_motion_and_get_id(db, f"http://m{i}", f"Motion {i}"))
|
|
|
|
# MPs
|
|
mpA = "Alpha, A."
|
|
mpB = "Beta, B."
|
|
mpC = "Gamma, G."
|
|
|
|
# Voting patterns (motions 1..4)
|
|
# A: v v t v (3/4)
|
|
# B: t t t t (0/4)
|
|
# C: v v v v (4/4)
|
|
votes = {
|
|
mpA: ["voor", "voor", "tegen", "voor"],
|
|
mpB: ["tegen", "tegen", "tegen", "tegen"],
|
|
mpC: ["voor", "voor", "voor", "voor"],
|
|
}
|
|
|
|
for idx, mid in enumerate(mids):
|
|
for mp_name, vlist in votes.items():
|
|
db.insert_mp_vote(
|
|
motion_id=mid,
|
|
mp_name=mp_name,
|
|
vote=vlist[idx],
|
|
date="2023-01-01",
|
|
party=None,
|
|
)
|
|
|
|
# User votes matching Gamma exactly
|
|
user_votes = {mids[0]: "Voor", mids[1]: "Voor", mids[2]: "Voor", mids[3]: "Voor"}
|
|
|
|
results = db.match_mps_for_votes(user_votes, limit=10)
|
|
assert results, "No results returned"
|
|
|
|
# Top candidate should be Gamma
|
|
top = results[0]
|
|
assert top["mp_name"] == mpC
|
|
assert top["matched"] == 4
|
|
assert top["overlap"] == 4
|
|
assert top["agreement_pct"] == 100.0
|
|
|
|
# Check Alpha is second with 3 matched
|
|
names = [r["mp_name"] for r in results]
|
|
assert mpA in names
|
|
a = next(r for r in results if r["mp_name"] == mpA)
|
|
assert a["matched"] == 3
|
|
assert a["overlap"] == 4
|
|
assert a["agreement_pct"] == 75.0
|
|
|
|
|
|
def test_choose_discriminating_motions(tmp_path: Path):
|
|
db_path = str(tmp_path / "test_motions2.db")
|
|
db = MotionDatabase(db_path)
|
|
|
|
# create 3 motions
|
|
mids = []
|
|
for i in range(1, 4):
|
|
mids.append(_create_motion_and_get_id(db, f"http://d{i}", f"DMotion {i}"))
|
|
|
|
mpA = "Alice, A."
|
|
mpB = "Bob, B."
|
|
mpC = "Carol, C."
|
|
|
|
# Votes: motion1 splits A vs B/C
|
|
# motion1: A=voor, B=tegen, C=tegen
|
|
# motion2: all voor
|
|
# motion3: all tegen
|
|
db.insert_mp_vote(mids[0], mpA, "voor", "2023-01-01")
|
|
db.insert_mp_vote(mids[0], mpB, "tegen", "2023-01-01")
|
|
db.insert_mp_vote(mids[0], mpC, "tegen", "2023-01-01")
|
|
|
|
for mp in (mpA, mpB, mpC):
|
|
db.insert_mp_vote(mids[1], mp, "voor", "2023-01-01")
|
|
db.insert_mp_vote(mids[2], mp, "tegen", "2023-01-01")
|
|
|
|
candidates = [mpA, mpB, mpC]
|
|
chosen = db.choose_discriminating_motions(candidates, excluded_motion_ids=[], k=1)
|
|
assert chosen, "No discriminating motion returned"
|
|
# best splitter should be motion1 (mids[0])
|
|
assert chosen[0] == mids[0]
|
|
|
|
|
|
def test_match_excludes_zero_overlap(tmp_path: Path):
|
|
"""MPs who voted on none of the user's motions must not appear in results."""
|
|
db_path = str(tmp_path / "zo.db")
|
|
db = MotionDatabase(db_path)
|
|
|
|
mid1 = _create_motion_and_get_id(db, "http://zo1", "ZO Motion 1")
|
|
mid2 = _create_motion_and_get_id(db, "http://zo2", "ZO Motion 2")
|
|
|
|
mp_overlap = "Overlap, O."
|
|
mp_noshow = "Noshow, N."
|
|
|
|
db.insert_mp_vote(mid1, mp_overlap, "voor", "2023-01-01")
|
|
# mp_noshow only voted on mid2, not mid1
|
|
db.insert_mp_vote(mid2, mp_noshow, "voor", "2023-01-01")
|
|
|
|
results = db.match_mps_for_votes({mid1: "Voor"}, limit=10)
|
|
names = [r["mp_name"] for r in results]
|
|
|
|
assert mp_overlap in names, "mp_overlap should appear"
|
|
assert mp_noshow not in names, "mp_noshow had no overlap and must be excluded"
|
|
|
|
|
|
def test_invalid_input_empty_user_votes(tmp_path: Path):
|
|
"""Passing an empty dict must raise ValueError."""
|
|
db_path = str(tmp_path / "inv.db")
|
|
db = MotionDatabase(db_path)
|
|
|
|
import pytest
|
|
|
|
with pytest.raises(ValueError, match="non-empty"):
|
|
db.match_mps_for_votes({})
|
|
|
|
|
|
def test_invalid_input_empty_candidates(tmp_path: Path):
|
|
"""Passing empty candidates to choose_discriminating_motions must raise ValueError."""
|
|
db_path = str(tmp_path / "inv2.db")
|
|
db = MotionDatabase(db_path)
|
|
|
|
import pytest
|
|
|
|
with pytest.raises(ValueError):
|
|
db.choose_discriminating_motions([], excluded_motion_ids=[])
|
|
|
|
|
|
def test_geen_stem_not_counted_in_overlap(tmp_path: Path):
|
|
"""'Geen stem' user votes should be skipped (not counted in overlap or matched)."""
|
|
db_path = str(tmp_path / "gs.db")
|
|
db = MotionDatabase(db_path)
|
|
|
|
mid1 = _create_motion_and_get_id(db, "http://gs1", "GS Motion 1")
|
|
mid2 = _create_motion_and_get_id(db, "http://gs2", "GS Motion 2")
|
|
|
|
mpA = "Alpha, A."
|
|
db.insert_mp_vote(mid1, mpA, "voor", "2023-01-01")
|
|
db.insert_mp_vote(mid2, mpA, "voor", "2023-01-01")
|
|
|
|
# user says Geen stem on mid1 (skip), Voor on mid2
|
|
results = db.match_mps_for_votes({mid1: "Geen stem", mid2: "Voor"}, limit=10)
|
|
assert results, "Expected at least one result"
|
|
r = results[0]
|
|
# overlap should only be 1 (mid2 counted, mid1 skipped)
|
|
assert r["overlap"] == 1
|
|
assert r["matched"] == 1
|
|
assert r["agreement_pct"] == 100.0
|
|
|
|
|
|
def test_choose_excluded_motions_respected(tmp_path: Path):
|
|
"""Excluded motion ids must not be returned by choose_discriminating_motions."""
|
|
db_path = str(tmp_path / "excl.db")
|
|
db = MotionDatabase(db_path)
|
|
|
|
mid1 = _create_motion_and_get_id(db, "http://ex1", "EX Motion 1")
|
|
mid2 = _create_motion_and_get_id(db, "http://ex2", "EX Motion 2")
|
|
|
|
mpA = "Alice, A."
|
|
mpB = "Bob, B."
|
|
|
|
# mid1 splits them; mid2 they agree
|
|
db.insert_mp_vote(mid1, mpA, "voor", "2023-01-01")
|
|
db.insert_mp_vote(mid1, mpB, "tegen", "2023-01-01")
|
|
db.insert_mp_vote(mid2, mpA, "voor", "2023-01-01")
|
|
db.insert_mp_vote(mid2, mpB, "voor", "2023-01-01")
|
|
|
|
# Exclude mid1 — only mid2 is available, should return mid2
|
|
chosen = db.choose_discriminating_motions(
|
|
[mpA, mpB], excluded_motion_ids=[mid1], k=1
|
|
)
|
|
assert mid1 not in chosen, "Excluded motion must not be returned"
|
|
assert mid2 in chosen, "mid2 should be chosen as only available motion"
|
|
|