You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/scripts/qa_similarity.py

150 lines
5.1 KiB

"""Quick QA script that samples motions and checks similarity cache quality.
Writes a short JSON summary into thoughts/ledgers/qa_similarity_{ts}.json
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import random
from datetime import datetime
from typing import List
_logger = logging.getLogger(__name__)
def sample_motion_ids(sample_size: int) -> List[int]:
# naive: select all motion ids from DB and sample
# Prefer any dynamically-provided database object from the 'database'
# module so tests can inject a fake via sys.modules.
try:
database_mod = __import__("database")
db_obj = getattr(database_mod, "db", None)
if db_obj and hasattr(db_obj, "sample_motions"):
return db_obj.sample_motions(sample_size)
except Exception:
pass
try:
conn = (
__import__("duckdb").connect(db.db_path) if __import__("duckdb") else None
)
except Exception:
conn = None
if conn is None:
# fallback: read from motions.json if present (file-backed mode)
# Not implemented: return empty
return []
try:
rows = conn.execute("SELECT id FROM motions").fetchall()
conn.close()
ids = [r[0] for r in rows]
if not ids:
return []
return random.sample(ids, min(sample_size, len(ids)))
except Exception:
if conn:
try:
conn.close()
except Exception:
pass
return []
def run_qa(db_path: str, sample_size: int = 50, top_k: int = 5) -> dict:
summary = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"sample_size": sample_size,
"top_k": top_k,
"results": [],
}
ids = sample_motion_ids(sample_size)
if not ids:
summary["error"] = "no motion ids available"
return summary
# Resolve db at runtime so tests can substitute a fake module
try:
database_mod = __import__("database")
db_obj = getattr(database_mod, "db", None)
except Exception:
db_obj = None
for mid in ids:
if db_obj and hasattr(db_obj, "get_cached_similarities"):
sims = db_obj.get_cached_similarities(mid, top_k=top_k)
else:
# fallback: attempt to call module-level db if present
try:
from database import db as fallback_db
sims = fallback_db.get_cached_similarities(
mid, vector_type="fused", top_k=top_k
)
except Exception:
sims = []
# heuristics: count how many top_k have score >= 0.99999 and different target ids
suspicious = 0
for r in sims:
try:
score = float(r.get("score", 0.0))
target = (
r.get("target_motion_id")
if r.get("target_motion_id") is not None
else r.get("id")
)
if score > 0.99999 and int(target) != int(mid):
suspicious += 1
except Exception:
# Be tolerant of unexpected structures in similarity rows
continue
summary["results"].append(
{"motion_id": mid, "top_k": len(sims), "suspicious": suspicious}
)
return summary
def main(db_path: str | None = None, sample_size: int = 50, top_k: int = 5) -> dict:
"""Wrapper used by CLI and tests.
When called with no args, this behaves like the prior CLI entrypoint and
will parse command-line args and write a ledger file. Tests call main()
directly with explicit parameters and expect a dict summary to be
returned (and a ledger to be written). To maintain compatibility we
support both usage patterns.
"""
# If invoked as CLI, db_path will be None and we should parse args and
# write the ledger file as before.
if db_path is None:
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="QA similarity cache sampler")
parser.add_argument("--db-path", required=False, help="Path to motions.db")
parser.add_argument("--sample-size", type=int, default=50)
parser.add_argument("--top-k", type=int, default=5)
args = parser.parse_args()
db_path = args.db_path or db.db_path
sample_size = args.sample_size
top_k = args.top_k
summary = run_qa(db_path or db.db_path, sample_size=sample_size, top_k=top_k)
# Provide a convenience mapping of motion_id -> result for easier consumption
# by callers/tests which expect a `motions` mapping.
summary["motions"] = {r["motion_id"]: r for r in summary.get("results", [])}
ledger_dir = os.path.join("thoughts", "ledgers")
os.makedirs(ledger_dir, exist_ok=True)
ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
path = os.path.join(ledger_dir, f"qa_similarity_{ts}.json")
with open(path, "w", encoding="utf-8") as fh:
json.dump(summary, fh, ensure_ascii=False, indent=2)
print(f"Wrote QA summary to {path}")
return {"ledger_path": path, **summary}
if __name__ == "__main__":
main()