You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
220 lines
6.7 KiB
220 lines
6.7 KiB
"""Database query primitives for agent operation.
|
|
|
|
Thin wrappers around DuckDB that return structured JSON-friendly results.
|
|
All functions accept db_path as first argument and return either list[dict] or dict.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _connect(db_path: str, read_only: bool = True):
|
|
import duckdb
|
|
|
|
return duckdb.connect(database=db_path, read_only=read_only)
|
|
|
|
|
|
def query_motions(
|
|
db_path: str,
|
|
*,
|
|
year: Optional[int] = None,
|
|
policy_area: Optional[str] = None,
|
|
limit: int = 100,
|
|
order: str = "date DESC",
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query motions with optional filters."""
|
|
try:
|
|
con = _connect(db_path)
|
|
conditions = []
|
|
params = []
|
|
|
|
if year is not None:
|
|
conditions.append("EXTRACT(YEAR FROM date) = ?")
|
|
params.append(year)
|
|
if policy_area is not None:
|
|
conditions.append("policy_area = ?")
|
|
params.append(policy_area)
|
|
|
|
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
|
|
sql = f"""
|
|
SELECT id, title, description, date, policy_area,
|
|
winning_margin, controversy_score, layman_explanation
|
|
FROM motions
|
|
{where_clause}
|
|
ORDER BY {order}
|
|
LIMIT ?
|
|
"""
|
|
params.append(limit)
|
|
|
|
result = con.execute(sql, params).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_motions failed")
|
|
return []
|
|
|
|
|
|
def query_votes(
|
|
db_path: str,
|
|
motion_id: int,
|
|
party: Optional[str] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query vote counts for a motion, optionally filtered by party."""
|
|
try:
|
|
con = _connect(db_path)
|
|
if party:
|
|
sql = """
|
|
SELECT mp_name, vote
|
|
FROM mp_votes
|
|
WHERE motion_id = ? AND mp_name IN (
|
|
SELECT mp_name FROM mp_metadata WHERE party = ?
|
|
)
|
|
"""
|
|
result = con.execute(sql, (motion_id, party)).fetchdf().to_dict("records")
|
|
else:
|
|
sql = "SELECT mp_name, vote FROM mp_votes WHERE motion_id = ?"
|
|
result = con.execute(sql, (motion_id,)).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_votes failed")
|
|
return []
|
|
|
|
|
|
def query_svd_vectors(
|
|
db_path: str,
|
|
window_id: str,
|
|
entity_type: Optional[str] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query SVD vectors for a window."""
|
|
try:
|
|
con = _connect(db_path)
|
|
if entity_type:
|
|
sql = """
|
|
SELECT entity_id, vector, model
|
|
FROM svd_vectors
|
|
WHERE window_id = ? AND entity_type = ?
|
|
"""
|
|
result = con.execute(sql, (window_id, entity_type)).fetchdf().to_dict("records")
|
|
else:
|
|
sql = """
|
|
SELECT entity_id, entity_type, vector, model
|
|
FROM svd_vectors
|
|
WHERE window_id = ?
|
|
"""
|
|
result = con.execute(sql, (window_id,)).fetchdf().to_dict("records")
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_svd_vectors failed")
|
|
return []
|
|
|
|
|
|
def query_party_positions(
|
|
db_path: str,
|
|
window_id: str,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Query party axis scores for a window."""
|
|
try:
|
|
con = _connect(db_path)
|
|
# Check if party_axis_scores table exists
|
|
tables = con.execute(
|
|
"SELECT table_name FROM information_schema.tables WHERE table_name = 'party_axis_scores'"
|
|
).fetchall()
|
|
|
|
if tables:
|
|
result = con.execute(
|
|
"""
|
|
SELECT party, axis, score
|
|
FROM party_axis_scores
|
|
WHERE window_id = ?
|
|
""",
|
|
(window_id,),
|
|
).fetchdf().to_dict("records")
|
|
else:
|
|
# Fallback: compute from vectors
|
|
result = _compute_party_positions_from_vectors(con, window_id)
|
|
con.close()
|
|
return result
|
|
except Exception:
|
|
logger.exception("query_party_positions failed")
|
|
return []
|
|
|
|
|
|
def _compute_party_positions_from_vectors(con, window_id: str) -> List[Dict[str, Any]]:
|
|
"""Compute party positions from MP vectors when party_axis_scores doesn't exist."""
|
|
rows = con.execute(
|
|
"""
|
|
SELECT sv.entity_id, sv.vector, mm.party
|
|
FROM svd_vectors sv
|
|
JOIN mp_metadata mm ON sv.entity_id = mm.mp_name
|
|
WHERE sv.window_id = ? AND sv.entity_type = 'mp'
|
|
""",
|
|
(window_id,),
|
|
).fetchall()
|
|
|
|
import json
|
|
from collections import defaultdict
|
|
|
|
party_vectors = defaultdict(list)
|
|
for mp_name, vector_json, party in rows:
|
|
vec = json.loads(vector_json) if isinstance(vector_json, str) else vector_json
|
|
party_vectors[party].append(vec)
|
|
|
|
result = []
|
|
for party, vectors in party_vectors.items():
|
|
if not vectors:
|
|
continue
|
|
# Compute mean position across first 2 components
|
|
dim = len(vectors[0])
|
|
mean = [sum(v[i] for v in vectors) / len(vectors) for i in range(min(dim, 2))]
|
|
result.append({
|
|
"party": party,
|
|
"axis_1": mean[0] if len(mean) > 0 else 0.0,
|
|
"axis_2": mean[1] if len(mean) > 1 else 0.0,
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
def query_pipeline_status(db_path: str) -> Dict[str, Any]:
|
|
"""Return pipeline freshness metrics."""
|
|
try:
|
|
con = _connect(db_path)
|
|
|
|
motion_count = con.execute("SELECT COUNT(*) FROM motions").fetchone()[0]
|
|
|
|
latest = con.execute("SELECT MAX(date) FROM motions").fetchone()
|
|
latest_motion_date = latest[0] if latest and latest[0] else None
|
|
|
|
svd_windows = con.execute(
|
|
"SELECT COUNT(DISTINCT window_id) FROM svd_vectors"
|
|
).fetchone()[0]
|
|
|
|
embedding_count = con.execute(
|
|
"SELECT COUNT(*) FROM svd_vectors WHERE entity_type = 'motion'"
|
|
).fetchone()[0]
|
|
|
|
con.close()
|
|
|
|
return {
|
|
"motion_count": motion_count,
|
|
"latest_motion_date": str(latest_motion_date) if latest_motion_date else None,
|
|
"svd_window_count": svd_windows,
|
|
"embedding_count": embedding_count,
|
|
"healthy": motion_count > 0 and svd_windows > 0,
|
|
}
|
|
except Exception:
|
|
logger.exception("query_pipeline_status failed")
|
|
return {
|
|
"motion_count": 0,
|
|
"latest_motion_date": None,
|
|
"svd_window_count": 0,
|
|
"embedding_count": 0,
|
|
"healthy": False,
|
|
"error": "Failed to query pipeline status",
|
|
}
|
|
|