You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.4 KiB
79 lines
2.4 KiB
import json
|
|
|
|
import duckdb
|
|
import pytest
|
|
|
|
from database import MotionDatabase
|
|
|
|
|
|
def test_fuse_for_window(tmp_path):
|
|
db_path = str(tmp_path / "motions.db")
|
|
|
|
# Create MotionDatabase (this will initialize schema except embeddings)
|
|
db = MotionDatabase(db_path=db_path)
|
|
|
|
# Create embeddings table (migration not run by MotionDatabase)
|
|
conn = duckdb.connect(db_path)
|
|
conn.execute("CREATE SEQUENCE IF NOT EXISTS embeddings_id_seq START 1")
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS embeddings (
|
|
id INTEGER DEFAULT nextval('embeddings_id_seq'),
|
|
motion_id INTEGER NOT NULL,
|
|
model TEXT NOT NULL,
|
|
vector JSON NOT NULL,
|
|
created_at TIMESTAMP DEFAULT current_timestamp,
|
|
PRIMARY KEY (id)
|
|
)
|
|
"""
|
|
)
|
|
conn.close()
|
|
|
|
# Insert 3 synthetic SVD vectors (k=4)
|
|
svd1 = [0.1, 0.2, 0.3, 0.4]
|
|
svd2 = [0.2, 0.1, 0.0, -0.1]
|
|
svd3 = [0.9, 0.8, 0.7, 0.6]
|
|
|
|
db.store_svd_vector("2024-Q1", "motion", "1", svd1)
|
|
db.store_svd_vector("2024-Q1", "motion", "2", svd2)
|
|
db.store_svd_vector("2024-Q1", "motion", "3", svd3)
|
|
|
|
# Insert text embeddings for motions 1 and 2 (16 dims)
|
|
text1 = [float(i) / 100.0 for i in range(16)]
|
|
text2 = [float(i) / 50.0 for i in range(16)]
|
|
|
|
conn = duckdb.connect(db_path)
|
|
conn.execute(
|
|
"INSERT INTO embeddings (motion_id, model, vector, created_at) VALUES (?, ?, ?, current_timestamp)",
|
|
(1, "text-model-1", json.dumps(text1)),
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO embeddings (motion_id, model, vector, created_at) VALUES (?, ?, ?, current_timestamp)",
|
|
(2, "text-model-1", json.dumps(text2)),
|
|
)
|
|
conn.close()
|
|
|
|
# Import fuse function here to ensure module available
|
|
from pipeline.fusion import fuse_for_window
|
|
|
|
result = fuse_for_window("2024-Q1", db_path=db_path)
|
|
|
|
assert result["inserted"] == 2
|
|
assert result["skipped_missing_text"] == 1
|
|
|
|
# Verify fused embeddings stored
|
|
conn = duckdb.connect(db_path)
|
|
rows = conn.execute(
|
|
"SELECT motion_id, vector, svd_dims, text_dims FROM fused_embeddings WHERE window_id = ?",
|
|
("2024-Q1",),
|
|
).fetchall()
|
|
conn.close()
|
|
|
|
# Expect two rows for motions 1 and 2
|
|
assert len(rows) == 2
|
|
|
|
for motion_id, vector_json, svd_dims, text_dims in rows:
|
|
vec = json.loads(vector_json)
|
|
assert svd_dims == 4
|
|
assert text_dims == 16
|
|
assert len(vec) == 20
|
|
|