You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/tests/test_similarity_db_helpers.py

67 lines
2.2 KiB

import json
from pathlib import Path
from database import MotionDatabase
def test_similarity_cache_roundtrip(tmp_path: Path):
db_file = tmp_path / "motions.db"
# Create MotionDatabase which should initialize schema
db = MotionDatabase(db_path=str(db_file))
# If MotionDatabase fell back to file mode, check JSON files
if getattr(db, "_file_mode", False):
emb_file = Path(str(db_file) + ".embeddings.json")
sim_file = Path(str(db_file) + ".similarity_cache.json")
assert emb_file.exists()
assert sim_file.exists()
assert json.loads(emb_file.read_text(encoding="utf-8")) == []
assert json.loads(sim_file.read_text(encoding="utf-8")) == []
else:
# Try to import duckdb only when needed
import duckdb
conn = duckdb.connect(str(db_file))
embeddings_count = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0]
similarity_count = conn.execute(
"SELECT COUNT(*) FROM similarity_cache"
).fetchone()[0]
conn.close()
assert embeddings_count == 0
assert similarity_count == 0
# Insert two similarity rows via helper
rows = [
{
"source_motion_id": 1,
"target_motion_id": 2,
"score": 0.5,
"vector_type": "text",
"window_id": None,
},
{
"source_motion_id": 1,
"target_motion_id": 3,
"score": 0.9,
"vector_type": "text",
"window_id": None,
},
]
db.store_similarity_batch(rows)
# Read back cached similarities and verify ordering (highest score first)
results = db.get_cached_similarities(source_motion_id=1, vector_type="text")
assert len(results) == 2
# results may be dicts from DB or file-backed dicts
assert results[0]["target_motion_id"] == 3
assert abs(float(results[0]["score"]) - 0.9) < 1e-6
assert results[1]["target_motion_id"] == 2
assert abs(float(results[1]["score"]) - 0.5) < 1e-6
# Clear cache and verify it's empty
db.clear_similarity_cache(vector_type="text")
results_after_clear = db.get_cached_similarities(
source_motion_id=1, vector_type="text"
)
assert results_after_clear == []