You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
65 lines
2.1 KiB
65 lines
2.1 KiB
def test_similarity_compute_and_lookup(tmp_path):
|
|
import pytest
|
|
|
|
duckdb = pytest.importorskip("duckdb")
|
|
|
|
# local duckdb imported above
|
|
from database import MotionDatabase
|
|
|
|
import similarity.compute as compute
|
|
import similarity.lookup as lookup
|
|
|
|
db_path = str(tmp_path / "motions.db")
|
|
|
|
# Build MotionDatabase on tmp_path
|
|
db = MotionDatabase(db_path=db_path)
|
|
|
|
# Insert three motions directly (avoid insert_motion which expects migration-added columns)
|
|
conn = duckdb.connect(db_path)
|
|
motion_ids = []
|
|
for i in range(1, 4):
|
|
conn.execute(
|
|
"INSERT INTO motions (title, url) VALUES (?, ?)",
|
|
(f"motion {i}", f"http://example/{i}"),
|
|
)
|
|
row = conn.execute(
|
|
"SELECT id FROM motions WHERE url = ?", (f"http://example/{i}",)
|
|
).fetchone()
|
|
assert row is not None
|
|
motion_ids.append(row[0])
|
|
conn.close()
|
|
|
|
# Insert fused_embeddings for window 'W1'
|
|
vectors = [[1, 0, 0], [0, 1, 0], [1, 1, 0]]
|
|
for motion_id, vec in zip(motion_ids, vectors):
|
|
rid = db.store_fused_embedding(
|
|
motion_id=motion_id, window_id="W1", vector=vec, svd_dims=1, text_dims=2
|
|
)
|
|
assert rid != -1
|
|
|
|
# Compute similarities
|
|
inserted = compute.compute_similarities(
|
|
vector_type="fused", window_id="W1", top_k=1, db_path=db_path
|
|
)
|
|
# depending on implementation we may insert 2 or 3 rows (or more); allow 2 or 3
|
|
assert inserted in (2, 3)
|
|
|
|
# Lookup neighbors for motion 1
|
|
neighbors = lookup.get_similar_motions(
|
|
motion_id=motion_ids[0],
|
|
vector_type="fused",
|
|
window_id="W1",
|
|
top_k=2,
|
|
db_path=db_path,
|
|
)
|
|
assert len(neighbors) >= 1
|
|
|
|
# Verify ordering: motion 3 ([1,1,0]) should be closer to motion 1 ([1,0,0]) than motion 2 ([0,1,0])
|
|
if len(neighbors) >= 2:
|
|
first = neighbors[0]
|
|
second = neighbors[1]
|
|
assert first["motion_id"] == motion_ids[2]
|
|
assert first["score"] >= second["score"]
|
|
else:
|
|
# If only one neighbor returned, it should be motion 3
|
|
assert neighbors[0]["motion_id"] == motion_ids[2]
|
|
|