You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
5.6 KiB
173 lines
5.6 KiB
"""Generate thoughts/explorer/top_svd_top_motions.json from svd_vectors.
|
|
|
|
For each SVD component, finds the top N motions by absolute score (split
|
|
equally between positive and negative pole), joins with the motions table,
|
|
and writes the result to the output JSON file.
|
|
|
|
Usage:
|
|
uv run python3 scripts/generate_svd_json.py --db data/motions.db --window current_parliament
|
|
uv run python3 scripts/generate_svd_json.py --db data/motions.db --window 2025
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if ROOT not in sys.path:
|
|
sys.path.insert(0, ROOT)
|
|
|
|
logger = logging.getLogger("generate_svd_json")
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
|
|
|
|
def main(argv: Optional[List[str]] = None) -> int:
|
|
p = argparse.ArgumentParser(
|
|
description="Generate SVD top-motions JSON for a window."
|
|
)
|
|
p.add_argument("--db", default="data/motions.db", help="Path to motions.db")
|
|
p.add_argument(
|
|
"--window", default="current_parliament", help="SVD window_id to use"
|
|
)
|
|
p.add_argument(
|
|
"--top-n",
|
|
type=int,
|
|
default=10,
|
|
help="Top N motions per component (split pos/neg)",
|
|
)
|
|
p.add_argument(
|
|
"--components", type=int, default=10, help="Number of SVD components to include"
|
|
)
|
|
p.add_argument(
|
|
"--out",
|
|
default="thoughts/explorer/top_svd_top_motions.json",
|
|
help="Output JSON file path",
|
|
)
|
|
args = p.parse_args(argv)
|
|
|
|
try:
|
|
import duckdb
|
|
except ImportError:
|
|
logger.error("duckdb not available")
|
|
return 2
|
|
|
|
con = duckdb.connect(database=args.db, read_only=True)
|
|
|
|
# Load all motion SVD vectors for the window
|
|
logger.info("Loading motion SVD vectors for window='%s' ...", args.window)
|
|
rows = con.execute(
|
|
"SELECT entity_id, vector FROM svd_vectors "
|
|
"WHERE entity_type='motion' AND window_id=?",
|
|
[args.window],
|
|
).fetchall()
|
|
|
|
if not rows:
|
|
logger.error(
|
|
"No motion vectors found for window='%s' in %s", args.window, args.db
|
|
)
|
|
con.close()
|
|
return 3
|
|
|
|
logger.info("Loaded %d motion vectors", len(rows))
|
|
|
|
# Parse vectors into {motion_id: list[float]}
|
|
motion_scores: Dict[int, List[float]] = {}
|
|
for entity_id, raw_vec in rows:
|
|
try:
|
|
if isinstance(raw_vec, str):
|
|
vec = json.loads(raw_vec)
|
|
elif isinstance(raw_vec, (bytes, bytearray)):
|
|
vec = json.loads(raw_vec.decode())
|
|
elif isinstance(raw_vec, list):
|
|
vec = raw_vec
|
|
else:
|
|
vec = list(raw_vec)
|
|
motion_scores[int(entity_id)] = [
|
|
float(v) if v is not None else 0.0 for v in vec
|
|
]
|
|
except Exception:
|
|
logger.warning("Failed to parse vector for motion_id=%s", entity_id)
|
|
|
|
logger.info("Parsed %d motion vectors", len(motion_scores))
|
|
|
|
n_positive = args.top_n // 2
|
|
n_negative = args.top_n - n_positive
|
|
|
|
output_rows: List[Dict[str, Any]] = []
|
|
all_motion_ids: List[int] = []
|
|
|
|
# Collect top motions per component
|
|
per_component: List[List[Tuple[int, float]]] = []
|
|
for comp_idx in range(args.components):
|
|
scored: List[Tuple[int, float]] = []
|
|
for mid, vec in motion_scores.items():
|
|
if comp_idx < len(vec):
|
|
scored.append((mid, vec[comp_idx]))
|
|
|
|
scored.sort(key=lambda x: x[1], reverse=True)
|
|
top_positive = scored[:n_positive]
|
|
top_negative = scored[-n_negative:]
|
|
combined = top_positive + list(reversed(top_negative))
|
|
per_component.append(combined)
|
|
all_motion_ids.extend(mid for mid, _ in combined)
|
|
|
|
# Batch-fetch motion details
|
|
unique_ids = list(set(all_motion_ids))
|
|
if not unique_ids:
|
|
logger.error("No motion IDs to fetch")
|
|
con.close()
|
|
return 4
|
|
|
|
logger.info("Fetching details for %d unique motions ...", len(unique_ids))
|
|
placeholders = ", ".join("?" for _ in unique_ids)
|
|
detail_rows = con.execute(
|
|
f"SELECT id, title, body_text, date, policy_area FROM motions WHERE id IN ({placeholders})",
|
|
unique_ids,
|
|
).fetchall()
|
|
con.close()
|
|
|
|
details_map: Dict[int, tuple] = {row[0]: row for row in detail_rows}
|
|
logger.info("Fetched details for %d motions", len(details_map))
|
|
|
|
# Build output rows
|
|
for comp_idx, top_motions in enumerate(per_component):
|
|
comp_num = comp_idx + 1
|
|
for mid, score in top_motions:
|
|
detail = details_map.get(mid)
|
|
output_rows.append(
|
|
{
|
|
"component": comp_num,
|
|
"motion_id": mid,
|
|
"score": score,
|
|
"title": detail[1] if detail else None,
|
|
"body_text": detail[2] if detail else None,
|
|
"date": str(detail[3])[:10] if detail and detail[3] else None,
|
|
"policy_area": detail[4] if detail else None,
|
|
}
|
|
)
|
|
|
|
output: Dict[str, Any] = {"window": args.window, "rows": output_rows}
|
|
|
|
out_dir = os.path.dirname(args.out)
|
|
if out_dir:
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
with open(args.out, "w", encoding="utf-8") as f:
|
|
json.dump(output, f, ensure_ascii=False, indent=2)
|
|
|
|
logger.info(
|
|
"Written %d rows (%d components) to %s",
|
|
len(output_rows),
|
|
args.components,
|
|
args.out,
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|
|
|