"""Database query primitives for agent operation. Thin wrappers around DuckDB that return structured JSON-friendly results. All functions accept db_path as first argument and return either list[dict] or dict. """ from __future__ import annotations import logging from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) def _connect(db_path: str, read_only: bool = True): import duckdb return duckdb.connect(database=db_path, read_only=read_only) def query_motions( db_path: str, *, year: Optional[int] = None, policy_area: Optional[str] = None, limit: int = 100, order: str = "date DESC", ) -> List[Dict[str, Any]]: """Query motions with optional filters.""" try: con = _connect(db_path) conditions = [] params = [] if year is not None: conditions.append("EXTRACT(YEAR FROM date) = ?") params.append(year) if policy_area is not None: conditions.append("policy_area = ?") params.append(policy_area) where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" sql = f""" SELECT id, title, description, date, policy_area, winning_margin, controversy_score, layman_explanation FROM motions {where_clause} ORDER BY {order} LIMIT ? """ params.append(limit) result = con.execute(sql, params).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_motions failed") return [] def query_votes( db_path: str, motion_id: int, party: Optional[str] = None, ) -> List[Dict[str, Any]]: """Query vote counts for a motion, optionally filtered by party.""" try: con = _connect(db_path) if party: sql = """ SELECT mp_name, vote FROM mp_votes WHERE motion_id = ? AND mp_name IN ( SELECT mp_name FROM mp_metadata WHERE party = ? ) """ result = con.execute(sql, (motion_id, party)).fetchdf().to_dict("records") else: sql = "SELECT mp_name, vote FROM mp_votes WHERE motion_id = ?" result = con.execute(sql, (motion_id,)).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_votes failed") return [] def query_svd_vectors( db_path: str, window_id: str, entity_type: Optional[str] = None, ) -> List[Dict[str, Any]]: """Query SVD vectors for a window.""" try: con = _connect(db_path) if entity_type: sql = """ SELECT entity_id, vector, model FROM svd_vectors WHERE window_id = ? AND entity_type = ? """ result = con.execute(sql, (window_id, entity_type)).fetchdf().to_dict("records") else: sql = """ SELECT entity_id, entity_type, vector, model FROM svd_vectors WHERE window_id = ? """ result = con.execute(sql, (window_id,)).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_svd_vectors failed") return [] def query_party_positions( db_path: str, window_id: str, ) -> List[Dict[str, Any]]: """Query party axis scores for a window.""" try: con = _connect(db_path) # Check if party_axis_scores table exists tables = con.execute( "SELECT table_name FROM information_schema.tables WHERE table_name = 'party_axis_scores'" ).fetchall() if tables: result = con.execute( """ SELECT party, axis, score FROM party_axis_scores WHERE window_id = ? """, (window_id,), ).fetchdf().to_dict("records") else: # Fallback: compute from vectors result = _compute_party_positions_from_vectors(con, window_id) con.close() return result except Exception: logger.exception("query_party_positions failed") return [] def _compute_party_positions_from_vectors(con, window_id: str) -> List[Dict[str, Any]]: """Compute party positions from MP vectors when party_axis_scores doesn't exist.""" rows = con.execute( """ SELECT sv.entity_id, sv.vector, mm.party FROM svd_vectors sv JOIN mp_metadata mm ON sv.entity_id = mm.mp_name WHERE sv.window_id = ? AND sv.entity_type = 'mp' """, (window_id,), ).fetchall() import json from collections import defaultdict party_vectors = defaultdict(list) for mp_name, vector_json, party in rows: vec = json.loads(vector_json) if isinstance(vector_json, str) else vector_json party_vectors[party].append(vec) result = [] for party, vectors in party_vectors.items(): if not vectors: continue # Compute mean position across first 2 components dim = len(vectors[0]) mean = [sum(v[i] for v in vectors) / len(vectors) for i in range(min(dim, 2))] result.append({ "party": party, "axis_1": mean[0] if len(mean) > 0 else 0.0, "axis_2": mean[1] if len(mean) > 1 else 0.0, }) return result def query_pipeline_status(db_path: str) -> Dict[str, Any]: """Return pipeline freshness metrics.""" try: con = _connect(db_path) motion_count = con.execute("SELECT COUNT(*) FROM motions").fetchone()[0] latest = con.execute("SELECT MAX(date) FROM motions").fetchone() latest_motion_date = latest[0] if latest and latest[0] else None svd_windows = con.execute( "SELECT COUNT(DISTINCT window_id) FROM svd_vectors" ).fetchone()[0] embedding_count = con.execute( "SELECT COUNT(*) FROM svd_vectors WHERE entity_type = 'motion'" ).fetchone()[0] con.close() return { "motion_count": motion_count, "latest_motion_date": str(latest_motion_date) if latest_motion_date else None, "svd_window_count": svd_windows, "embedding_count": embedding_count, "motion_count": motion_count, "svd_window_count": svd_windows, } except Exception: logger.exception("query_pipeline_status failed") return { "motion_count": 0, "latest_motion_date": None, "svd_window_count": 0, "embedding_count": 0, "error": "Failed to query pipeline status", } def query_embeddings( db_path: str, *, motion_id: Optional[int] = None, model: Optional[str] = None, limit: int = 100, ) -> List[Dict[str, Any]]: """Query fused embeddings for motions.""" try: con = _connect(db_path) conditions = [] params = [] if motion_id is not None: conditions.append("motion_id = ?") params.append(motion_id) if model is not None: conditions.append("model = ?") params.append(model) where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" sql = f""" SELECT motion_id, vector, model FROM fused_embeddings {where_clause} LIMIT ? """ params.append(limit) result = con.execute(sql, params).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_embeddings failed") return [] def query_similar_motions( db_path: str, motion_id: int, top_k: int = 10, ) -> List[Dict[str, Any]]: """Query top-k similar motions from similarity cache.""" try: con = _connect(db_path) result = con.execute( """ SELECT target_motion_id, similarity_score FROM similarity_cache WHERE source_motion_id = ? ORDER BY similarity_score DESC LIMIT ? """, (motion_id, top_k), ).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_similar_motions failed") return [] def query_compass_positions( db_path: str, window_id: str, ) -> List[Dict[str, Any]]: """Query 2D PCA compass positions for MPs in a window.""" try: con = _connect(db_path) result = con.execute( """ SELECT sv.entity_id, sv.vector, mm.party FROM svd_vectors sv JOIN mp_metadata mm ON sv.entity_id = mm.mp_name WHERE sv.window_id = ? AND sv.entity_type = 'mp' """, (window_id,), ).fetchdf().to_dict("records") con.close() return result except Exception: logger.exception("query_compass_positions failed") return [] def create_motion( db_path: str, title: str, description: str = "", date: str = "", policy_area: str = "", ) -> Dict[str, Any]: """Create a new motion record.""" try: con = _connect(db_path, read_only=False) con.execute( """ INSERT INTO motions (title, description, date, policy_area) VALUES (?, ?, ?, ?) """, (title, description, date, policy_area), ) con.close() return {"created": True, "title": title} except Exception: logger.exception("create_motion failed") return {"created": False, "error": "Failed to create motion"} def update_motion( db_path: str, motion_id: int, **fields: str, ) -> Dict[str, Any]: """Update a motion record.""" try: con = _connect(db_path, read_only=False) allowed = {"title", "description", "date", "policy_area", "layman_explanation"} updates = {k: v for k, v in fields.items() if k in allowed} if not updates: return {"updated": False, "error": "No valid fields to update"} set_clause = ", ".join(f"{k} = ?" for k in updates) params = list(updates.values()) + [motion_id] con.execute( f"UPDATE motions SET {set_clause} WHERE id = ?", params, ) con.close() return {"updated": True, "motion_id": motion_id, "fields": list(updates.keys())} except Exception: logger.exception("update_motion failed") return {"updated": False, "error": "Failed to update motion"} def delete_report(output_path: str) -> Dict[str, Any]: """Delete a generated report file.""" try: import os if os.path.exists(output_path): os.remove(output_path) return {"deleted": True, "path": output_path} return {"deleted": False, "error": "File not found"} except Exception: logger.exception("delete_report failed") return {"deleted": False, "error": "Failed to delete report"}