You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/database.py

1022 lines
36 KiB

# database.py (final working version)
try:
import duckdb
except Exception: # pragma: no cover - environment may not have duckdb installed
duckdb = None
import json
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from config import config
import logging
_logger = logging.getLogger(__name__)
class MotionDatabase:
def __init__(self, db_path: str = config.DATABASE_PATH):
self.db_path = db_path
# If duckdb is not available, operate in lightweight file-backed mode
self._file_mode = duckdb is None
self._init_database()
def _init_database(self):
"""Initialize database with required tables"""
# Create directory if it doesn't exist
import os
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
# If duckdb isn't available in this environment, create lightweight
# JSON-backed files to allow tests to run without the duckdb dependency.
if duckdb is None:
# create simple JSON files representing embeddings and similarity cache
emb_file = f"{self.db_path}.embeddings.json"
sim_file = f"{self.db_path}.similarity_cache.json"
for p in (emb_file, sim_file):
if not os.path.exists(p):
with open(p, "w", encoding="utf-8") as fh:
fh.write("[]")
return
conn = duckdb.connect(self.db_path)
# Create sequence for auto-incrementing IDs
try:
conn.execute("CREATE SEQUENCE IF NOT EXISTS motions_id_seq START 1")
except:
pass
# Create tables with proper ID handling
conn.execute("""
CREATE TABLE IF NOT EXISTS motions (
id INTEGER DEFAULT nextval('motions_id_seq'),
title TEXT NOT NULL,
description TEXT,
date DATE,
policy_area TEXT,
voting_results JSON,
winning_margin FLOAT,
controversy_score FLOAT,
layman_explanation TEXT,
url TEXT UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS user_sessions (
session_id TEXT PRIMARY KEY,
user_votes JSON,
completed_motions INTEGER DEFAULT 0,
total_motions INTEGER DEFAULT 10,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS party_results (
session_id TEXT,
party_name TEXT,
agreement_percentage FLOAT,
agreed_motions JSON,
disagreed_motions JSON,
PRIMARY KEY (session_id, party_name)
)
""")
# New pipeline tables
conn.execute("""
CREATE SEQUENCE IF NOT EXISTS mp_votes_id_seq START 1
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS mp_votes (
id INTEGER DEFAULT nextval('mp_votes_id_seq'),
motion_id INTEGER NOT NULL,
mp_name TEXT NOT NULL,
party TEXT,
vote TEXT NOT NULL,
date DATE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS mp_metadata (
mp_name TEXT PRIMARY KEY,
party TEXT,
van DATE,
tot_en_met DATE,
persoon_id TEXT
)
""")
conn.execute("""
CREATE SEQUENCE IF NOT EXISTS svd_vectors_id_seq START 1
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS svd_vectors (
id INTEGER DEFAULT nextval('svd_vectors_id_seq'),
window_id TEXT NOT NULL,
entity_type TEXT NOT NULL,
entity_id TEXT NOT NULL,
vector JSON NOT NULL,
model TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
""")
conn.execute("""
CREATE SEQUENCE IF NOT EXISTS fused_embeddings_id_seq START 1
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS fused_embeddings (
id INTEGER DEFAULT nextval('fused_embeddings_id_seq'),
motion_id INTEGER NOT NULL,
window_id TEXT NOT NULL,
vector JSON NOT NULL,
svd_dims INTEGER NOT NULL,
text_dims INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
""")
# Embeddings table for raw text embeddings
conn.execute("""
CREATE SEQUENCE IF NOT EXISTS embeddings_id_seq START 1
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER DEFAULT nextval('embeddings_id_seq'),
motion_id INTEGER NOT NULL,
model TEXT,
vector JSON NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
""")
# Similarity cache table for precomputed neighbors
conn.execute("""
CREATE SEQUENCE IF NOT EXISTS similarity_cache_id_seq START 1
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS similarity_cache (
id INTEGER DEFAULT nextval('similarity_cache_id_seq'),
source_motion_id INTEGER NOT NULL,
target_motion_id INTEGER NOT NULL,
score REAL NOT NULL,
vector_type TEXT NOT NULL,
window_id TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
""")
# Embeddings table and sequence (stores vectors as JSON)
conn.execute("""
CREATE SEQUENCE IF NOT EXISTS embeddings_id_seq START 1
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER DEFAULT nextval('embeddings_id_seq'),
motion_id INTEGER NOT NULL,
model TEXT,
vector JSON NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
""")
# Similarity cache and sequence (stores only ids and score, no vectors)
conn.execute("""
CREATE SEQUENCE IF NOT EXISTS similarity_cache_id_seq START 1
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS similarity_cache (
id INTEGER DEFAULT nextval('similarity_cache_id_seq'),
source_motion_id INTEGER NOT NULL,
target_motion_id INTEGER NOT NULL,
vector_type TEXT NOT NULL,
window_id TEXT,
score FLOAT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
)
""")
conn.close()
def reset_database(self):
"""Development helper: drop known tables and re-run initialization.
WARNING: intended for dev/test only. This will remove tables and recreate schema.
"""
conn = duckdb.connect(self.db_path)
try:
# Drop known tables if they exist
for t in ("party_results", "user_sessions", "motions"):
try:
conn.execute(f"DROP TABLE IF EXISTS {t}")
except Exception:
pass
# Recreate schema
conn.close()
self._init_database()
finally:
try:
conn.close()
except Exception:
pass
def insert_motion(self, motion_data: Dict) -> bool:
"""Insert a new motion into database"""
try:
conn = duckdb.connect(self.db_path)
# Check if motion already exists by URL to avoid duplicates
existing = conn.execute(
"""
SELECT COUNT(*) FROM motions WHERE url = ?
""",
(motion_data["url"],),
).fetchone()
if existing and existing[0] > 0:
conn.close()
return False # Motion already exists
# Insert motion - id will be auto-generated by sequence
conn.execute(
"""
INSERT INTO motions
(title, description, date, policy_area, voting_results,
winning_margin, controversy_score, url, externe_identifier, body_text, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
""",
(
motion_data["title"],
motion_data["description"] or "",
motion_data["date"],
motion_data["policy_area"],
json.dumps(motion_data["voting_results"]),
motion_data["winning_margin"],
1 - motion_data["winning_margin"], # controversy score
motion_data["url"],
motion_data.get("externe_identifier"),
motion_data.get("body_text"),
),
)
conn.close()
# Also insert mp_vote rows for individual MPs if party data is available.
# This only runs for brand-new motions (existing motions are rejected above),
# so there is no risk of duplicates — no existence check needed here.
mp_vote_parties = motion_data.get("mp_vote_parties", {})
voting_results_raw = motion_data.get("voting_results", {})
if mp_vote_parties:
conn2 = duckdb.connect(self.db_path)
row = conn2.execute(
"SELECT id FROM motions WHERE url = ? LIMIT 1",
(motion_data["url"],),
).fetchone()
conn2.close()
motion_id = row[0] if row else None
if motion_id is not None:
motion_date = motion_data.get("date", "")
for mp_name, party in mp_vote_parties.items():
vote = voting_results_raw.get(mp_name, "afwezig")
self.insert_mp_vote(
motion_id=motion_id,
mp_name=mp_name,
party=party,
vote=vote,
date=motion_date,
)
return True
except Exception as e:
print(f"Error inserting motion: {e}")
if "conn" in locals():
conn.close()
return False
def batch_insert_motions(self, motions_data: List[Dict]) -> Tuple[int, int]:
"""Batch-insert motions and their mp_votes using a single DuckDB connection.
Returns (inserted_count, duplicate_count).
"""
if not motions_data:
return 0, 0
try:
conn = duckdb.connect(self.db_path)
# 1. Find which URLs already exist — single query
urls = [m["url"] for m in motions_data]
placeholders = ", ".join("?" * len(urls))
existing_urls = set(
row[0]
for row in conn.execute(
f"SELECT url FROM motions WHERE url IN ({placeholders})", urls
).fetchall()
)
new_motions = [m for m in motions_data if m["url"] not in existing_urls]
duplicates = len(motions_data) - len(new_motions)
if not new_motions:
conn.close()
return 0, duplicates
# 2. Bulk-insert motions
motion_rows = [
(
m["title"],
m["description"] or "",
m["date"],
m["policy_area"],
json.dumps(m["voting_results"]),
m["winning_margin"],
1 - m["winning_margin"],
m["url"],
m.get("externe_identifier"),
m.get("body_text"),
)
for m in new_motions
]
conn.executemany(
"""
INSERT INTO motions
(title, description, date, policy_area, voting_results,
winning_margin, controversy_score, url, externe_identifier,
body_text, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
""",
motion_rows,
)
# 3. Fetch the newly-assigned IDs in one query
new_urls = [m["url"] for m in new_motions]
np = ", ".join("?" * len(new_urls))
url_to_id = {
row[1]: row[0]
for row in conn.execute(
f"SELECT id, url FROM motions WHERE url IN ({np})", new_urls
).fetchall()
}
# 4. Bulk-insert mp_votes
vote_rows = []
for m in new_motions:
motion_id = url_to_id.get(m["url"])
if motion_id is None:
continue
mp_vote_parties = m.get("mp_vote_parties", {})
voting_results_raw = m.get("voting_results", {})
motion_date = m.get("date", "")
for mp_name, party in mp_vote_parties.items():
vote = voting_results_raw.get(mp_name, "afwezig")
vote_rows.append((motion_id, mp_name, party, vote, motion_date))
if vote_rows:
conn.executemany(
"""
INSERT INTO mp_votes (motion_id, mp_name, party, vote, date, created_at)
VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
""",
vote_rows,
)
conn.close()
return len(new_motions), duplicates
except Exception as e:
_logger.error(f"Error in batch_insert_motions: {e}")
try:
conn.close()
except Exception:
pass
raise
def get_filtered_motions(
self,
policy_area: str = "Alle",
min_margin: float = 0.2,
max_margin: float = 0.8,
limit: int = 100,
) -> List[Dict]:
"""Get motions filtered by criteria"""
conn = duckdb.connect(self.db_path)
query = """
SELECT * FROM motions
WHERE winning_margin BETWEEN ? AND ?
AND layman_explanation IS NOT NULL
AND layman_explanation != ''
"""
params = [min_margin, max_margin]
if policy_area != "Alle":
query += " AND policy_area = ?"
params.append(policy_area)
query += " ORDER BY controversy_score DESC LIMIT ?"
params.append(limit)
try:
result = conn.execute(query, params).fetchall()
columns = [desc[0] for desc in conn.description]
conn.close()
return [dict(zip(columns, row)) for row in result]
except Exception as e:
print(f"Error querying motions: {e}")
conn.close()
return []
def create_session(self, total_motions: int = 10) -> str:
"""Create new user session"""
session_id = str(uuid.uuid4())
conn = duckdb.connect(self.db_path)
conn.execute(
"""
INSERT INTO user_sessions (session_id, user_votes, total_motions)
VALUES (?, '{}', ?)
""",
(session_id, total_motions),
)
conn.close()
return session_id
def update_user_vote(self, session_id: str, motion_id: int, vote: str):
"""Update user vote for a motion"""
conn = duckdb.connect(self.db_path)
# Get current votes
current_votes = conn.execute(
"""
SELECT user_votes FROM user_sessions WHERE session_id = ?
""",
(session_id,),
).fetchone()
if current_votes:
votes_dict = json.loads(current_votes[0])
votes_dict[str(motion_id)] = vote
conn.execute(
"""
UPDATE user_sessions
SET user_votes = ?,
completed_motions = ?,
last_updated = CURRENT_TIMESTAMP
WHERE session_id = ?
""",
(json.dumps(votes_dict), len(votes_dict), session_id),
)
conn.close()
def calculate_party_matches(self, session_id: str) -> List[Dict]:
"""Calculate party agreement percentages"""
conn = duckdb.connect(self.db_path)
# Get user votes and motion data
user_data = conn.execute(
"""
SELECT user_votes FROM user_sessions WHERE session_id = ?
""",
(session_id,),
).fetchone()
if not user_data:
return []
user_votes = json.loads(user_data[0])
motion_ids = list(user_votes.keys())
if not motion_ids:
return []
# Get motion voting results
placeholders = ",".join(["?" for _ in motion_ids])
motions = conn.execute(
f"""
SELECT id, voting_results FROM motions
WHERE id IN ({placeholders})
""",
motion_ids,
).fetchall()
conn.close()
# Calculate agreements
party_scores = {}
for motion_id, voting_results_json in motions:
voting_results = json.loads(voting_results_json)
user_vote = user_votes[str(motion_id)]
if user_vote == "Geen stem": # Skip abstentions
continue
for party, party_vote in voting_results.items():
# Skip individual MP names (contain comma, e.g. "Yesilgöz-Zegerius, D.")
# Party/fractie names never contain a comma.
if "," in party:
continue
if party not in party_scores:
party_scores[party] = {"agreed": 0, "total": 0}
party_scores[party]["total"] += 1
# Check agreement
if (user_vote == "Voor" and party_vote == "voor") or (
user_vote == "Tegen" and party_vote == "tegen"
):
party_scores[party]["agreed"] += 1
# Convert to percentages and sort
results = []
for party, scores in party_scores.items():
if scores["total"] > 0:
agreement_pct = (scores["agreed"] / scores["total"]) * 100
results.append(
{
"party": party,
"agreement_percentage": round(agreement_pct, 1),
"agreed_motions": scores["agreed"],
"total_motions": scores["total"],
}
)
return sorted(results, key=lambda x: x["agreement_percentage"], reverse=True)
def store_embedding(self, motion_id: int, model: str, vector: List[float]) -> int:
"""Store an embedding for a motion. Returns inserted row id or -1 on failure."""
try:
conn = duckdb.connect(self.db_path)
# Use explicit nextval for id since older tables may lack DEFAULT
conn.execute(
"INSERT INTO embeddings (id, motion_id, model, vector, created_at) VALUES (nextval('embeddings_id_seq'), ?, ?, ?, CURRENT_TIMESTAMP)",
(motion_id, model, json.dumps(vector)),
)
row = conn.execute("SELECT currval('embeddings_id_seq')").fetchone()
conn.close()
if row and row[0] is not None:
return int(row[0])
return -1
except Exception as e:
_logger.error("Error storing embedding: %s", e)
try:
conn.close()
except Exception:
pass
return -1
def search_similar(
self, query_vector: List[float], top_k: int = 5, model: Optional[str] = None
) -> List[Dict]:
"""Naive in-Python cosine similarity search over stored embeddings.
Returns list of dicts with keys: id, motion_id, model, score, created_at
"""
try:
conn = duckdb.connect(self.db_path)
if model:
rows = conn.execute(
"SELECT id, motion_id, model, vector, created_at FROM embeddings WHERE model = ?",
(model,),
).fetchall()
else:
rows = conn.execute(
"SELECT id, motion_id, model, vector, created_at FROM embeddings"
).fetchall()
conn.close()
results = []
import math
for r in rows:
id_, motion_id, mdl, vector_json, created_at = r
try:
vec = json.loads(vector_json)
except Exception:
continue
# cosine similarity
try:
dot = sum(float(a) * float(b) for a, b in zip(query_vector, vec))
na = math.sqrt(sum(float(a) * float(a) for a in query_vector))
nb = math.sqrt(sum(float(b) * float(b) for b in vec))
score = dot / (na * nb) if na and nb else 0.0
except Exception:
score = 0.0
results.append(
{
"id": id_,
"motion_id": motion_id,
"model": mdl,
"score": score,
"created_at": created_at,
}
)
results.sort(key=lambda x: x["score"], reverse=True)
return results[:top_k]
except Exception as e:
print(f"Error searching embeddings: {e}")
try:
conn.close()
except Exception:
pass
return []
def mp_votes_exists_for_motion(self, motion_id: int) -> bool:
try:
conn = duckdb.connect(self.db_path)
row = conn.execute(
"SELECT COUNT(*) FROM mp_votes WHERE motion_id = ?",
(motion_id,),
).fetchone()
conn.close()
return bool(row and row[0] > 0)
except Exception as e:
_logger.error(f"Error checking mp_votes existence: {e}")
try:
conn.close()
except Exception:
pass
return False
def insert_mp_vote(
self,
motion_id: int,
mp_name: str,
vote: str,
date: Optional[str] = None,
party: Optional[str] = None,
) -> int:
try:
conn = duckdb.connect(self.db_path)
conn.execute(
"""
INSERT INTO mp_votes (motion_id, mp_name, party, vote, date, created_at)
VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
""",
(motion_id, mp_name, party, vote, date),
)
row = conn.execute("SELECT max(id) FROM mp_votes").fetchone()
conn.close()
if row and row[0] is not None:
return int(row[0])
return -1
except Exception as e:
_logger.error(f"Error inserting mp_vote: {e}")
try:
conn.close()
except Exception:
pass
return -1
def upsert_mp_metadata(
self,
mp_name: str,
party: Optional[str],
van: Optional[str],
tot_en_met: Optional[str],
persoon_id: Optional[str],
) -> None:
try:
conn = duckdb.connect(self.db_path)
exists = conn.execute(
"SELECT COUNT(*) FROM mp_metadata WHERE mp_name = ?", (mp_name,)
).fetchone()
if exists and exists[0] > 0:
# Only update if this record is newer (higher Van date) than the stored one,
# preferring active memberships (TotEnMet IS NULL) over ended ones.
conn.execute(
"""
UPDATE mp_metadata SET party = ?, van = ?, tot_en_met = ?, persoon_id = ?
WHERE mp_name = ?
AND (
-- prefer active over ended
(? IS NULL AND tot_en_met IS NOT NULL)
-- or same active status but newer start date
OR (? IS NULL AND tot_en_met IS NULL AND CAST(? AS DATE) > CAST(van AS DATE))
OR (? IS NOT NULL AND tot_en_met IS NOT NULL AND CAST(? AS DATE) > CAST(van AS DATE))
)
""",
(
party,
van,
tot_en_met,
persoon_id,
mp_name,
tot_en_met, # prefer active
tot_en_met,
van, # both active, newer
tot_en_met,
van,
), # both ended, newer
)
else:
conn.execute(
"""
INSERT INTO mp_metadata (mp_name, party, van, tot_en_met, persoon_id)
VALUES (?, ?, ?, ?, ?)
""",
(mp_name, party, van, tot_en_met, persoon_id),
)
conn.close()
except Exception as e:
_logger.error(f"Error upserting mp_metadata: {e}")
try:
conn.close()
except Exception:
pass
def store_svd_vector(
self,
window_id: str,
entity_type: str,
entity_id: str,
vector: List[float],
model: Optional[str] = None,
) -> int:
try:
conn = duckdb.connect(self.db_path)
conn.execute(
"""
INSERT INTO svd_vectors (window_id, entity_type, entity_id, vector, model, created_at)
VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
""",
(window_id, entity_type, entity_id, json.dumps(vector), model),
)
row = conn.execute("SELECT max(id) FROM svd_vectors").fetchone()
conn.close()
if row and row[0] is not None:
return int(row[0])
return -1
except Exception as e:
_logger.error(f"Error storing svd_vector: {e}")
try:
conn.close()
except Exception:
pass
return -1
def batch_store_svd_vectors(
self,
window_id: str,
rows: List[Tuple], # each: (entity_type, entity_id, vector_list, model_or_None)
) -> int:
"""Batch-upsert SVD vectors for a window using a single connection.
Deletes all existing rows for the window first, then inserts the new batch.
Returns number of rows inserted.
"""
if not rows:
return 0
try:
conn = duckdb.connect(self.db_path)
conn.execute("DELETE FROM svd_vectors WHERE window_id = ?", (window_id,))
insert_rows = [
(window_id, entity_type, entity_id, json.dumps(vector), model)
for entity_type, entity_id, vector, model in rows
]
conn.executemany(
"""
INSERT INTO svd_vectors
(window_id, entity_type, entity_id, vector, model, created_at)
VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
""",
insert_rows,
)
conn.close()
return len(insert_rows)
except Exception as e:
_logger.error(f"Error in batch_store_svd_vectors: {e}")
try:
conn.close()
except Exception:
pass
raise
def store_fused_embedding(
self,
motion_id: int,
window_id: str,
vector: List[float],
svd_dims: int,
text_dims: int,
) -> int:
try:
conn = duckdb.connect(self.db_path)
# Delete any existing row for this (motion_id, window_id) to prevent duplicates
conn.execute(
"DELETE FROM fused_embeddings WHERE motion_id = ? AND window_id = ?",
(motion_id, window_id),
)
conn.execute(
"""
INSERT INTO fused_embeddings (motion_id, window_id, vector, svd_dims, text_dims, created_at)
VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
""",
(motion_id, window_id, json.dumps(vector), svd_dims, text_dims),
)
row = conn.execute("SELECT max(id) FROM fused_embeddings").fetchone()
conn.close()
if row and row[0] is not None:
return int(row[0])
return -1
except Exception as e:
_logger.error(f"Error storing fused_embedding: {e}")
try:
conn.close()
except Exception:
pass
return -1
def store_similarity_batch(self, rows: List[Dict]) -> int:
"""Insert multiple similarity_cache rows. Returns number inserted."""
if not rows:
return 0
inserted = 0
# File-backed fallback when duckdb is not available
if duckdb is None:
sim_file = f"{self.db_path}.similarity_cache.json"
try:
with open(sim_file, "r+", encoding="utf-8") as fh:
data = json.load(fh)
# assign incremental ids
max_id = max((item.get("id", 0) for item in data), default=0)
for r in rows:
max_id += 1
entry = {
"id": max_id,
"source_motion_id": int(r["source_motion_id"]),
"target_motion_id": int(r["target_motion_id"]),
"score": float(r["score"]),
"vector_type": r["vector_type"],
"window_id": r.get("window_id"),
}
data.append(entry)
inserted += 1
fh.seek(0)
json.dump(data, fh)
fh.truncate()
return inserted
except Exception as e:
_logger.error(f"Error writing similarity cache file: {e}")
return inserted
try:
conn = duckdb.connect(self.db_path)
for r in rows:
try:
conn.execute(
"""
INSERT INTO similarity_cache (source_motion_id, target_motion_id, score, vector_type, window_id, created_at)
VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
""",
(
r["source_motion_id"],
r["target_motion_id"],
float(r["score"]),
r["vector_type"],
r.get("window_id"),
),
)
inserted += 1
except Exception as e:
_logger.error(f"Error inserting similarity row {r}: {e}")
conn.close()
return inserted
except Exception as e:
_logger.error(f"Error in store_similarity_batch: {e}")
try:
conn.close()
except Exception:
pass
return inserted
def get_cached_similarities(
self,
source_motion_id: int,
vector_type: str,
window_id: Optional[str] = None,
top_k: int = 10,
) -> List[Dict]:
"""Retrieve cached similarities for a source motion.
Returns list of dicts with keys: target_motion_id, score, created_at, id
"""
# File-backed fallback
if duckdb is None:
sim_file = f"{self.db_path}.similarity_cache.json"
try:
with open(sim_file, "r", encoding="utf-8") as fh:
data = json.load(fh)
rows = [
r
for r in data
if int(r.get("source_motion_id")) == int(source_motion_id)
and r.get("vector_type") == vector_type
and (window_id is None or r.get("window_id") == window_id)
]
# sort by score desc
rows.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return rows[:top_k]
except Exception as e:
_logger.error(f"Error reading similarity cache file: {e}")
return []
try:
conn = duckdb.connect(self.db_path)
params = [source_motion_id, vector_type]
query = (
"SELECT id, target_motion_id, score, created_at FROM similarity_cache"
" WHERE source_motion_id = ? AND vector_type = ?"
)
if window_id is not None:
query += " AND window_id = ?"
params.append(window_id)
query += " ORDER BY score DESC LIMIT ?"
params.append(top_k)
rows = conn.execute(query, params).fetchall()
columns = [desc[0] for desc in conn.description]
conn.close()
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
_logger.error(f"Error fetching cached similarities: {e}")
try:
conn.close()
except Exception:
pass
return []
def clear_similarity_cache(
self, vector_type: str, window_id: Optional[str] = None
) -> int:
"""Delete cached similarity rows matching vector_type and optional window_id. Returns count deleted."""
try:
# File-backed fallback
if duckdb is None:
sim_file = f"{self.db_path}.similarity_cache.json"
try:
with open(sim_file, "r+", encoding="utf-8") as fh:
data = json.load(fh)
before = len(data)
data = [
r
for r in data
if not (
r.get("vector_type") == vector_type
and (
window_id is None or r.get("window_id") == window_id
)
)
]
deleted = before - len(data)
fh.seek(0)
json.dump(data, fh)
fh.truncate()
return deleted
except Exception as e:
_logger.error(f"Error clearing similarity cache file: {e}")
return 0
conn = duckdb.connect(self.db_path)
params = [vector_type]
count_q = "SELECT COUNT(*) FROM similarity_cache WHERE vector_type = ?"
del_q = "DELETE FROM similarity_cache WHERE vector_type = ?"
if window_id is not None:
count_q += " AND window_id = ?"
del_q += " AND window_id = ?"
params.append(window_id)
row = conn.execute(count_q, params).fetchone()
to_delete = int(row[0]) if row and row[0] is not None else 0
if to_delete > 0:
conn.execute(del_q, params)
conn.close()
return to_delete
except Exception as e:
_logger.error(f"Error clearing similarity_cache: {e}")
try:
conn.close()
except Exception:
pass
return 0
db = MotionDatabase()