You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_similarity_compute.py

65 lines
2.1 KiB

def test_similarity_compute_and_lookup(tmp_path):
import pytest
duckdb = pytest.importorskip("duckdb")
# local duckdb imported above
from database import MotionDatabase
import similarity.compute as compute
import similarity.lookup as lookup
db_path = str(tmp_path / "motions.db")
# Build MotionDatabase on tmp_path
db = MotionDatabase(db_path=db_path)
# Insert three motions directly (avoid insert_motion which expects migration-added columns)
conn = duckdb.connect(db_path)
motion_ids = []
for i in range(1, 4):
conn.execute(
"INSERT INTO motions (title, url) VALUES (?, ?)",
(f"motion {i}", f"http://example/{i}"),
)
row = conn.execute(
"SELECT id FROM motions WHERE url = ?", (f"http://example/{i}",)
).fetchone()
assert row is not None
motion_ids.append(row[0])
conn.close()
# Insert fused_embeddings for window 'W1'
vectors = [[1, 0, 0], [0, 1, 0], [1, 1, 0]]
for motion_id, vec in zip(motion_ids, vectors):
rid = db.store_fused_embedding(
motion_id=motion_id, window_id="W1", vector=vec, svd_dims=1, text_dims=2
)
assert rid != -1
# Compute similarities
inserted = compute.compute_similarities(
vector_type="fused", window_id="W1", top_k=1, db_path=db_path
)
# depending on implementation we may insert 2 or 3 rows (or more); allow 2 or 3
assert inserted in (2, 3)
# Lookup neighbors for motion 1
neighbors = lookup.get_similar_motions(
motion_id=motion_ids[0],
vector_type="fused",
window_id="W1",
top_k=2,
db_path=db_path,
)
assert len(neighbors) >= 1
# Verify ordering: motion 3 ([1,1,0]) should be closer to motion 1 ([1,0,0]) than motion 2 ([0,1,0])
if len(neighbors) >= 2:
first = neighbors[0]
second = neighbors[1]
assert first["motion_id"] == motion_ids[2]
assert first["score"] >= second["score"]
else:
# If only one neighbor returned, it should be motion 3
assert neighbors[0]["motion_id"] == motion_ids[2]