You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
368 lines
11 KiB
368 lines
11 KiB
"""Database query primitives for agent operation.
|
|
|
|
Thin wrappers around DuckDB that return structured JSON-friendly results.
|
|
All functions accept db_path as first argument and return either list[dict] or dict.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _connect(db_path: str, read_only: bool = True):
|
|
import duckdb
|
|
|
|
return duckdb.connect(database=db_path, read_only=read_only)
|
|
|
|
|
|
def query_motions(
|
|
db_path: str,
|
|
*,
|
|
year: Optional[int] = None,
|
|
policy_area: Optional[str] = None,
|
|
limit: int = 100,
|
|
order: str = "date DESC",
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query motions with optional filters."""
|
|
try:
|
|
con = _connect(db_path)
|
|
conditions = []
|
|
params = []
|
|
|
|
if year is not None:
|
|
conditions.append("EXTRACT(YEAR FROM date) = ?")
|
|
params.append(year)
|
|
if policy_area is not None:
|
|
conditions.append("policy_area = ?")
|
|
params.append(policy_area)
|
|
|
|
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
|
|
sql = f"""
|
|
SELECT id, title, description, date, policy_area,
|
|
winning_margin, controversy_score, layman_explanation
|
|
FROM motions
|
|
{where_clause}
|
|
ORDER BY {order}
|
|
LIMIT ?
|
|
"""
|
|
params.append(limit)
|
|
|
|
result = con.execute(sql, params).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_motions failed")
|
|
return []
|
|
|
|
|
|
def query_votes(
|
|
db_path: str,
|
|
motion_id: int,
|
|
party: Optional[str] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query vote counts for a motion, optionally filtered by party."""
|
|
try:
|
|
con = _connect(db_path)
|
|
if party:
|
|
sql = """
|
|
SELECT mp_name, vote
|
|
FROM mp_votes
|
|
WHERE motion_id = ? AND mp_name IN (
|
|
SELECT mp_name FROM mp_metadata WHERE party = ?
|
|
)
|
|
"""
|
|
result = con.execute(sql, (motion_id, party)).fetchdf().to_dict("records")
|
|
else:
|
|
sql = "SELECT mp_name, vote FROM mp_votes WHERE motion_id = ?"
|
|
result = con.execute(sql, (motion_id,)).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_votes failed")
|
|
return []
|
|
|
|
|
|
def query_svd_vectors(
|
|
db_path: str,
|
|
window_id: str,
|
|
entity_type: Optional[str] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query SVD vectors for a window."""
|
|
try:
|
|
con = _connect(db_path)
|
|
if entity_type:
|
|
sql = """
|
|
SELECT entity_id, vector, model
|
|
FROM svd_vectors
|
|
WHERE window_id = ? AND entity_type = ?
|
|
"""
|
|
result = con.execute(sql, (window_id, entity_type)).fetchdf().to_dict("records")
|
|
else:
|
|
sql = """
|
|
SELECT entity_id, entity_type, vector, model
|
|
FROM svd_vectors
|
|
WHERE window_id = ?
|
|
"""
|
|
result = con.execute(sql, (window_id,)).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_svd_vectors failed")
|
|
return []
|
|
|
|
|
|
def query_party_positions(
|
|
db_path: str,
|
|
window_id: str,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query party axis scores for a window."""
|
|
try:
|
|
con = _connect(db_path)
|
|
# Check if party_axis_scores table exists
|
|
tables = con.execute(
|
|
"SELECT table_name FROM information_schema.tables WHERE table_name = 'party_axis_scores'"
|
|
).fetchall()
|
|
|
|
if tables:
|
|
result = con.execute(
|
|
"""
|
|
SELECT party, axis, score
|
|
FROM party_axis_scores
|
|
WHERE window_id = ?
|
|
""",
|
|
(window_id,),
|
|
).fetchdf().to_dict("records")
|
|
else:
|
|
# Fallback: compute from vectors
|
|
result = _compute_party_positions_from_vectors(con, window_id)
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_party_positions failed")
|
|
return []
|
|
|
|
|
|
def _compute_party_positions_from_vectors(con, window_id: str) -> List[Dict[str, Any]]:
|
|
"""Compute party positions from MP vectors when party_axis_scores doesn't exist."""
|
|
rows = con.execute(
|
|
"""
|
|
SELECT sv.entity_id, sv.vector, mm.party
|
|
FROM svd_vectors sv
|
|
JOIN mp_metadata mm ON sv.entity_id = mm.mp_name
|
|
WHERE sv.window_id = ? AND sv.entity_type = 'mp'
|
|
""",
|
|
(window_id,),
|
|
).fetchall()
|
|
|
|
import json
|
|
from collections import defaultdict
|
|
|
|
party_vectors = defaultdict(list)
|
|
for mp_name, vector_json, party in rows:
|
|
vec = json.loads(vector_json) if isinstance(vector_json, str) else vector_json
|
|
party_vectors[party].append(vec)
|
|
|
|
result = []
|
|
for party, vectors in party_vectors.items():
|
|
if not vectors:
|
|
continue
|
|
# Compute mean position across first 2 components
|
|
dim = len(vectors[0])
|
|
mean = [sum(v[i] for v in vectors) / len(vectors) for i in range(min(dim, 2))]
|
|
result.append({
|
|
"party": party,
|
|
"axis_1": mean[0] if len(mean) > 0 else 0.0,
|
|
"axis_2": mean[1] if len(mean) > 1 else 0.0,
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
def query_pipeline_status(db_path: str) -> Dict[str, Any]:
|
|
"""Return pipeline freshness metrics."""
|
|
try:
|
|
con = _connect(db_path)
|
|
|
|
motion_count = con.execute("SELECT COUNT(*) FROM motions").fetchone()[0]
|
|
|
|
latest = con.execute("SELECT MAX(date) FROM motions").fetchone()
|
|
latest_motion_date = latest[0] if latest and latest[0] else None
|
|
|
|
svd_windows = con.execute(
|
|
"SELECT COUNT(DISTINCT window_id) FROM svd_vectors"
|
|
).fetchone()[0]
|
|
|
|
embedding_count = con.execute(
|
|
"SELECT COUNT(*) FROM svd_vectors WHERE entity_type = 'motion'"
|
|
).fetchone()[0]
|
|
|
|
con.close()
|
|
|
|
return {
|
|
"motion_count": motion_count,
|
|
"latest_motion_date": str(latest_motion_date) if latest_motion_date else None,
|
|
"svd_window_count": svd_windows,
|
|
"embedding_count": embedding_count,
|
|
"motion_count": motion_count,
|
|
"svd_window_count": svd_windows,
|
|
}
|
|
except Exception:
|
|
logger.exception("query_pipeline_status failed")
|
|
return {
|
|
"motion_count": 0,
|
|
"latest_motion_date": None,
|
|
"svd_window_count": 0,
|
|
"embedding_count": 0,
|
|
"error": "Failed to query pipeline status",
|
|
}
|
|
|
|
|
|
def query_embeddings(
|
|
db_path: str,
|
|
*,
|
|
motion_id: Optional[int] = None,
|
|
model: Optional[str] = None,
|
|
limit: int = 100,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query fused embeddings for motions."""
|
|
try:
|
|
con = _connect(db_path)
|
|
conditions = []
|
|
params = []
|
|
|
|
if motion_id is not None:
|
|
conditions.append("motion_id = ?")
|
|
params.append(motion_id)
|
|
if model is not None:
|
|
conditions.append("model = ?")
|
|
params.append(model)
|
|
|
|
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
|
|
sql = f"""
|
|
SELECT motion_id, vector, model
|
|
FROM fused_embeddings
|
|
{where_clause}
|
|
LIMIT ?
|
|
"""
|
|
params.append(limit)
|
|
|
|
result = con.execute(sql, params).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_embeddings failed")
|
|
return []
|
|
|
|
|
|
def query_similar_motions(
|
|
db_path: str,
|
|
motion_id: int,
|
|
top_k: int = 10,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query top-k similar motions from similarity cache."""
|
|
try:
|
|
con = _connect(db_path)
|
|
result = con.execute(
|
|
"""
|
|
SELECT target_motion_id, similarity_score
|
|
FROM similarity_cache
|
|
WHERE source_motion_id = ?
|
|
ORDER BY similarity_score DESC
|
|
LIMIT ?
|
|
""",
|
|
(motion_id, top_k),
|
|
).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_similar_motions failed")
|
|
return []
|
|
|
|
|
|
def query_compass_positions(
|
|
db_path: str,
|
|
window_id: str,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query 2D PCA compass positions for MPs in a window."""
|
|
try:
|
|
con = _connect(db_path)
|
|
result = con.execute(
|
|
"""
|
|
SELECT sv.entity_id, sv.vector, mm.party
|
|
FROM svd_vectors sv
|
|
JOIN mp_metadata mm ON sv.entity_id = mm.mp_name
|
|
WHERE sv.window_id = ? AND sv.entity_type = 'mp'
|
|
""",
|
|
(window_id,),
|
|
).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_compass_positions failed")
|
|
return []
|
|
|
|
|
|
def create_motion(
|
|
db_path: str,
|
|
title: str,
|
|
description: str = "",
|
|
date: str = "",
|
|
policy_area: str = "",
|
|
) -> Dict[str, Any]:
|
|
"""Create a new motion record."""
|
|
try:
|
|
con = _connect(db_path, read_only=False)
|
|
con.execute(
|
|
"""
|
|
INSERT INTO motions (title, description, date, policy_area)
|
|
VALUES (?, ?, ?, ?)
|
|
""",
|
|
(title, description, date, policy_area),
|
|
)
|
|
con.close()
|
|
return {"created": True, "title": title}
|
|
except Exception:
|
|
logger.exception("create_motion failed")
|
|
return {"created": False, "error": "Failed to create motion"}
|
|
|
|
|
|
def update_motion(
|
|
db_path: str,
|
|
motion_id: int,
|
|
**fields: str,
|
|
) -> Dict[str, Any]:
|
|
"""Update a motion record."""
|
|
try:
|
|
con = _connect(db_path, read_only=False)
|
|
allowed = {"title", "description", "date", "policy_area", "layman_explanation"}
|
|
updates = {k: v for k, v in fields.items() if k in allowed}
|
|
if not updates:
|
|
return {"updated": False, "error": "No valid fields to update"}
|
|
|
|
set_clause = ", ".join(f"{k} = ?" for k in updates)
|
|
params = list(updates.values()) + [motion_id]
|
|
con.execute(
|
|
f"UPDATE motions SET {set_clause} WHERE id = ?",
|
|
params,
|
|
)
|
|
con.close()
|
|
return {"updated": True, "motion_id": motion_id, "fields": list(updates.keys())}
|
|
except Exception:
|
|
logger.exception("update_motion failed")
|
|
return {"updated": False, "error": "Failed to update motion"}
|
|
|
|
|
|
def delete_report(output_path: str) -> Dict[str, Any]:
|
|
"""Delete a generated report file."""
|
|
try:
|
|
import os
|
|
if os.path.exists(output_path):
|
|
os.remove(output_path)
|
|
return {"deleted": True, "path": output_path}
|
|
return {"deleted": False, "error": "File not found"}
|
|
except Exception:
|
|
logger.exception("delete_report failed")
|
|
return {"deleted": False, "error": "Failed to delete report"}
|
|
|