You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
motief/agent_tools/database.py

220 lines
6.7 KiB

"""Database query primitives for agent operation.
Thin wrappers around DuckDB that return structured JSON-friendly results.
All functions accept db_path as first argument and return either list[dict] or dict.
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
def _connect(db_path: str, read_only: bool = True):
import duckdb
return duckdb.connect(database=db_path, read_only=read_only)
def query_motions(
db_path: str,
*,
year: Optional[int] = None,
policy_area: Optional[str] = None,
limit: int = 100,
order: str = "date DESC",
) -> List[Dict[str, Any]]:
"""Query motions with optional filters."""
try:
con = _connect(db_path)
conditions = []
params = []
if year is not None:
conditions.append("EXTRACT(YEAR FROM date) = ?")
params.append(year)
if policy_area is not None:
conditions.append("policy_area = ?")
params.append(policy_area)
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
sql = f"""
SELECT id, title, description, date, policy_area,
winning_margin, controversy_score, layman_explanation
FROM motions
{where_clause}
ORDER BY {order}
LIMIT ?
"""
params.append(limit)
result = con.execute(sql, params).fetchdf().to_dict("records")
con.close()
return result
except Exception:
logger.exception("query_motions failed")
return []
def query_votes(
db_path: str,
motion_id: int,
party: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Query vote counts for a motion, optionally filtered by party."""
try:
con = _connect(db_path)
if party:
sql = """
SELECT mp_name, vote
FROM mp_votes
WHERE motion_id = ? AND mp_name IN (
SELECT mp_name FROM mp_metadata WHERE party = ?
)
"""
result = con.execute(sql, (motion_id, party)).fetchdf().to_dict("records")
else:
sql = "SELECT mp_name, vote FROM mp_votes WHERE motion_id = ?"
result = con.execute(sql, (motion_id,)).fetchdf().to_dict("records")
con.close()
return result
except Exception:
logger.exception("query_votes failed")
return []
def query_svd_vectors(
db_path: str,
window_id: str,
entity_type: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Query SVD vectors for a window."""
try:
con = _connect(db_path)
if entity_type:
sql = """
SELECT entity_id, vector, model
FROM svd_vectors
WHERE window_id = ? AND entity_type = ?
"""
result = con.execute(sql, (window_id, entity_type)).fetchdf().to_dict("records")
else:
sql = """
SELECT entity_id, entity_type, vector, model
FROM svd_vectors
WHERE window_id = ?
"""
result = con.execute(sql, (window_id,)).fetchdf().to_dict("records")
con.close()
return result
except Exception:
logger.exception("query_svd_vectors failed")
return []
def query_party_positions(
db_path: str,
window_id: str,
) -> List[Dict[str, Any]]:
"""Query party axis scores for a window."""
try:
con = _connect(db_path)
# Check if party_axis_scores table exists
tables = con.execute(
"SELECT table_name FROM information_schema.tables WHERE table_name = 'party_axis_scores'"
).fetchall()
if tables:
result = con.execute(
"""
SELECT party, axis, score
FROM party_axis_scores
WHERE window_id = ?
""",
(window_id,),
).fetchdf().to_dict("records")
else:
# Fallback: compute from vectors
result = _compute_party_positions_from_vectors(con, window_id)
con.close()
return result
except Exception:
logger.exception("query_party_positions failed")
return []
def _compute_party_positions_from_vectors(con, window_id: str) -> List[Dict[str, Any]]:
"""Compute party positions from MP vectors when party_axis_scores doesn't exist."""
rows = con.execute(
"""
SELECT sv.entity_id, sv.vector, mm.party
FROM svd_vectors sv
JOIN mp_metadata mm ON sv.entity_id = mm.mp_name
WHERE sv.window_id = ? AND sv.entity_type = 'mp'
""",
(window_id,),
).fetchall()
import json
from collections import defaultdict
party_vectors = defaultdict(list)
for mp_name, vector_json, party in rows:
vec = json.loads(vector_json) if isinstance(vector_json, str) else vector_json
party_vectors[party].append(vec)
result = []
for party, vectors in party_vectors.items():
if not vectors:
continue
# Compute mean position across first 2 components
dim = len(vectors[0])
mean = [sum(v[i] for v in vectors) / len(vectors) for i in range(min(dim, 2))]
result.append({
"party": party,
"axis_1": mean[0] if len(mean) > 0 else 0.0,
"axis_2": mean[1] if len(mean) > 1 else 0.0,
})
return result
def query_pipeline_status(db_path: str) -> Dict[str, Any]:
"""Return pipeline freshness metrics."""
try:
con = _connect(db_path)
motion_count = con.execute("SELECT COUNT(*) FROM motions").fetchone()[0]
latest = con.execute("SELECT MAX(date) FROM motions").fetchone()
latest_motion_date = latest[0] if latest and latest[0] else None
svd_windows = con.execute(
"SELECT COUNT(DISTINCT window_id) FROM svd_vectors"
).fetchone()[0]
embedding_count = con.execute(
"SELECT COUNT(*) FROM svd_vectors WHERE entity_type = 'motion'"
).fetchone()[0]
con.close()
return {
"motion_count": motion_count,
"latest_motion_date": str(latest_motion_date) if latest_motion_date else None,
"svd_window_count": svd_windows,
"embedding_count": embedding_count,
"healthy": motion_count > 0 and svd_windows > 0,
}
except Exception:
logger.exception("query_pipeline_status failed")
return {
"motion_count": 0,
"latest_motion_date": None,
"svd_window_count": 0,
"embedding_count": 0,
"healthy": False,
"error": "Failed to query pipeline status",
}