You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
7.7 KiB
219 lines
7.7 KiB
from pathlib import Path
|
|
|
|
try:
|
|
import duckdb
|
|
|
|
DB_BACKEND = "duckdb"
|
|
except Exception:
|
|
import sqlite3
|
|
|
|
DB_BACKEND = "sqlite3"
|
|
|
|
|
|
MIGRATIONS = [
|
|
(
|
|
"migrations/2026_03_21__create_mp_votes.sql",
|
|
"mp_votes",
|
|
[
|
|
"id",
|
|
"motion_id",
|
|
"mp_name",
|
|
"party",
|
|
"vote",
|
|
"date",
|
|
"created_at",
|
|
],
|
|
),
|
|
(
|
|
"migrations/2026_03_21__create_mp_metadata.sql",
|
|
"mp_metadata",
|
|
[
|
|
"mp_name",
|
|
"party",
|
|
"van",
|
|
"tot_en_met",
|
|
"persoon_id",
|
|
],
|
|
),
|
|
(
|
|
"migrations/2026_03_21__create_svd_vectors.sql",
|
|
"svd_vectors",
|
|
[
|
|
"id",
|
|
"window_id",
|
|
"entity_type",
|
|
"entity_id",
|
|
"vector",
|
|
"model",
|
|
"created_at",
|
|
],
|
|
),
|
|
(
|
|
"migrations/2026_03_21__create_fused_embeddings.sql",
|
|
"fused_embeddings",
|
|
[
|
|
"id",
|
|
"motion_id",
|
|
"window_id",
|
|
"vector",
|
|
"svd_dims",
|
|
"text_dims",
|
|
"created_at",
|
|
],
|
|
),
|
|
]
|
|
|
|
|
|
def test_run_migrations_and_tables(tmp_path):
|
|
db_path = tmp_path / "test.db"
|
|
if DB_BACKEND == "duckdb":
|
|
conn = duckdb.connect(str(db_path))
|
|
else:
|
|
conn = sqlite3.connect(str(db_path))
|
|
|
|
for sql_path, table_name, expected_cols in MIGRATIONS:
|
|
p = Path(sql_path)
|
|
assert p.exists(), f"Migration file {sql_path} must exist"
|
|
sql = p.read_text()
|
|
|
|
# If using sqlite3, transform SQL to be sqlite compatible
|
|
if DB_BACKEND == "sqlite3":
|
|
# remove CREATE SEQUENCE lines
|
|
lines = [
|
|
l
|
|
for l in sql.splitlines()
|
|
if not l.strip().upper().startswith("CREATE SEQUENCE")
|
|
]
|
|
sql2 = "\n".join(lines)
|
|
# remove DEFAULT nextval(...) occurrences
|
|
import re
|
|
|
|
sql2 = re.sub(
|
|
r"DEFAULT\s+nextval\('[^']+'\)", "", sql2, flags=re.IGNORECASE
|
|
)
|
|
# replace JSON type with TEXT
|
|
sql2 = re.sub(r"\bJSON\b", "TEXT", sql2, flags=re.IGNORECASE)
|
|
# execute as script (multiple statements)
|
|
conn.executescript(sql2)
|
|
else:
|
|
# execute migration SQL
|
|
conn.execute(sql)
|
|
|
|
# check columns via pragma
|
|
if DB_BACKEND == "duckdb":
|
|
rows = conn.execute(f"PRAGMA table_info('{table_name}')").fetchall()
|
|
col_names = [r[1] for r in rows]
|
|
else:
|
|
cur = conn.execute(f"PRAGMA table_info('{table_name}')")
|
|
rows = cur.fetchall()
|
|
col_names = [r[1] for r in rows]
|
|
|
|
for col in expected_cols:
|
|
assert col in col_names, (
|
|
f"Column {col} missing in table {table_name}, got {col_names}"
|
|
)
|
|
|
|
# perform a simple insert + select to validate basic round-trip
|
|
if table_name == "mp_votes":
|
|
if DB_BACKEND == "duckdb":
|
|
conn.execute(
|
|
"INSERT INTO mp_votes (motion_id, mp_name, party, vote, date) VALUES (1, 'Jane Doe', 'PartyX', 'Yea', '2026-03-21')"
|
|
)
|
|
res = conn.execute(
|
|
"SELECT motion_id, mp_name, party, vote, date FROM mp_votes WHERE motion_id=1"
|
|
).fetchone()
|
|
# DuckDB returns datetime.date for DATE columns; normalise to string
|
|
assert (
|
|
res[:4] == (1, "Jane Doe", "PartyX", "Yea")
|
|
and str(res[4]) == "2026-03-21"
|
|
)
|
|
else:
|
|
# sqlite: id has no default after transformation, provide id explicitly
|
|
conn.execute(
|
|
"INSERT INTO mp_votes (id, motion_id, mp_name, party, vote, date) VALUES (1, 1, 'Jane Doe', 'PartyX', 'Yea', '2026-03-21')"
|
|
)
|
|
res = conn.execute(
|
|
"SELECT motion_id, mp_name, party, vote, date FROM mp_votes WHERE id=1"
|
|
).fetchone()
|
|
assert res == (1, "Jane Doe", "PartyX", "Yea", "2026-03-21")
|
|
|
|
elif table_name == "mp_metadata":
|
|
conn.execute(
|
|
"INSERT INTO mp_metadata (mp_name, party, van, tot_en_met, persoon_id) VALUES ('Jane Doe', 'PartyX', '2020-01-01', '2024-12-31', 'pid-123')"
|
|
)
|
|
res = conn.execute(
|
|
"SELECT mp_name, party, van, tot_en_met, persoon_id FROM mp_metadata WHERE mp_name='Jane Doe'"
|
|
).fetchone()
|
|
# DuckDB returns datetime.date for DATE columns; normalise to string
|
|
assert (
|
|
res[0] == "Jane Doe"
|
|
and res[1] == "PartyX"
|
|
and str(res[2]) == "2020-01-01"
|
|
and str(res[3]) == "2024-12-31"
|
|
and res[4] == "pid-123"
|
|
)
|
|
|
|
elif table_name == "svd_vectors":
|
|
# JSON value as text
|
|
if DB_BACKEND == "duckdb":
|
|
conn.execute(
|
|
"INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model) VALUES ('w1', 'typeA', 'e1', '[1,2,3]', 'm1')"
|
|
)
|
|
res = conn.execute(
|
|
"SELECT window_id, entity_type, entity_id, vector, model FROM svd_vectors WHERE window_id='w1'"
|
|
).fetchone()
|
|
# Note: DuckDB may return the JSON column as string; compare string form
|
|
assert (
|
|
res[0] == "w1"
|
|
and res[1] == "typeA"
|
|
and res[2] == "e1"
|
|
and (str(res[3]) == "[1,2,3]" or res[3] == "[1,2,3]")
|
|
and res[4] == "m1"
|
|
)
|
|
else:
|
|
# sqlite: provide id explicitly
|
|
conn.execute(
|
|
"INSERT INTO svd_vectors (id, window_id, entity_type, entity_id, vector, model) VALUES (1, 'w1', 'typeA', 'e1', '[1,2,3]', 'm1')"
|
|
)
|
|
res = conn.execute(
|
|
"SELECT window_id, entity_type, entity_id, vector, model FROM svd_vectors WHERE id=1"
|
|
).fetchone()
|
|
assert (
|
|
res[0] == "w1"
|
|
and res[1] == "typeA"
|
|
and res[2] == "e1"
|
|
and str(res[3]) == "[1,2,3]"
|
|
and res[4] == "m1"
|
|
)
|
|
|
|
elif table_name == "fused_embeddings":
|
|
if DB_BACKEND == "duckdb":
|
|
conn.execute(
|
|
"INSERT INTO fused_embeddings (motion_id, window_id, vector, svd_dims, text_dims) VALUES (2, 'w2', '[0.1,0.2]', 16, 128)"
|
|
)
|
|
res = conn.execute(
|
|
"SELECT motion_id, window_id, vector, svd_dims, text_dims FROM fused_embeddings WHERE motion_id=2"
|
|
).fetchone()
|
|
assert (
|
|
res[0] == 2
|
|
and res[1] == "w2"
|
|
and (str(res[2]) == "[0.1,0.2]" or res[2] == "[0.1,0.2]")
|
|
and res[3] == 16
|
|
and res[4] == 128
|
|
)
|
|
else:
|
|
conn.execute(
|
|
"INSERT INTO fused_embeddings (id, motion_id, window_id, vector, svd_dims, text_dims) VALUES (1, 2, 'w2', '[0.1,0.2]', 16, 128)"
|
|
)
|
|
res = conn.execute(
|
|
"SELECT motion_id, window_id, vector, svd_dims, text_dims FROM fused_embeddings WHERE id=1"
|
|
).fetchone()
|
|
assert (
|
|
res[0] == 2
|
|
and res[1] == "w2"
|
|
and str(res[2]) == "[0.1,0.2]"
|
|
and res[3] == 16
|
|
and res[4] == 128
|
|
)
|
|
|
|
conn.close()
|
|
|