You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
motief/analysis/right_wing/extremity_rescore_2d.py

362 lines
13 KiB

#!/usr/bin/env python3
"""Two-dimensional extremity rescoring orchestrator.
Scores Dutch parliamentary motions on two independent dimensions:
1. stijl_extremiteit (stylistic extremity, 1-5)
2. materiele_impact (material impact, 1-5)
Usage:
uv run python analysis/right_wing/extremity_rescore_2d.py --db data/motions.db
uv run python analysis/right_wing/extremity_rescore_2d.py --db data/motions.db --dry-run
"""
from __future__ import annotations
import argparse
import json
import logging
import re
from pathlib import Path
from typing import Any
import duckdb
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ── prompt / schema loading ──────────────────────────────────────────────────
SKILL_MD_PATH = Path(__file__).parent.parent.parent / ".opencode" / "skills" / "score-extremity" / "SKILL.md"
def load_skill(skill_path: str | None = None) -> dict[str, Any]:
"""Read SKILL.md and extract prompt template and output schemas.
Returns:
dict with keys "prompt_template", "single_schema", "batch_schema".
"""
path = Path(skill_path) if skill_path else SKILL_MD_PATH
if not path.exists():
raise FileNotFoundError(f"Skill file not found: {path}")
content = path.read_text(encoding="utf-8")
# Extract prompt template from ```text ... ``` block
prompt_match = re.search(r"```text\n(.*?)```", content, re.DOTALL)
prompt_template = prompt_match.group(1).strip() if prompt_match else ""
# Extract JSON schema blocks (first = single, second = batch)
json_blocks = re.findall(r"```json\n(.*?)```", content, re.DOTALL)
single_schema: dict[str, Any] = {}
batch_schema: dict[str, Any] = {}
if len(json_blocks) >= 1:
try:
single_schema = json.loads(json_blocks[0].strip())
except json.JSONDecodeError:
logger.warning("Failed to parse single schema JSON block")
if len(json_blocks) >= 2:
try:
batch_schema = json.loads(json_blocks[1].strip())
except json.JSONDecodeError:
logger.warning("Failed to parse batch schema JSON block")
return {
"prompt_template": prompt_template,
"single_schema": single_schema,
"batch_schema": batch_schema,
}
# ── sampling ─────────────────────────────────────────────────────────────────
def sample_motions(
db_path: str,
n_per_bucket: int = 25,
seed: int = 42,
) -> list[dict[str, Any]]:
"""Stratified sample from right_wing_motions JOIN extremity_scores.
Samples n_per_bucket motions from each text_score bucket (1-5).
Returns:
List of dicts with keys: motion_id, title, text, layman, text_score.
"""
con = duckdb.connect(db_path)
try:
# Ensure tables exist
tables = {t[0] for t in con.execute("SHOW TABLES").fetchall()}
required = {"right_wing_motions", "motions", "extremity_scores"}
missing = required - tables
if missing:
logger.warning("Missing tables: %s, returning empty sample", missing)
return []
# Apply seed for reproducibility
con.execute(f"SELECT setseed({seed / 1000000.0})")
rows = con.execute(
"""
SELECT m.id, m.title, m.body_text, m.layman_explanation, e.text_score
FROM right_wing_motions r
JOIN motions m ON r.motion_id = m.id
JOIN extremity_scores e ON r.motion_id = e.motion_id
WHERE r.classified = TRUE
AND e.text_score IS NOT NULL
AND e.error IS NULL
ORDER BY RANDOM()
"""
).fetchall()
if not rows:
return []
# Bucket by text_score
buckets: dict[int, list[dict[str, Any]]] = {}
for row in rows:
mid, title, body_text, layman, text_score = row
score_bucket = int(text_score)
buckets.setdefault(score_bucket, []).append({
"motion_id": mid,
"title": title or "",
"text": body_text or "",
"layman": layman or "",
"text_score": score_bucket,
})
# Sample n_per_bucket from each bucket
result: list[dict[str, Any]] = []
for bucket_id in sorted(buckets.keys()):
bucket = buckets[bucket_id]
result.extend(bucket[:n_per_bucket])
logger.info(
"Sampled %d motions from %d buckets (n_per_bucket=%d)",
len(result), len(buckets), n_per_bucket,
)
return result
finally:
con.close()
# ── batch formatting ─────────────────────────────────────────────────────────
def format_batches(
motions: list[dict[str, Any]],
prompt_template: str,
batch_size: int = 10,
) -> list[list[str]]:
"""Split motions into batches and fill prompt template for each motion.
Args:
motions: List of dicts with keys title, text, layman.
prompt_template: Template string with {title}, {text}, {layman} placeholders.
batch_size: Number of motions per batch.
Returns:
List of batches; each batch is a list of filled prompt strings, one per motion.
"""
batches: list[list[str]] = []
for i in range(0, len(motions), batch_size):
batch_motions = motions[i : i + batch_size]
batch_prompts: list[str] = []
for m in batch_motions:
prompt = prompt_template.format(
title=m.get("title", ""),
text=m.get("text", ""),
layman=m.get("layman", ""),
)
batch_prompts.append(prompt)
batches.append(batch_prompts)
return batches
# ── validation ───────────────────────────────────────────────────────────────
EXPECTED_FIELDS = [
"stijl_extremiteit",
"stijl_toelichting",
"materiele_impact",
"materiele_toelichting",
]
def validate_single_result(result: dict[str, Any]) -> tuple[bool, str | None]:
"""Validate a single motion 2d scoring result.
Returns:
(True, None) if valid, (False, error_message) otherwise.
"""
# Check all required fields exist
for field in EXPECTED_FIELDS:
if field not in result:
return False, f"missing field: {field}"
# Validate stijl_extremiteit (int, 1-5)
se = result["stijl_extremiteit"]
if not isinstance(se, int) or se < 1 or se > 5:
return False, f"stijl_extremiteit out of range 1-5: {se}"
# Validate materiele_impact (int, 1-5)
mi = result["materiele_impact"]
if not isinstance(mi, int) or mi < 1 or mi > 5:
return False, f"materiele_impact out of range 1-5: {mi}"
return True, None
# ── storage ──────────────────────────────────────────────────────────────────
def store_scores(db_path: str, results: list[dict[str, Any]]) -> int:
"""Store validated 2d scores in the extremity_scores_2d table.
Creates the table if it doesn't exist.
Args:
db_path: Path to DuckDB database.
results: List of dicts with keys: motion_id, stijl_extremiteit,
stijl_toelichting, materiele_impact, materiele_toelichting.
Returns:
Number of rows inserted.
"""
con = duckdb.connect(db_path)
try:
con.execute(
"""
CREATE TABLE IF NOT EXISTS extremity_scores_2d (
motion_id INTEGER PRIMARY KEY,
stijl_extremiteit INTEGER NOT NULL,
stijl_toelichting TEXT,
materiele_impact INTEGER NOT NULL,
materiele_toelichting TEXT
)
"""
)
count = 0
for r in results:
con.execute(
"""
INSERT OR REPLACE INTO extremity_scores_2d
(motion_id, stijl_extremiteit, stijl_toelichting, materiele_impact, materiele_toelichting)
VALUES (?, ?, ?, ?, ?)
""",
(
r["motion_id"],
r["stijl_extremiteit"],
r.get("stijl_toelichting"),
r["materiele_impact"],
r.get("materiele_toelichting"),
),
)
count += 1
con.commit()
logger.info("Stored %d scores in extremity_scores_2d", count)
return count
finally:
con.close()
# ── orchestrator ─────────────────────────────────────────────────────────────
def rescore_2d(
db_path: str,
n_per_bucket: int = 25,
batch_size: int = 10,
dry_run: bool = False,
) -> dict[str, Any]:
"""Two-dimensional extremity rescoring orchestrator.
Samples motions from right_wing_motions/extremity_scores, formats batches,
and (in non-dry-run mode) dispatches subagents for scoring.
Args:
db_path: Path to DuckDB database.
n_per_bucket: Number of motions to sample per text_score bucket.
batch_size: Motions per subagent batch.
dry_run: If True, only print the plan without spawning subagents.
Returns:
Dict with summary stats.
"""
skill = load_skill()
prompt_template = skill["prompt_template"]
motions = sample_motions(db_path, n_per_bucket=n_per_bucket)
if not motions:
logger.warning("No motions to rescore.")
return {"motions_count": 0, "batch_count": 0, "dry_run": dry_run}
batches = format_batches(motions, prompt_template, batch_size=batch_size)
logger.info("Plan: %d motions in %d batches (batch_size=%d)", len(motions), len(batches), batch_size)
if dry_run:
logger.info("DRY RUN — no subagents will be spawned.")
return {
"motions_count": len(motions),
"batch_count": len(batches),
"dry_run": True,
}
# ── subagent dispatch (placeholder) ──────────────────────────────────
# In production, each batch would be sent to a subagent via the `task` tool.
# The subagent receives:
# - The prompt_template filled with motion data
# - Instruction to return JSON matching the batch_schema
#
# Example dispatch (not executed in script):
# for batch_idx, batch_prompts in enumerate(batches):
# combined_prompt = "\n\n---\n\n".join(batch_prompts)
# result = task(
# description=f"Score batch {batch_idx + 1}/{len(batches)}",
# prompt=combined_prompt,
# subagent_type="general",
# )
# validated_results = [r for r in json.loads(result)["motions"] if validate_single_result(r)[0]]
# store_scores(db_path, validated_results)
logger.info(
"Subagent dispatch placeholder: %d batches ready for scoring. "
"Run via an agent context (e.g. opencode task) to execute.",
len(batches),
)
return {
"motions_count": len(motions),
"batch_count": len(batches),
"dry_run": False,
"subagents_spawned": 0,
}
# ── CLI ──────────────────────────────────────────────────────────────────────
def main() -> int:
parser = argparse.ArgumentParser(
description="Two-dimensional extremity rescoring orchestrator"
)
parser.add_argument("--db", default="data/motions.db", help="Path to DuckDB database")
parser.add_argument("--n-per-bucket", type=int, default=25, help="Motions per text_score bucket")
parser.add_argument("--batch-size", type=int, default=10, help="Motions per subagent batch")
parser.add_argument("--dry-run", action="store_true", help="Print plan without spawning subagents")
args = parser.parse_args()
result = rescore_2d(
db_path=args.db,
n_per_bucket=args.n_per_bucket,
batch_size=args.batch_size,
dry_run=args.dry_run,
)
print(json.dumps(result, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())