"""Database query primitives for agent operation. Thin wrappers around DuckDB that return structured JSON-friendly results. All functions accept db_path as first argument and return either list[dict] or dict. """ from __future__ import annotations import logging from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) def _connect(db_path: str, read_only: bool = True): import duckdb return duckdb.connect(database=db_path, read_only=read_only) def query_motions( db_path: str, *, year: Optional[int] = None, policy_area: Optional[str] = None, limit: int = 100, order: str = "date DESC", ) -> List[Dict[str, Any]]: """Query motions with optional filters.""" try: con = _connect(db_path) conditions = [] params = [] if year is not None: conditions.append("EXTRACT(YEAR FROM date) = ?") params.append(year) if policy_area is not None: conditions.append("policy_area = ?") params.append(policy_area) where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" sql = f""" SELECT id, title, description, date, policy_area, winning_margin, controversy_score, layman_explanation FROM motions {where_clause} ORDER BY {order} LIMIT ? """ params.append(limit) result = con.execute(sql, params).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_motions failed") return [] def query_votes( db_path: str, motion_id: int, party: Optional[str] = None, ) -> List[Dict[str, Any]]: """Query vote counts for a motion, optionally filtered by party.""" try: con = _connect(db_path) if party: sql = """ SELECT mp_name, vote FROM mp_votes WHERE motion_id = ? AND mp_name IN ( SELECT mp_name FROM mp_metadata WHERE party = ? ) """ result = con.execute(sql, (motion_id, party)).fetchdf().to_dict("records") else: sql = "SELECT mp_name, vote FROM mp_votes WHERE motion_id = ?" result = con.execute(sql, (motion_id,)).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_votes failed") return [] def query_svd_vectors( db_path: str, window_id: str, entity_type: Optional[str] = None, ) -> List[Dict[str, Any]]: """Query SVD vectors for a window.""" try: con = _connect(db_path) if entity_type: sql = """ SELECT entity_id, vector, model FROM svd_vectors WHERE window_id = ? AND entity_type = ? """ result = con.execute(sql, (window_id, entity_type)).fetchdf().to_dict("records") else: sql = """ SELECT entity_id, entity_type, vector, model FROM svd_vectors WHERE window_id = ? """ result = con.execute(sql, (window_id,)).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_svd_vectors failed") return [] def query_party_positions( db_path: str, window_id: str, ) -> List[Dict[str, Any]]: """Query party axis scores for a window.""" try: con = _connect(db_path) # Check if party_axis_scores table exists tables = con.execute( "SELECT table_name FROM information_schema.tables WHERE table_name = 'party_axis_scores'" ).fetchall() if tables: result = con.execute( """ SELECT party, axis, score FROM party_axis_scores WHERE window_id = ? """, (window_id,), ).fetchdf().to_dict("records") else: # Fallback: compute from vectors result = _compute_party_positions_from_vectors(con, window_id) con.close() return result except Exception: logger.exception("query_party_positions failed") return [] def _compute_party_positions_from_vectors(con, window_id: str) -> List[Dict[str, Any]]: """Compute party positions from MP vectors when party_axis_scores doesn't exist.""" rows = con.execute( """ SELECT sv.entity_id, sv.vector, mm.party FROM svd_vectors sv JOIN mp_metadata mm ON sv.entity_id = mm.mp_name WHERE sv.window_id = ? AND sv.entity_type = 'mp' """, (window_id,), ).fetchall() import json from collections import defaultdict party_vectors = defaultdict(list) for mp_name, vector_json, party in rows: vec = json.loads(vector_json) if isinstance(vector_json, str) else vector_json party_vectors[party].append(vec) result = [] for party, vectors in party_vectors.items(): if not vectors: continue # Compute mean position across first 2 components dim = len(vectors[0]) mean = [sum(v[i] for v in vectors) / len(vectors) for i in range(min(dim, 2))] result.append({ "party": party, "axis_1": mean[0] if len(mean) > 0 else 0.0, "axis_2": mean[1] if len(mean) > 1 else 0.0, }) return result def query_pipeline_status(db_path: str) -> Dict[str, Any]: """Return pipeline freshness metrics.""" try: con = _connect(db_path) motion_count = con.execute("SELECT COUNT(*) FROM motions").fetchone()[0] latest = con.execute("SELECT MAX(date) FROM motions").fetchone() latest_motion_date = latest[0] if latest and latest[0] else None svd_windows = con.execute( "SELECT COUNT(DISTINCT window_id) FROM svd_vectors" ).fetchone()[0] embedding_count = con.execute( "SELECT COUNT(*) FROM svd_vectors WHERE entity_type = 'motion'" ).fetchone()[0] con.close() return { "motion_count": motion_count, "latest_motion_date": str(latest_motion_date) if latest_motion_date else None, "svd_window_count": svd_windows, "embedding_count": embedding_count, "healthy": motion_count > 0 and svd_windows > 0, } except Exception: logger.exception("query_pipeline_status failed") return { "motion_count": 0, "latest_motion_date": None, "svd_window_count": 0, "embedding_count": 0, "healthy": False, "error": "Failed to query pipeline status", }