You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
2.2 KiB
67 lines
2.2 KiB
import json
|
|
from pathlib import Path
|
|
|
|
from database import MotionDatabase
|
|
|
|
|
|
def test_similarity_cache_roundtrip(tmp_path: Path):
|
|
db_file = tmp_path / "motions.db"
|
|
# Create MotionDatabase which should initialize schema
|
|
db = MotionDatabase(db_path=str(db_file))
|
|
|
|
# If MotionDatabase fell back to file mode, check JSON files
|
|
if getattr(db, "_file_mode", False):
|
|
emb_file = Path(str(db_file) + ".embeddings.json")
|
|
sim_file = Path(str(db_file) + ".similarity_cache.json")
|
|
assert emb_file.exists()
|
|
assert sim_file.exists()
|
|
assert json.loads(emb_file.read_text(encoding="utf-8")) == []
|
|
assert json.loads(sim_file.read_text(encoding="utf-8")) == []
|
|
else:
|
|
# Try to import duckdb only when needed
|
|
import duckdb
|
|
|
|
conn = duckdb.connect(str(db_file))
|
|
embeddings_count = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0]
|
|
similarity_count = conn.execute(
|
|
"SELECT COUNT(*) FROM similarity_cache"
|
|
).fetchone()[0]
|
|
conn.close()
|
|
assert embeddings_count == 0
|
|
assert similarity_count == 0
|
|
|
|
# Insert two similarity rows via helper
|
|
rows = [
|
|
{
|
|
"source_motion_id": 1,
|
|
"target_motion_id": 2,
|
|
"score": 0.5,
|
|
"vector_type": "text",
|
|
"window_id": None,
|
|
},
|
|
{
|
|
"source_motion_id": 1,
|
|
"target_motion_id": 3,
|
|
"score": 0.9,
|
|
"vector_type": "text",
|
|
"window_id": None,
|
|
},
|
|
]
|
|
|
|
db.store_similarity_batch(rows)
|
|
|
|
# Read back cached similarities and verify ordering (highest score first)
|
|
results = db.get_cached_similarities(source_motion_id=1, vector_type="text")
|
|
assert len(results) == 2
|
|
# results may be dicts from DB or file-backed dicts
|
|
assert results[0]["target_motion_id"] == 3
|
|
assert abs(float(results[0]["score"]) - 0.9) < 1e-6
|
|
assert results[1]["target_motion_id"] == 2
|
|
assert abs(float(results[1]["score"]) - 0.5) < 1e-6
|
|
|
|
# Clear cache and verify it's empty
|
|
db.clear_similarity_cache(vector_type="text")
|
|
results_after_clear = db.get_cached_similarities(
|
|
source_motion_id=1, vector_type="text"
|
|
)
|
|
assert results_after_clear == []
|
|
|