You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/pipeline/run_pipeline.py

306 lines
12 KiB

"""CLI orchestrator for the parliamentary embedding pipeline.
Runs all phases in sequence:
1. fetch_mp_metadata — pull MP party + tenure from OData
2. extract_mp_votes — parse voting_results JSON → mp_votes rows
3. svd per window — build vote matrix, SVD, Procrustes-align
4. text embeddings — fill any gaps in the embeddings table
5. fuse per window — concatenate SVD + text vectors → fused_embeddings
Usage:
uv run python -m pipeline.run_pipeline [options]
Options:
--db-path PATH Path to the DuckDB file (default: data/motions.db)
--start-date DATE Window start (YYYY-MM-DD, default: 2 years ago)
--end-date DATE Window end (YYYY-MM-DD, default: today)
--window-size {quarterly,annual} Time window granularity (default: quarterly)
--svd-k INT SVD dimensionality (default: 50)
--text-model TEXT Text embedding model name (default: from ai_provider)
--skip-metadata Skip fetching MP metadata from OData
--skip-extract Skip extracting MP votes from voting_results
--skip-svd Skip SVD computation
--skip-text Skip text embedding gap-fill
--skip-fusion Skip vector fusion
--dry-run Print actions but make no DB writes
"""
import argparse
import calendar
import logging
import sys
from concurrent.futures import ThreadPoolExecutor
from datetime import date, timedelta
from typing import List, Tuple
from database import MotionDatabase
_logger = logging.getLogger(__name__)
def _generate_windows(
start: date, end: date, granularity: str
) -> List[Tuple[str, str, str]]:
"""Return list of (window_id, start_str, end_str) tuples.
window_id format:
quarterly → "2024-Q1", "2024-Q2", …
annual → "2024"
"""
windows = []
cursor = date(start.year, start.month, 1)
if granularity == "annual":
cursor = date(start.year, 1, 1)
while cursor <= end:
year_end = date(cursor.year, 12, 31)
w_end = min(year_end, end)
windows.append((str(cursor.year), cursor.isoformat(), w_end.isoformat()))
cursor = date(cursor.year + 1, 1, 1)
else:
# quarterly
quarter_starts = {1: 1, 2: 4, 3: 7, 4: 10}
quarter_ends = {1: 3, 2: 6, 3: 9, 4: 12}
# Align cursor to quarter start
q = (cursor.month - 1) // 3 + 1
cursor = date(cursor.year, quarter_starts[q], 1)
while cursor <= end:
q = (cursor.month - 1) // 3 + 1
q_end_month = quarter_ends[q]
last_day = calendar.monthrange(cursor.year, q_end_month)[1]
q_end = date(cursor.year, q_end_month, last_day)
w_end = min(q_end, end)
window_id = f"{cursor.year}-Q{q}"
windows.append((window_id, cursor.isoformat(), w_end.isoformat()))
cursor = q_end + timedelta(days=1)
return windows
def run(args: argparse.Namespace) -> int:
"""Execute the pipeline. Returns exit code (0 = success)."""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
db_path = args.db_path
dry_run = args.dry_run
if dry_run:
_logger.info("DRY RUN — no writes will be made")
# Resolve date range
end_date = date.fromisoformat(args.end_date) if args.end_date else date.today()
start_date = (
date.fromisoformat(args.start_date)
if args.start_date
else end_date - timedelta(days=730)
)
_logger.info(
"Pipeline run: %s%s (%s windows), db=%s",
start_date,
end_date,
args.window_size,
db_path,
)
db = MotionDatabase(db_path)
# ── Phase 1: MP metadata ────────────────────────────────────────────────
if not args.skip_metadata:
_logger.info("Phase 1: fetching MP metadata from OData")
if not dry_run:
from pipeline.fetch_mp_metadata import fetch_mp_metadata
n = fetch_mp_metadata(db_path=db.db_path)
_logger.info(" mp_metadata: processed=%d", n)
else:
_logger.info(" [dry-run] would call fetch_mp_metadata(db)")
else:
_logger.info("Phase 1: skipped (--skip-metadata)")
# ── Phase 2: Extract MP votes ────────────────────────────────────────────
if not args.skip_extract:
_logger.info("Phase 2: extracting MP votes from voting_results")
if not dry_run:
from pipeline.extract_mp_votes import extract_mp_votes
result = extract_mp_votes(db_path=db.db_path)
_logger.info(
" mp_votes: inserted=%d motions_scanned=%d skipped=%d",
result["mp_rows_inserted"],
result["motions_scanned"],
result["motions_skipped"],
)
else:
_logger.info(" [dry-run] would call extract_mp_votes(db)")
else:
_logger.info("Phase 2: skipped (--skip-extract)")
# ── Phase 3: SVD per window ──────────────────────────────────────────────
if not args.skip_svd:
windows = _generate_windows(start_date, end_date, args.window_size)
_logger.info(
"Phase 3: SVD for %d windows (k=%d, parallel)", len(windows), args.svd_k
)
from pipeline.svd_pipeline import compute_svd_for_window
if dry_run:
for window_id, w_start, w_end in windows:
_logger.info(" [dry-run] would run SVD for window %s", window_id)
else:
# Compute all windows in parallel (numpy/scipy SVD releases the GIL).
# IMPORTANT: collect ALL results before writing — DuckDB rejects mixing
# read-only and read-write connections in the same process.
# The `with` block waits for all threads to finish before we exit it,
# ensuring all read-only connections are closed before writes begin.
futures = {}
max_workers = min(len(windows), (args.svd_workers or 4))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
for window_id, w_start, w_end in windows:
fut = pool.submit(
compute_svd_for_window,
db.db_path,
window_id,
w_start,
w_end,
args.svd_k,
)
futures[fut] = window_id
# All threads are done here — all read-only connections are closed.
# Now write results sequentially.
for fut, window_id in futures.items():
try:
result = fut.result()
except Exception as exc:
_logger.error(" window %s raised: %s", window_id, exc)
continue
if result["k_used"] == 0:
_logger.info(" window %s: no data, skipped", window_id)
continue
rows = result["mp_rows"] + result["motion_rows"]
db.batch_store_svd_vectors(window_id, rows)
_logger.info(
" window %s: k_used=%d stored_mp=%d stored_motion=%d",
window_id,
result["k_used"],
len(result["mp_rows"]),
len(result["motion_rows"]),
)
else:
_logger.info("Phase 3: skipped (--skip-svd)")
# ── Phase 4: Text embeddings ──────────────────────────────────────────────
if not args.skip_text:
_logger.info("Phase 4: ensuring text embeddings")
if not dry_run:
from pipeline.text_pipeline import ensure_text_embeddings
stored, existing, no_text, errors = ensure_text_embeddings(
db_path=db_path, model=args.text_model, batch_size=args.text_batch_size
)
_logger.info(
" embeddings: stored=%d existing=%d no_text=%d errors=%d",
stored,
existing,
no_text,
errors,
)
else:
_logger.info(" [dry-run] would call ensure_text_embeddings")
else:
_logger.info("Phase 4: skipped (--skip-text)")
# ── Phase 5: Fusion per window ────────────────────────────────────────────
if not args.skip_fusion:
windows = _generate_windows(start_date, end_date, args.window_size)
_logger.info("Phase 5: fusing vectors for %d windows", len(windows))
from pipeline.fusion import fuse_for_window
for window_id, _w_start, _w_end in windows:
if not dry_run:
result = fuse_for_window(
window_id=window_id,
db_path=db_path,
model=args.text_model,
)
_logger.info(
" window %s: fused=%d skipped_no_svd=%d skipped_no_text=%d errors=%d",
window_id,
result.get("inserted", 0),
result.get("skipped_missing_svd", 0),
result.get("skipped_missing_text", 0),
result.get("errors", 0),
)
else:
_logger.info(" [dry-run] would fuse window %s", window_id)
else:
_logger.info("Phase 5: skipped (--skip-fusion)")
_logger.info("Pipeline complete.")
return 0
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Parliamentary embedding pipeline orchestrator",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--db-path", default="data/motions.db", help="Path to DuckDB file"
)
parser.add_argument("--start-date", default=None, help="Window start YYYY-MM-DD")
parser.add_argument("--end-date", default=None, help="Window end YYYY-MM-DD")
parser.add_argument(
"--window-size",
choices=["quarterly", "annual"],
default="quarterly",
help="Time window granularity",
)
parser.add_argument("--svd-k", type=int, default=50, help="SVD dimensions")
parser.add_argument(
"--svd-workers",
type=int,
default=None,
help="Parallel workers for SVD (default: min(windows, 4))",
)
parser.add_argument(
"--text-model",
default=None,
help="Text embedding model (default: ai_provider default)",
)
parser.add_argument(
"--text-batch-size",
type=int,
default=200,
help="Number of texts per embedding API call (default: 200)",
)
parser.add_argument(
"--skip-metadata", action="store_true", help="Skip MP metadata fetch"
)
parser.add_argument(
"--skip-extract", action="store_true", help="Skip MP vote extraction"
)
parser.add_argument("--skip-svd", action="store_true", help="Skip SVD computation")
parser.add_argument(
"--skip-text", action="store_true", help="Skip text embedding gap-fill"
)
parser.add_argument("--skip-fusion", action="store_true", help="Skip vector fusion")
parser.add_argument(
"--dry-run",
action="store_true",
help="Print what would happen without writing anything",
)
return parser
if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
sys.exit(run(args))