You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
306 lines
12 KiB
306 lines
12 KiB
"""CLI orchestrator for the parliamentary embedding pipeline.
|
|
|
|
Runs all phases in sequence:
|
|
1. fetch_mp_metadata — pull MP party + tenure from OData
|
|
2. extract_mp_votes — parse voting_results JSON → mp_votes rows
|
|
3. svd per window — build vote matrix, SVD, Procrustes-align
|
|
4. text embeddings — fill any gaps in the embeddings table
|
|
5. fuse per window — concatenate SVD + text vectors → fused_embeddings
|
|
|
|
Usage:
|
|
uv run python -m pipeline.run_pipeline [options]
|
|
|
|
Options:
|
|
--db-path PATH Path to the DuckDB file (default: data/motions.db)
|
|
--start-date DATE Window start (YYYY-MM-DD, default: 2 years ago)
|
|
--end-date DATE Window end (YYYY-MM-DD, default: today)
|
|
--window-size {quarterly,annual} Time window granularity (default: quarterly)
|
|
--svd-k INT SVD dimensionality (default: 50)
|
|
--text-model TEXT Text embedding model name (default: from ai_provider)
|
|
--skip-metadata Skip fetching MP metadata from OData
|
|
--skip-extract Skip extracting MP votes from voting_results
|
|
--skip-svd Skip SVD computation
|
|
--skip-text Skip text embedding gap-fill
|
|
--skip-fusion Skip vector fusion
|
|
--dry-run Print actions but make no DB writes
|
|
"""
|
|
|
|
import argparse
|
|
import calendar
|
|
import logging
|
|
import sys
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import date, timedelta
|
|
from typing import List, Tuple
|
|
|
|
from database import MotionDatabase
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _generate_windows(
|
|
start: date, end: date, granularity: str
|
|
) -> List[Tuple[str, str, str]]:
|
|
"""Return list of (window_id, start_str, end_str) tuples.
|
|
|
|
window_id format:
|
|
quarterly → "2024-Q1", "2024-Q2", …
|
|
annual → "2024"
|
|
"""
|
|
windows = []
|
|
cursor = date(start.year, start.month, 1)
|
|
|
|
if granularity == "annual":
|
|
cursor = date(start.year, 1, 1)
|
|
while cursor <= end:
|
|
year_end = date(cursor.year, 12, 31)
|
|
w_end = min(year_end, end)
|
|
windows.append((str(cursor.year), cursor.isoformat(), w_end.isoformat()))
|
|
cursor = date(cursor.year + 1, 1, 1)
|
|
else:
|
|
# quarterly
|
|
quarter_starts = {1: 1, 2: 4, 3: 7, 4: 10}
|
|
quarter_ends = {1: 3, 2: 6, 3: 9, 4: 12}
|
|
|
|
# Align cursor to quarter start
|
|
q = (cursor.month - 1) // 3 + 1
|
|
cursor = date(cursor.year, quarter_starts[q], 1)
|
|
|
|
while cursor <= end:
|
|
q = (cursor.month - 1) // 3 + 1
|
|
q_end_month = quarter_ends[q]
|
|
last_day = calendar.monthrange(cursor.year, q_end_month)[1]
|
|
q_end = date(cursor.year, q_end_month, last_day)
|
|
w_end = min(q_end, end)
|
|
window_id = f"{cursor.year}-Q{q}"
|
|
windows.append((window_id, cursor.isoformat(), w_end.isoformat()))
|
|
cursor = q_end + timedelta(days=1)
|
|
|
|
return windows
|
|
|
|
|
|
def run(args: argparse.Namespace) -> int:
|
|
"""Execute the pipeline. Returns exit code (0 = success)."""
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
db_path = args.db_path
|
|
dry_run = args.dry_run
|
|
|
|
if dry_run:
|
|
_logger.info("DRY RUN — no writes will be made")
|
|
|
|
# Resolve date range
|
|
end_date = date.fromisoformat(args.end_date) if args.end_date else date.today()
|
|
start_date = (
|
|
date.fromisoformat(args.start_date)
|
|
if args.start_date
|
|
else end_date - timedelta(days=730)
|
|
)
|
|
|
|
_logger.info(
|
|
"Pipeline run: %s → %s (%s windows), db=%s",
|
|
start_date,
|
|
end_date,
|
|
args.window_size,
|
|
db_path,
|
|
)
|
|
|
|
db = MotionDatabase(db_path)
|
|
|
|
# ── Phase 1: MP metadata ────────────────────────────────────────────────
|
|
if not args.skip_metadata:
|
|
_logger.info("Phase 1: fetching MP metadata from OData")
|
|
if not dry_run:
|
|
from pipeline.fetch_mp_metadata import fetch_mp_metadata
|
|
|
|
n = fetch_mp_metadata(db_path=db.db_path)
|
|
_logger.info(" mp_metadata: processed=%d", n)
|
|
else:
|
|
_logger.info(" [dry-run] would call fetch_mp_metadata(db)")
|
|
else:
|
|
_logger.info("Phase 1: skipped (--skip-metadata)")
|
|
|
|
# ── Phase 2: Extract MP votes ────────────────────────────────────────────
|
|
if not args.skip_extract:
|
|
_logger.info("Phase 2: extracting MP votes from voting_results")
|
|
if not dry_run:
|
|
from pipeline.extract_mp_votes import extract_mp_votes
|
|
|
|
result = extract_mp_votes(db_path=db.db_path)
|
|
_logger.info(
|
|
" mp_votes: inserted=%d motions_scanned=%d skipped=%d",
|
|
result["mp_rows_inserted"],
|
|
result["motions_scanned"],
|
|
result["motions_skipped"],
|
|
)
|
|
else:
|
|
_logger.info(" [dry-run] would call extract_mp_votes(db)")
|
|
else:
|
|
_logger.info("Phase 2: skipped (--skip-extract)")
|
|
|
|
# ── Phase 3: SVD per window ──────────────────────────────────────────────
|
|
if not args.skip_svd:
|
|
windows = _generate_windows(start_date, end_date, args.window_size)
|
|
_logger.info(
|
|
"Phase 3: SVD for %d windows (k=%d, parallel)", len(windows), args.svd_k
|
|
)
|
|
from pipeline.svd_pipeline import compute_svd_for_window
|
|
|
|
if dry_run:
|
|
for window_id, w_start, w_end in windows:
|
|
_logger.info(" [dry-run] would run SVD for window %s", window_id)
|
|
else:
|
|
# Compute all windows in parallel (numpy/scipy SVD releases the GIL).
|
|
# IMPORTANT: collect ALL results before writing — DuckDB rejects mixing
|
|
# read-only and read-write connections in the same process.
|
|
# The `with` block waits for all threads to finish before we exit it,
|
|
# ensuring all read-only connections are closed before writes begin.
|
|
futures = {}
|
|
max_workers = min(len(windows), (args.svd_workers or 4))
|
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
|
for window_id, w_start, w_end in windows:
|
|
fut = pool.submit(
|
|
compute_svd_for_window,
|
|
db.db_path,
|
|
window_id,
|
|
w_start,
|
|
w_end,
|
|
args.svd_k,
|
|
)
|
|
futures[fut] = window_id
|
|
# All threads are done here — all read-only connections are closed.
|
|
# Now write results sequentially.
|
|
for fut, window_id in futures.items():
|
|
try:
|
|
result = fut.result()
|
|
except Exception as exc:
|
|
_logger.error(" window %s raised: %s", window_id, exc)
|
|
continue
|
|
|
|
if result["k_used"] == 0:
|
|
_logger.info(" window %s: no data, skipped", window_id)
|
|
continue
|
|
|
|
rows = result["mp_rows"] + result["motion_rows"]
|
|
db.batch_store_svd_vectors(window_id, rows)
|
|
_logger.info(
|
|
" window %s: k_used=%d stored_mp=%d stored_motion=%d",
|
|
window_id,
|
|
result["k_used"],
|
|
len(result["mp_rows"]),
|
|
len(result["motion_rows"]),
|
|
)
|
|
else:
|
|
_logger.info("Phase 3: skipped (--skip-svd)")
|
|
|
|
# ── Phase 4: Text embeddings ──────────────────────────────────────────────
|
|
if not args.skip_text:
|
|
_logger.info("Phase 4: ensuring text embeddings")
|
|
if not dry_run:
|
|
from pipeline.text_pipeline import ensure_text_embeddings
|
|
|
|
stored, existing, no_text, errors = ensure_text_embeddings(
|
|
db_path=db_path, model=args.text_model, batch_size=args.text_batch_size
|
|
)
|
|
_logger.info(
|
|
" embeddings: stored=%d existing=%d no_text=%d errors=%d",
|
|
stored,
|
|
existing,
|
|
no_text,
|
|
errors,
|
|
)
|
|
else:
|
|
_logger.info(" [dry-run] would call ensure_text_embeddings")
|
|
else:
|
|
_logger.info("Phase 4: skipped (--skip-text)")
|
|
|
|
# ── Phase 5: Fusion per window ────────────────────────────────────────────
|
|
if not args.skip_fusion:
|
|
windows = _generate_windows(start_date, end_date, args.window_size)
|
|
_logger.info("Phase 5: fusing vectors for %d windows", len(windows))
|
|
from pipeline.fusion import fuse_for_window
|
|
|
|
for window_id, _w_start, _w_end in windows:
|
|
if not dry_run:
|
|
result = fuse_for_window(
|
|
window_id=window_id,
|
|
db_path=db_path,
|
|
model=args.text_model,
|
|
)
|
|
_logger.info(
|
|
" window %s: fused=%d skipped_no_svd=%d skipped_no_text=%d errors=%d",
|
|
window_id,
|
|
result.get("inserted", 0),
|
|
result.get("skipped_missing_svd", 0),
|
|
result.get("skipped_missing_text", 0),
|
|
result.get("errors", 0),
|
|
)
|
|
else:
|
|
_logger.info(" [dry-run] would fuse window %s", window_id)
|
|
else:
|
|
_logger.info("Phase 5: skipped (--skip-fusion)")
|
|
|
|
_logger.info("Pipeline complete.")
|
|
return 0
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(
|
|
description="Parliamentary embedding pipeline orchestrator",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
parser.add_argument(
|
|
"--db-path", default="data/motions.db", help="Path to DuckDB file"
|
|
)
|
|
parser.add_argument("--start-date", default=None, help="Window start YYYY-MM-DD")
|
|
parser.add_argument("--end-date", default=None, help="Window end YYYY-MM-DD")
|
|
parser.add_argument(
|
|
"--window-size",
|
|
choices=["quarterly", "annual"],
|
|
default="quarterly",
|
|
help="Time window granularity",
|
|
)
|
|
parser.add_argument("--svd-k", type=int, default=50, help="SVD dimensions")
|
|
parser.add_argument(
|
|
"--svd-workers",
|
|
type=int,
|
|
default=None,
|
|
help="Parallel workers for SVD (default: min(windows, 4))",
|
|
)
|
|
parser.add_argument(
|
|
"--text-model",
|
|
default=None,
|
|
help="Text embedding model (default: ai_provider default)",
|
|
)
|
|
parser.add_argument(
|
|
"--text-batch-size",
|
|
type=int,
|
|
default=200,
|
|
help="Number of texts per embedding API call (default: 200)",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-metadata", action="store_true", help="Skip MP metadata fetch"
|
|
)
|
|
parser.add_argument(
|
|
"--skip-extract", action="store_true", help="Skip MP vote extraction"
|
|
)
|
|
parser.add_argument("--skip-svd", action="store_true", help="Skip SVD computation")
|
|
parser.add_argument(
|
|
"--skip-text", action="store_true", help="Skip text embedding gap-fill"
|
|
)
|
|
parser.add_argument("--skip-fusion", action="store_true", help="Skip vector fusion")
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Print what would happen without writing anything",
|
|
)
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
sys.exit(run(args))
|
|
|