You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_fusion.py

81 lines
2.5 KiB

import json
import pytest
# duckdb is optional for test runs; skip test if not available
duckdb = pytest.importorskip("duckdb")
from database import MotionDatabase
def test_fuse_for_window(tmp_path):
db_path = str(tmp_path / "motions.db")
# Create MotionDatabase (this will initialize schema except embeddings)
db = MotionDatabase(db_path=db_path)
# Create embeddings table (migration not run by MotionDatabase)
conn = duckdb.connect(db_path)
conn.execute("CREATE SEQUENCE IF NOT EXISTS embeddings_id_seq START 1")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER DEFAULT nextval('embeddings_id_seq'),
motion_id INTEGER NOT NULL,
model TEXT NOT NULL,
vector JSON NOT NULL,
created_at TIMESTAMP DEFAULT current_timestamp,
PRIMARY KEY (id)
)
"""
)
conn.close()
# Insert 3 synthetic SVD vectors (k=4)
svd1 = [0.1, 0.2, 0.3, 0.4]
svd2 = [0.2, 0.1, 0.0, -0.1]
svd3 = [0.9, 0.8, 0.7, 0.6]
db.store_svd_vector("2024-Q1", "motion", "1", svd1)
db.store_svd_vector("2024-Q1", "motion", "2", svd2)
db.store_svd_vector("2024-Q1", "motion", "3", svd3)
# Insert text embeddings for motions 1 and 2 (16 dims)
text1 = [float(i) / 100.0 for i in range(16)]
text2 = [float(i) / 50.0 for i in range(16)]
conn = duckdb.connect(db_path)
conn.execute(
"INSERT INTO embeddings (motion_id, model, vector, created_at) VALUES (?, ?, ?, current_timestamp)",
(1, "text-model-1", json.dumps(text1)),
)
conn.execute(
"INSERT INTO embeddings (motion_id, model, vector, created_at) VALUES (?, ?, ?, current_timestamp)",
(2, "text-model-1", json.dumps(text2)),
)
conn.close()
# Import fuse function here to ensure module available
from pipeline.fusion import fuse_for_window
result = fuse_for_window("2024-Q1", db_path=db_path)
assert result["inserted"] == 2
assert result["skipped_missing_text"] == 1
# Verify fused embeddings stored
conn = duckdb.connect(db_path)
rows = conn.execute(
"SELECT motion_id, vector, svd_dims, text_dims FROM fused_embeddings WHERE window_id = ?",
("2024-Q1",),
).fetchall()
conn.close()
# Expect two rows for motions 1 and 2
assert len(rows) == 2
for motion_id, vector_json, svd_dims, text_dims in rows:
vec = json.loads(vector_json)
assert svd_dims == 4
assert text_dims == 16
assert len(vec) == 20