You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
5.7 KiB
179 lines
5.7 KiB
#!/usr/bin/env python3
|
|
"""Score ALL motions with 2D extremity (stijl + materieel) using subagents.
|
|
|
|
Usage:
|
|
# Sanity check: score 200 random motions, print summary
|
|
uv run python analysis/right_wing/extremity_score_all.py --sample 200
|
|
|
|
# Full run: output all batches as JSON for subagent dispatch
|
|
uv run python analysis/right_wing/extremity_score_all.py --all --output /tmp/all_batches.json
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import duckdb
|
|
|
|
from analysis.right_wing.extremity_rescore_2d import (
|
|
load_skill, format_batches, validate_single_result, store_scores,
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DB_PATH = str(Path(__file__).parent.parent.parent / "data" / "motions.db")
|
|
|
|
|
|
def sample_all_motions(db_path: str, n: int | None = None, seed: int = 42) -> list[dict]:
|
|
"""Sample motions from the full motions table (not just right_wing).
|
|
|
|
Skips motions already in extremity_scores_2d.
|
|
|
|
Args:
|
|
db_path: Path to DuckDB database.
|
|
n: Number of motions to sample (None = all).
|
|
seed: Random seed.
|
|
|
|
Returns:
|
|
List of dicts with keys: motion_id, title, text, layman.
|
|
"""
|
|
con = duckdb.connect(db_path)
|
|
try:
|
|
con.execute(f"SELECT setseed({seed / 1_000_000.0})")
|
|
|
|
already = con.execute(
|
|
"SELECT motion_id FROM extremity_scores_2d"
|
|
).fetchall()
|
|
already_ids = {r[0] for r in already}
|
|
|
|
rows = con.execute("""
|
|
SELECT id, title, body_text, layman_explanation
|
|
FROM motions
|
|
WHERE body_text IS NOT NULL
|
|
AND length(trim(body_text)) > 0
|
|
ORDER BY RANDOM()
|
|
""").fetchall()
|
|
|
|
motions = []
|
|
for row in rows:
|
|
mid = row[0]
|
|
if mid in already_ids:
|
|
continue
|
|
motions.append({
|
|
"motion_id": mid,
|
|
"title": (row[1] or "").strip(),
|
|
"text": (row[2] or "").strip(),
|
|
"layman": (row[3] or "").strip(),
|
|
})
|
|
if n and len(motions) >= n:
|
|
break
|
|
|
|
total = len(rows)
|
|
new = len(motions)
|
|
logger.info(
|
|
"Found %d motions total, %d already scored, %d new (%d skipped)",
|
|
total, len(already_ids), new,
|
|
total - len(already_ids) - new,
|
|
)
|
|
return motions
|
|
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
def prepare_batches(
|
|
db_path: str, n: int | None = None, batch_size: int = 20,
|
|
) -> tuple[list[dict], list[list[str]]]:
|
|
"""Sample motions and format into prompt batches.
|
|
|
|
Returns (motions, batches).
|
|
"""
|
|
skill = load_skill()
|
|
prompt = skill["prompt_template"]
|
|
|
|
motions = sample_all_motions(db_path, n=n)
|
|
batches = format_batches(motions, prompt, batch_size=batch_size)
|
|
|
|
logger.info(
|
|
"%d motions → %d batches (batch_size=%d)",
|
|
len(motions), len(batches), batch_size,
|
|
)
|
|
return motions, batches
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(
|
|
description="Score ALL motions with 2D extremity scoring"
|
|
)
|
|
parser.add_argument("--sample", type=int, metavar="N",
|
|
help="Number of motions to sample for sanity check")
|
|
parser.add_argument("--all", action="store_true",
|
|
help="Prepare all unscored motions for dispatch")
|
|
parser.add_argument("--batch-size", type=int, default=20,
|
|
help="Motions per subagent batch (default: 20)")
|
|
parser.add_argument("--output", type=str,
|
|
help="Write batch JSON to this file")
|
|
parser.add_argument("--preview", type=int, default=3,
|
|
help="Number of batch previews to print (default: 3)")
|
|
args = parser.parse_args()
|
|
|
|
if not args.sample and not args.all:
|
|
parser.error("Must specify --sample N or --all")
|
|
|
|
n = args.sample if args.sample else None
|
|
motions, batches = prepare_batches(DB_PATH, n=n, batch_size=args.batch_size)
|
|
|
|
if not batches:
|
|
logger.info("No batches to dispatch.")
|
|
return 0
|
|
|
|
# Print preview
|
|
print(f"\n{'='*60}")
|
|
print(f"Motions: {len(motions)} Batches: {len(batches)} Batch size: {args.batch_size}")
|
|
print(f"{'='*60}")
|
|
|
|
preview_n = min(args.preview, len(batches))
|
|
for i in range(preview_n):
|
|
print(f"\n--- Batch {i+1}/{len(batches)} ---")
|
|
for j, prompt_text in enumerate(batches[i]):
|
|
first_line = prompt_text.split("\n")[0] if prompt_text else "(empty)"
|
|
print(f" {j+1}. {first_line[:120]}...")
|
|
|
|
if len(batches) > preview_n:
|
|
print(f"\n... and {len(batches) - preview_n} more batches")
|
|
|
|
# Build output structure
|
|
output = {
|
|
"total_motions": len(motions),
|
|
"total_batches": len(batches),
|
|
"batch_size": args.batch_size,
|
|
"batches": [
|
|
{
|
|
"batch_id": i,
|
|
"motion_ids": [m["motion_id"] for m in motions[i * args.batch_size:(i + 1) * args.batch_size]],
|
|
"motion_count": len(batches[i]),
|
|
"prompts": batches[i],
|
|
}
|
|
for i in range(len(batches))
|
|
],
|
|
}
|
|
|
|
if args.output:
|
|
Path(args.output).write_text(json.dumps(output, ensure_ascii=False, indent=2))
|
|
logger.info("Wrote %d batches to %s", len(batches), args.output)
|
|
else:
|
|
# Save to default location
|
|
outpath = Path("/tmp/extremity_all_batches.json")
|
|
outpath.write_text(json.dumps(output, ensure_ascii=False, indent=2))
|
|
logger.info("Wrote %d batches to %s", len(batches), outpath)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|
|
|