diff --git a/agent_tools/__init__.py b/agent_tools/__init__.py index 5c1ad3e..d824617 100644 --- a/agent_tools/__init__.py +++ b/agent_tools/__init__.py @@ -1 +1,119 @@ -"""Agent tools for Stemwijzer — atomic primitives for agent operation.""" +"""Agent tools for Stemwijzer — atomic primitives for agent operation. + +Import individual modules or use `list_tools()` for runtime discovery. +""" + +from __future__ import annotations + +from agent_tools.analysis import ( + analyze_axis_stability, + analyze_party_shift, + validate_svd_labels, +) +from agent_tools.content import ( + check_embedding_quality, + suggest_svd_label, + validate_layman_explanations, + validate_motion_coverage, +) +from agent_tools.context import ( + append_context_note, + build_context, + render_context_markdown, +) +from agent_tools.database import ( + create_motion, + delete_report, + query_compass_positions, + query_embeddings, + query_motions, + query_party_positions, + query_pipeline_status, + query_similar_motions, + query_svd_vectors, + query_votes, + update_motion, +) +from agent_tools.pipeline import ( + pipeline_check_health, + pipeline_get_logs, + pipeline_run_full, + pipeline_run_stage, + pipeline_validate_output, +) +from agent_tools.reports import generate_report + +__all__ = [ + # Database + "query_motions", + "query_votes", + "query_svd_vectors", + "query_party_positions", + "query_pipeline_status", + "query_embeddings", + "query_similar_motions", + "query_compass_positions", + "create_motion", + "update_motion", + "delete_report", + # Pipeline + "pipeline_run_stage", + "pipeline_run_full", + "pipeline_check_health", + "pipeline_get_logs", + "pipeline_validate_output", + # Analysis + "analyze_party_shift", + "analyze_axis_stability", + "validate_svd_labels", + # Content + "validate_motion_coverage", + "validate_layman_explanations", + "suggest_svd_label", + "check_embedding_quality", + # Reports + "generate_report", + # Context + "build_context", + "render_context_markdown", + "append_context_note", + # Discovery + "list_tools", +] + + +def list_tools() -> list[dict[str, str]]: + """Return a list of all available agent tools with signatures and descriptions. + + Useful for runtime capability discovery and prompt injection. + """ + return [ + {"name": "query_motions", "signature": "query_motions(db_path, limit=100, policy_area=None, start_date=None, end_date=None)", "description": "Query motions from the database with optional filters."}, + {"name": "query_votes", "signature": "query_votes(db_path, motion_id=None, party=None)", "description": "Query vote counts or individual votes."}, + {"name": "query_svd_vectors", "signature": "query_svd_vectors(db_path, window_id, entity_type='motion')", "description": "Query SVD vectors for a window and entity type."}, + {"name": "query_party_positions", "signature": "query_party_positions(db_path, window_id='current_parliament')", "description": "Query party axis positions for a window."}, + {"name": "query_pipeline_status", "signature": "query_pipeline_status(db_path)", "description": "Query pipeline freshness and coverage metrics."}, + {"name": "query_embeddings", "signature": "query_embeddings(db_path, motion_id=None, model=None, limit=100)", "description": "Query text/fused embeddings."}, + {"name": "query_similar_motions", "signature": "query_similar_motions(db_path, motion_id, top_k=10)", "description": "Query similar motions from similarity cache."}, + {"name": "query_compass_positions", "signature": "query_compass_positions(db_path, window_id='current_parliament')", "description": "Query 2D compass positions for parties/MPs."}, + {"name": "create_motion", "signature": "create_motion(db_path, title, description, date, policy_area='General', voting_results='[]')", "description": "Insert a new motion into the database."}, + {"name": "update_motion", "signature": "update_motion(db_path, motion_id, **fields)", "description": "Update fields of an existing motion."}, + {"name": "delete_report", "signature": "delete_report(output_path)", "description": "Delete a generated report file."}, + {"name": "pipeline_run_stage", "signature": "pipeline_run_stage(db_path, stage, window_id, dry_run=False)", "description": "Run a single pipeline stage."}, + {"name": "pipeline_run_full", "signature": "pipeline_run_full(db_path, dry_run=False)", "description": "Run the full pipeline end-to-end."}, + {"name": "pipeline_check_health", "signature": "pipeline_check_health(db_path)", "description": "Run health checks and return report."}, + {"name": "pipeline_get_logs", "signature": "pipeline_get_logs(stage, lines=50)", "description": "Retrieve recent log output for a stage."}, + {"name": "pipeline_validate_output", "signature": "pipeline_validate_output(db_path, stage)", "description": "Validate that a stage produced expected output."}, + {"name": "analyze_party_shift", "signature": "analyze_party_shift(db_path, party, window_start, window_end)", "description": "Compute party position shift between two windows."}, + {"name": "analyze_axis_stability", "signature": "analyze_axis_stability(db_path, component, windows)", "description": "Compute axis stability across windows."}, + {"name": "validate_svd_labels", "signature": "validate_svd_labels(db_path, component)", "description": "Compare SVD theme labels to actual party positions."}, + {"name": "validate_motion_coverage", "signature": "validate_motion_coverage(db_path, start_date, end_date)", "description": "Check motion coverage for a date range."}, + {"name": "validate_layman_explanations", "signature": "validate_layman_explanations(db_path, sample_size=50)", "description": "Sample motions and check explanation quality."}, + {"name": "suggest_svd_label", "signature": "suggest_svd_label(db_path, component, top_n=10)", "description": "Suggest a label based on top/bottom motions."}, + {"name": "check_embedding_quality", "signature": "check_embedding_quality(db_path, window_id, healthy_threshold=0.8)", "description": "Check embedding coverage for a window."}, + {"name": "generate_report", "signature": "generate_report(db_path, report_type, parameters, output_path)", "description": "Generate a markdown report."}, + {"name": "build_context", "signature": "build_context(db_path)", "description": "Build runtime context dict for the agent."}, + {"name": "render_context_markdown", "signature": "render_context_markdown(db_path)", "description": "Render context as markdown for prompt injection."}, + {"name": "append_context_note", "signature": "append_context_note(note)", "description": "Append a note to the accumulated agent knowledge."}, + {"name": "list_tools", "signature": "list_tools()", "description": "Return a list of all available agent tools."}, + ] diff --git a/agent_tools/content.py b/agent_tools/content.py index 9d03942..665b9db 100644 --- a/agent_tools/content.py +++ b/agent_tools/content.py @@ -157,9 +157,14 @@ def suggest_svd_label( def check_embedding_quality( db_path: str, window_id: str, + healthy_threshold: float = 0.8, ) -> Dict[str, Any]: """Check embedding coverage and quality for a window. + Args: + healthy_threshold: Coverage ratio above which embeddings are considered healthy. + Defaults to 0.8; override via prompt for different quality bars. + Returns coverage stats for fused embeddings. """ try: @@ -176,7 +181,8 @@ def check_embedding_quality( "total_motions": total_motions, "with_embeddings": with_embeddings, "coverage": coverage, - "healthy": coverage > 0.8, + "healthy": coverage > healthy_threshold, + "healthy_threshold": healthy_threshold, } except Exception as e: logger.exception("check_embedding_quality failed") diff --git a/agent_tools/database.py b/agent_tools/database.py index ffefb6e..3319e56 100644 --- a/agent_tools/database.py +++ b/agent_tools/database.py @@ -206,7 +206,8 @@ def query_pipeline_status(db_path: str) -> Dict[str, Any]: "latest_motion_date": str(latest_motion_date) if latest_motion_date else None, "svd_window_count": svd_windows, "embedding_count": embedding_count, - "healthy": motion_count > 0 and svd_windows > 0, + "motion_count": motion_count, + "svd_window_count": svd_windows, } except Exception: logger.exception("query_pipeline_status failed") @@ -215,6 +216,153 @@ def query_pipeline_status(db_path: str) -> Dict[str, Any]: "latest_motion_date": None, "svd_window_count": 0, "embedding_count": 0, - "healthy": False, "error": "Failed to query pipeline status", } + + +def query_embeddings( + db_path: str, + *, + motion_id: Optional[int] = None, + model: Optional[str] = None, + limit: int = 100, +) -> List[Dict[str, Any]]: + """Query fused embeddings for motions.""" + try: + con = _connect(db_path) + conditions = [] + params = [] + + if motion_id is not None: + conditions.append("motion_id = ?") + params.append(motion_id) + if model is not None: + conditions.append("model = ?") + params.append(model) + + where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" + sql = f""" + SELECT motion_id, vector, model + FROM fused_embeddings + {where_clause} + LIMIT ? + """ + params.append(limit) + + result = con.execute(sql, params).fetchdf().to_dict("records") + con.close() + return result + except Exception: + logger.exception("query_embeddings failed") + return [] + + +def query_similar_motions( + db_path: str, + motion_id: int, + top_k: int = 10, +) -> List[Dict[str, Any]]: + """Query top-k similar motions from similarity cache.""" + try: + con = _connect(db_path) + result = con.execute( + """ + SELECT target_motion_id, similarity_score + FROM similarity_cache + WHERE source_motion_id = ? + ORDER BY similarity_score DESC + LIMIT ? + """, + (motion_id, top_k), + ).fetchdf().to_dict("records") + con.close() + return result + except Exception: + logger.exception("query_similar_motions failed") + return [] + + +def query_compass_positions( + db_path: str, + window_id: str, +) -> List[Dict[str, Any]]: + """Query 2D PCA compass positions for MPs in a window.""" + try: + con = _connect(db_path) + result = con.execute( + """ + SELECT sv.entity_id, sv.vector, mm.party + FROM svd_vectors sv + JOIN mp_metadata mm ON sv.entity_id = mm.mp_name + WHERE sv.window_id = ? AND sv.entity_type = 'mp' + """, + (window_id,), + ).fetchdf().to_dict("records") + con.close() + return result + except Exception: + logger.exception("query_compass_positions failed") + return [] + + +def create_motion( + db_path: str, + title: str, + description: str = "", + date: str = "", + policy_area: str = "", +) -> Dict[str, Any]: + """Create a new motion record.""" + try: + con = _connect(db_path, read_only=False) + con.execute( + """ + INSERT INTO motions (title, description, date, policy_area) + VALUES (?, ?, ?, ?) + """, + (title, description, date, policy_area), + ) + con.close() + return {"created": True, "title": title} + except Exception: + logger.exception("create_motion failed") + return {"created": False, "error": "Failed to create motion"} + + +def update_motion( + db_path: str, + motion_id: int, + **fields: str, +) -> Dict[str, Any]: + """Update a motion record.""" + try: + con = _connect(db_path, read_only=False) + allowed = {"title", "description", "date", "policy_area", "layman_explanation"} + updates = {k: v for k, v in fields.items() if k in allowed} + if not updates: + return {"updated": False, "error": "No valid fields to update"} + + set_clause = ", ".join(f"{k} = ?" for k in updates) + params = list(updates.values()) + [motion_id] + con.execute( + f"UPDATE motions SET {set_clause} WHERE id = ?", + params, + ) + con.close() + return {"updated": True, "motion_id": motion_id, "fields": list(updates.keys())} + except Exception: + logger.exception("update_motion failed") + return {"updated": False, "error": "Failed to update motion"} + + +def delete_report(output_path: str) -> Dict[str, Any]: + """Delete a generated report file.""" + try: + import os + if os.path.exists(output_path): + os.remove(output_path) + return {"deleted": True, "path": output_path} + return {"deleted": False, "error": "File not found"} + except Exception: + logger.exception("delete_report failed") + return {"deleted": False, "error": "Failed to delete report"} diff --git a/tests/agent_tools/test_content_tools.py b/tests/agent_tools/test_content_tools.py index ac07dab..43edc43 100644 --- a/tests/agent_tools/test_content_tools.py +++ b/tests/agent_tools/test_content_tools.py @@ -42,3 +42,12 @@ class TestCheckEmbeddingQuality: result = check_embedding_quality(tmp_duckdb_path, window_id="current_parliament") assert isinstance(result, dict) assert "coverage" in result or "error" in result + + def test_parameterized_threshold(self, tmp_duckdb_path): + from agent_tools.content import check_embedding_quality + + result = check_embedding_quality( + tmp_duckdb_path, window_id="current_parliament", healthy_threshold=0.5 + ) + assert isinstance(result, dict) + assert result.get("healthy_threshold") == 0.5 diff --git a/tests/agent_tools/test_database_tools.py b/tests/agent_tools/test_database_tools.py index 074f709..34efee5 100644 --- a/tests/agent_tools/test_database_tools.py +++ b/tests/agent_tools/test_database_tools.py @@ -73,3 +73,49 @@ class TestQueryPipelineStatus: assert "motion_count" in result assert "latest_motion_date" in result assert "svd_window_count" in result + + +class TestCrudTools: + def test_create_motion_returns_id(self, tmp_duckdb_path): + from agent_tools.database import create_motion + + result = create_motion( + tmp_duckdb_path, + title="Test Motion", + description="A test motion", + date="2024-06-01", + policy_area="Test", + ) + assert isinstance(result, dict) + assert "motion_id" in result or "error" in result + + def test_update_motion_changes_field(self, tmp_duckdb_path): + from agent_tools.database import create_motion, update_motion + + created = create_motion( + tmp_duckdb_path, + title="Original", + description="Original desc", + date="2024-06-01", + ) + if "error" in created: + pytest.skip("create_motion not supported by schema") + motion_id = created["motion_id"] + + result = update_motion( + tmp_duckdb_path, + motion_id=motion_id, + title="Updated", + ) + assert isinstance(result, dict) + assert "updated" in result or "error" in result + + def test_delete_report_removes_file(self, tmp_path): + from agent_tools.database import delete_report + + report_path = tmp_path / "test_report.md" + report_path.write_text("# Test Report\n") + + result = delete_report(str(report_path)) + assert result.get("deleted") is True + assert not report_path.exists() diff --git a/tests/agent_tools/test_package.py b/tests/agent_tools/test_package.py new file mode 100644 index 0000000..8ed6656 --- /dev/null +++ b/tests/agent_tools/test_package.py @@ -0,0 +1,39 @@ +"""Tests for agent_tools package-level utilities.""" + +import pytest + + +class TestListTools: + def test_returns_tool_list(self): + from agent_tools import list_tools + + result = list_tools() + assert isinstance(result, list) + assert len(result) > 0 + + names = {t["name"] for t in result} + assert "query_motions" in names + assert "pipeline_check_health" in names + assert "generate_report" in names + assert "list_tools" in names + + def test_each_tool_has_required_fields(self): + from agent_tools import list_tools + + result = list_tools() + for tool in result: + assert "name" in tool + assert "signature" in tool + assert "description" in tool + + +class TestAllExports: + def test_query_motions_importable(self): + from agent_tools import query_motions + + assert callable(query_motions) + + def test_list_tools_importable(self): + from agent_tools import list_tools + + assert callable(list_tools)