You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
150 lines
5.1 KiB
150 lines
5.1 KiB
"""Quick QA script that samples motions and checks similarity cache quality.
|
|
|
|
Writes a short JSON summary into thoughts/ledgers/qa_similarity_{ts}.json
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
from datetime import datetime
|
|
from typing import List
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def sample_motion_ids(sample_size: int) -> List[int]:
|
|
# naive: select all motion ids from DB and sample
|
|
# Prefer any dynamically-provided database object from the 'database'
|
|
# module so tests can inject a fake via sys.modules.
|
|
try:
|
|
database_mod = __import__("database")
|
|
db_obj = getattr(database_mod, "db", None)
|
|
if db_obj and hasattr(db_obj, "sample_motions"):
|
|
return db_obj.sample_motions(sample_size)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
conn = (
|
|
__import__("duckdb").connect(db.db_path) if __import__("duckdb") else None
|
|
)
|
|
except Exception:
|
|
conn = None
|
|
|
|
if conn is None:
|
|
# fallback: read from motions.json if present (file-backed mode)
|
|
# Not implemented: return empty
|
|
return []
|
|
|
|
try:
|
|
rows = conn.execute("SELECT id FROM motions").fetchall()
|
|
conn.close()
|
|
ids = [r[0] for r in rows]
|
|
if not ids:
|
|
return []
|
|
return random.sample(ids, min(sample_size, len(ids)))
|
|
except Exception:
|
|
if conn:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
return []
|
|
|
|
|
|
def run_qa(db_path: str, sample_size: int = 50, top_k: int = 5) -> dict:
|
|
summary = {
|
|
"timestamp": datetime.utcnow().isoformat() + "Z",
|
|
"sample_size": sample_size,
|
|
"top_k": top_k,
|
|
"results": [],
|
|
}
|
|
|
|
ids = sample_motion_ids(sample_size)
|
|
if not ids:
|
|
summary["error"] = "no motion ids available"
|
|
return summary
|
|
|
|
# Resolve db at runtime so tests can substitute a fake module
|
|
try:
|
|
database_mod = __import__("database")
|
|
db_obj = getattr(database_mod, "db", None)
|
|
except Exception:
|
|
db_obj = None
|
|
|
|
for mid in ids:
|
|
if db_obj and hasattr(db_obj, "get_cached_similarities"):
|
|
sims = db_obj.get_cached_similarities(mid, top_k=top_k)
|
|
else:
|
|
# fallback: attempt to call module-level db if present
|
|
try:
|
|
from database import db as fallback_db
|
|
|
|
sims = fallback_db.get_cached_similarities(
|
|
mid, vector_type="fused", top_k=top_k
|
|
)
|
|
except Exception:
|
|
sims = []
|
|
# heuristics: count how many top_k have score >= 0.99999 and different target ids
|
|
suspicious = 0
|
|
for r in sims:
|
|
try:
|
|
score = float(r.get("score", 0.0))
|
|
target = (
|
|
r.get("target_motion_id")
|
|
if r.get("target_motion_id") is not None
|
|
else r.get("id")
|
|
)
|
|
if score > 0.99999 and int(target) != int(mid):
|
|
suspicious += 1
|
|
except Exception:
|
|
# Be tolerant of unexpected structures in similarity rows
|
|
continue
|
|
summary["results"].append(
|
|
{"motion_id": mid, "top_k": len(sims), "suspicious": suspicious}
|
|
)
|
|
|
|
return summary
|
|
|
|
|
|
def main(db_path: str | None = None, sample_size: int = 50, top_k: int = 5) -> dict:
|
|
"""Wrapper used by CLI and tests.
|
|
|
|
When called with no args, this behaves like the prior CLI entrypoint and
|
|
will parse command-line args and write a ledger file. Tests call main()
|
|
directly with explicit parameters and expect a dict summary to be
|
|
returned (and a ledger to be written). To maintain compatibility we
|
|
support both usage patterns.
|
|
"""
|
|
# If invoked as CLI, db_path will be None and we should parse args and
|
|
# write the ledger file as before.
|
|
if db_path is None:
|
|
logging.basicConfig(level=logging.INFO)
|
|
parser = argparse.ArgumentParser(description="QA similarity cache sampler")
|
|
parser.add_argument("--db-path", required=False, help="Path to motions.db")
|
|
parser.add_argument("--sample-size", type=int, default=50)
|
|
parser.add_argument("--top-k", type=int, default=5)
|
|
args = parser.parse_args()
|
|
db_path = args.db_path or db.db_path
|
|
sample_size = args.sample_size
|
|
top_k = args.top_k
|
|
|
|
summary = run_qa(db_path or db.db_path, sample_size=sample_size, top_k=top_k)
|
|
# Provide a convenience mapping of motion_id -> result for easier consumption
|
|
# by callers/tests which expect a `motions` mapping.
|
|
summary["motions"] = {r["motion_id"]: r for r in summary.get("results", [])}
|
|
ledger_dir = os.path.join("thoughts", "ledgers")
|
|
os.makedirs(ledger_dir, exist_ok=True)
|
|
ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
|
path = os.path.join(ledger_dir, f"qa_similarity_{ts}.json")
|
|
with open(path, "w", encoding="utf-8") as fh:
|
|
json.dump(summary, fh, ensure_ascii=False, indent=2)
|
|
print(f"Wrote QA summary to {path}")
|
|
return {"ledger_path": path, **summary}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|