You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/scripts/rerun_embeddings.py

220 lines
7.4 KiB

"""Re-run text embeddings, fusion, and similarity for all windows.
Clears stale embeddings, re-embeds all motions with available text,
then fuses SVD + text vectors and rebuilds similarity cache for every
window that has SVD vectors in the database.
Usage:
.venv/bin/python scripts/rerun_embeddings.py --db-path data/motions.db
"""
import argparse
import logging
try:
import duckdb
except Exception:
duckdb = None
from pipeline import text_pipeline
import importlib
# If duckdb is not present at import time (test environments), avoid hard failure
try:
importlib.import_module("duckdb")
except Exception:
pass
from pipeline import fusion as fusion_pipeline
from similarity import compute as similarity_compute
_logger = logging.getLogger(__name__)
def _get_all_windows(db_path: str):
"""Return all distinct window_ids that have SVD vectors."""
try:
conn = duckdb.connect(db_path, read_only=True)
except Exception:
_logger.exception(
"Unable to connect to duckdb for _get_all_windows(%s)", db_path
)
return []
try:
rows = conn.execute(
"SELECT DISTINCT window_id FROM svd_vectors ORDER BY window_id"
).fetchall()
return [r[0] for r in rows]
except Exception:
_logger.exception("Error querying windows from %s", db_path)
return []
finally:
try:
conn.close()
except Exception:
pass
def _clear_embeddings(db_path: str) -> int:
"""Delete all rows from embeddings, fused_embeddings, and similarity_cache."""
try:
conn = duckdb.connect(db_path)
except Exception:
_logger.exception(
"Unable to connect to duckdb for _clear_embeddings(%s)", db_path
)
return 0
try:
emb = conn.execute("DELETE FROM embeddings").rowcount or 0
fused = conn.execute("DELETE FROM fused_embeddings").rowcount or 0
sim = conn.execute("DELETE FROM similarity_cache").rowcount or 0
conn.commit()
_logger.info(
"Cleared: %d embeddings, %d fused_embeddings, %d similarity_cache rows",
emb,
fused,
sim,
)
return emb + fused + sim
except Exception:
_logger.exception("Error clearing embeddings in %s", db_path)
return 0
finally:
try:
conn.close()
except Exception:
pass
def rerun_embeddings(
db_path: str, model: str = None, retry_missing: bool = False
) -> dict:
"""Full rerun: clear → embed → fuse → similarity for all windows.
Returns a summary dict.
"""
_logger.info("Starting rerun_embeddings for %s", db_path)
# 1. Clear stale data
cleared = _clear_embeddings(db_path)
# 2. Re-embed all motions
_logger.info("Running text embeddings ...")
# Call ensure_text_embeddings which historically returned either a 4-tuple
# (stored, skipped_existing, skipped_no_text, errors) or a 5-tuple that
# includes failed_ids as the fifth element. Support both shapes for
# backward-compatibility.
result = text_pipeline.ensure_text_embeddings(db_path=db_path, model=model)
if isinstance(result, tuple) and len(result) == 5:
stored, skipped_existing, skipped_no_text, emb_errors, failed_ids = result
elif isinstance(result, tuple) and len(result) == 4:
stored, skipped_existing, skipped_no_text, emb_errors = result
failed_ids = []
else:
# Fallback: try to unpack defensively
try:
stored, skipped_existing, skipped_no_text, emb_errors, failed_ids = result
except Exception:
_logger.error(
"Unexpected return shape from ensure_text_embeddings: %s", result
)
stored = skipped_existing = skipped_no_text = emb_errors = 0
failed_ids = []
# Optionally retry missing failed ids with smaller batch sizes
if retry_missing and failed_ids:
try:
_logger.info(
"Retrying %d failed embeddings with smaller batches", len(failed_ids)
)
# prefer a helper that can process only specific ids if available
if hasattr(text_pipeline, "ensure_text_embeddings_for_ids"):
text_pipeline.ensure_text_embeddings_for_ids(
db_path=db_path, ids=failed_ids, model=model, batch_size=max(1, 20)
)
else:
# best-effort: call ensure_text_embeddings and let implementation handle limiting
text_pipeline.ensure_text_embeddings(
db_path=db_path, model=model, batch_size=max(1, 20)
)
except Exception:
_logger.exception("Retrying missing embeddings failed")
_logger.info(
"Text embeddings: stored=%d, skipped_existing=%d, skipped_no_text=%d, errors=%d",
stored,
skipped_existing,
skipped_no_text,
emb_errors,
)
# 3. Get all windows with SVD vectors
windows = _get_all_windows(db_path)
_logger.info("Found %d windows with SVD vectors: %s", len(windows), windows)
fusion_summary = {}
similarity_summary = {}
for window_id in windows:
_logger.info("Processing window %s ...", window_id)
# 3a. Fuse
try:
result = fusion_pipeline.fuse_for_window(window_id, db_path=db_path)
fusion_summary[window_id] = result
_logger.info(" fuse_for_window(%s) -> %s", window_id, result)
except Exception:
_logger.exception(" fuse_for_window failed for %s", window_id)
fusion_summary[window_id] = {"error": True}
# 3b. Compute similarities
try:
inserted = similarity_compute.compute_similarities(
vector_type="fused",
window_id=window_id,
db_path=db_path,
)
similarity_summary[window_id] = inserted
_logger.info(" compute_similarities(%s) -> %d rows", window_id, inserted)
except Exception:
_logger.exception(" compute_similarities failed for %s", window_id)
similarity_summary[window_id] = -1
_logger.info("Finished rerun_embeddings for %s", db_path)
return {
"cleared_rows": cleared,
"embeddings_stored": stored,
"embeddings_skipped_no_text": skipped_no_text,
"embeddings_errors": emb_errors,
"embeddings_failed_ids": failed_ids,
"windows_processed": len(windows),
"fusion_summary": fusion_summary,
"similarity_summary": similarity_summary,
}
def _main():
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
parser = argparse.ArgumentParser(
description="Re-run embeddings, fusion, similarity"
)
parser.add_argument("--db-path", required=True, help="Path to motions.db")
parser.add_argument(
"--model",
default=None,
help="Embedding model name (default: text_pipeline default)",
)
args = parser.parse_args()
summary = rerun_embeddings(args.db_path, model=args.model)
print(f"cleared_rows: {summary['cleared_rows']}")
print(f"embeddings_stored: {summary['embeddings_stored']}")
print(f"embeddings_skipped_no_text: {summary['embeddings_skipped_no_text']}")
print(f"embeddings_errors: {summary['embeddings_errors']}")
print(f"windows_processed: {summary['windows_processed']}")
if __name__ == "__main__":
_main()