"""Quick QA script that samples motions and checks similarity cache quality. Writes a short JSON summary into thoughts/ledgers/qa_similarity_{ts}.json """ from __future__ import annotations import argparse import json import logging import os import random from datetime import datetime from typing import List _logger = logging.getLogger(__name__) def sample_motion_ids(sample_size: int) -> List[int]: # naive: select all motion ids from DB and sample # Prefer any dynamically-provided database object from the 'database' # module so tests can inject a fake via sys.modules. try: database_mod = __import__("database") db_obj = getattr(database_mod, "db", None) if db_obj and hasattr(db_obj, "sample_motions"): return db_obj.sample_motions(sample_size) except Exception: pass try: conn = ( __import__("duckdb").connect(db.db_path) if __import__("duckdb") else None ) except Exception: conn = None if conn is None: # fallback: read from motions.json if present (file-backed mode) # Not implemented: return empty return [] try: rows = conn.execute("SELECT id FROM motions").fetchall() conn.close() ids = [r[0] for r in rows] if not ids: return [] return random.sample(ids, min(sample_size, len(ids))) except Exception: if conn: try: conn.close() except Exception: pass return [] def run_qa(db_path: str, sample_size: int = 50, top_k: int = 5) -> dict: summary = { "timestamp": datetime.utcnow().isoformat() + "Z", "sample_size": sample_size, "top_k": top_k, "results": [], } ids = sample_motion_ids(sample_size) if not ids: summary["error"] = "no motion ids available" return summary # Resolve db at runtime so tests can substitute a fake module try: database_mod = __import__("database") db_obj = getattr(database_mod, "db", None) except Exception: db_obj = None for mid in ids: if db_obj and hasattr(db_obj, "get_cached_similarities"): sims = db_obj.get_cached_similarities(mid, top_k=top_k) else: # fallback: attempt to call module-level db if present try: from database import db as fallback_db sims = fallback_db.get_cached_similarities( mid, vector_type="fused", top_k=top_k ) except Exception: sims = [] # heuristics: count how many top_k have score >= 0.99999 and different target ids suspicious = 0 for r in sims: try: score = float(r.get("score", 0.0)) target = ( r.get("target_motion_id") if r.get("target_motion_id") is not None else r.get("id") ) if score > 0.99999 and int(target) != int(mid): suspicious += 1 except Exception: # Be tolerant of unexpected structures in similarity rows continue summary["results"].append( {"motion_id": mid, "top_k": len(sims), "suspicious": suspicious} ) return summary def main(db_path: str | None = None, sample_size: int = 50, top_k: int = 5) -> dict: """Wrapper used by CLI and tests. When called with no args, this behaves like the prior CLI entrypoint and will parse command-line args and write a ledger file. Tests call main() directly with explicit parameters and expect a dict summary to be returned (and a ledger to be written). To maintain compatibility we support both usage patterns. """ # If invoked as CLI, db_path will be None and we should parse args and # write the ledger file as before. if db_path is None: logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description="QA similarity cache sampler") parser.add_argument("--db-path", required=False, help="Path to motions.db") parser.add_argument("--sample-size", type=int, default=50) parser.add_argument("--top-k", type=int, default=5) args = parser.parse_args() db_path = args.db_path or db.db_path sample_size = args.sample_size top_k = args.top_k summary = run_qa(db_path or db.db_path, sample_size=sample_size, top_k=top_k) # Provide a convenience mapping of motion_id -> result for easier consumption # by callers/tests which expect a `motions` mapping. summary["motions"] = {r["motion_id"]: r for r in summary.get("results", [])} ledger_dir = os.path.join("thoughts", "ledgers") os.makedirs(ledger_dir, exist_ok=True) ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") path = os.path.join(ledger_dir, f"qa_similarity_{ts}.json") with open(path, "w", encoding="utf-8") as fh: json.dump(summary, fh, ensure_ascii=False, indent=2) print(f"Wrote QA summary to {path}") return {"ledger_path": path, **summary} if __name__ == "__main__": main()