"""Re-run text embeddings, fusion, and similarity for all windows. Clears stale embeddings, re-embeds all motions with available text, then fuses SVD + text vectors and rebuilds similarity cache for every window that has SVD vectors in the database. Usage: .venv/bin/python scripts/rerun_embeddings.py --db-path data/motions.db """ import argparse import logging try: import duckdb except Exception: duckdb = None from pipeline import text_pipeline import importlib # If duckdb is not present at import time (test environments), avoid hard failure try: importlib.import_module("duckdb") except Exception: pass from pipeline import fusion as fusion_pipeline from similarity import compute as similarity_compute _logger = logging.getLogger(__name__) def _get_all_windows(db_path: str): """Return all distinct window_ids that have SVD vectors.""" try: conn = duckdb.connect(db_path, read_only=True) except Exception: _logger.exception( "Unable to connect to duckdb for _get_all_windows(%s)", db_path ) return [] try: rows = conn.execute( "SELECT DISTINCT window_id FROM svd_vectors ORDER BY window_id" ).fetchall() return [r[0] for r in rows] except Exception: _logger.exception("Error querying windows from %s", db_path) return [] finally: try: conn.close() except Exception: pass def _clear_embeddings(db_path: str) -> int: """Delete all rows from embeddings, fused_embeddings, and similarity_cache.""" try: conn = duckdb.connect(db_path) except Exception: _logger.exception( "Unable to connect to duckdb for _clear_embeddings(%s)", db_path ) return 0 try: emb = conn.execute("DELETE FROM embeddings").rowcount or 0 fused = conn.execute("DELETE FROM fused_embeddings").rowcount or 0 sim = conn.execute("DELETE FROM similarity_cache").rowcount or 0 conn.commit() _logger.info( "Cleared: %d embeddings, %d fused_embeddings, %d similarity_cache rows", emb, fused, sim, ) return emb + fused + sim except Exception: _logger.exception("Error clearing embeddings in %s", db_path) return 0 finally: try: conn.close() except Exception: pass def rerun_embeddings( db_path: str, model: str = None, retry_missing: bool = False, growth_factor: float = 1.5, ) -> dict: """Full rerun: clear → embed → fuse → similarity for all windows. Returns a summary dict. """ _logger.info("Starting rerun_embeddings for %s", db_path) # 1. Clear stale data cleared = _clear_embeddings(db_path) # 2. Re-embed all motions _logger.info("Running text embeddings ...") # Call ensure_text_embeddings which historically returned either a 4-tuple # (stored, skipped_existing, skipped_no_text, errors) or a 5-tuple that # includes failed_ids as the fifth element. Support both shapes for # backward-compatibility. result = text_pipeline.ensure_text_embeddings( db_path=db_path, model=model, growth_factor=growth_factor ) if isinstance(result, tuple) and len(result) == 5: stored, skipped_existing, skipped_no_text, emb_errors, failed_ids = result elif isinstance(result, tuple) and len(result) == 4: stored, skipped_existing, skipped_no_text, emb_errors = result failed_ids = [] else: # Fallback: try to unpack defensively try: stored, skipped_existing, skipped_no_text, emb_errors, failed_ids = result except Exception: _logger.error( "Unexpected return shape from ensure_text_embeddings: %s", result ) stored = skipped_existing = skipped_no_text = emb_errors = 0 failed_ids = [] # Optionally retry missing failed ids with smaller batch sizes if retry_missing and failed_ids: try: _logger.info( "Retrying %d failed embeddings with smaller batches", len(failed_ids) ) # prefer a helper that can process only specific ids if available if hasattr(text_pipeline, "ensure_text_embeddings_for_ids"): text_pipeline.ensure_text_embeddings_for_ids( db_path=db_path, ids=failed_ids, model=model, batch_size=max(1, 20) ) else: # best-effort: call ensure_text_embeddings and let implementation handle limiting text_pipeline.ensure_text_embeddings( db_path=db_path, model=model, batch_size=max(1, 20) ) except Exception: _logger.exception("Retrying missing embeddings failed") _logger.info( "Text embeddings: stored=%d, skipped_existing=%d, skipped_no_text=%d, errors=%d", stored, skipped_existing, skipped_no_text, emb_errors, ) # 3. Get all windows with SVD vectors windows = _get_all_windows(db_path) _logger.info("Found %d windows with SVD vectors: %s", len(windows), windows) fusion_summary = {} similarity_summary = {} for window_id in windows: _logger.info("Processing window %s ...", window_id) # 3a. Fuse try: result = fusion_pipeline.fuse_for_window(window_id, db_path=db_path) fusion_summary[window_id] = result _logger.info(" fuse_for_window(%s) -> %s", window_id, result) except Exception: _logger.exception(" fuse_for_window failed for %s", window_id) fusion_summary[window_id] = {"error": True} # 3b. Compute similarities try: inserted = similarity_compute.compute_similarities( vector_type="fused", window_id=window_id, db_path=db_path, ) similarity_summary[window_id] = inserted _logger.info(" compute_similarities(%s) -> %d rows", window_id, inserted) except Exception: _logger.exception(" compute_similarities failed for %s", window_id) similarity_summary[window_id] = -1 _logger.info("Finished rerun_embeddings for %s", db_path) return { "cleared_rows": cleared, "embeddings_stored": stored, "embeddings_skipped_no_text": skipped_no_text, "embeddings_errors": emb_errors, "embeddings_failed_ids": failed_ids, "windows_processed": len(windows), "fusion_summary": fusion_summary, "similarity_summary": similarity_summary, } def _main(): logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s", ) parser = argparse.ArgumentParser( description="Re-run embeddings, fusion, similarity" ) parser.add_argument("--db-path", required=True, help="Path to motions.db") parser.add_argument( "--model", default=None, help="Embedding model name (default: text_pipeline default)", ) parser.add_argument( "--growth-factor", type=float, default=1.5, help="AIMD growth factor for batch-size tuning (default: 1.5)", ) args = parser.parse_args() summary = rerun_embeddings( args.db_path, model=args.model, growth_factor=args.growth_factor ) print(f"cleared_rows: {summary['cleared_rows']}") print(f"embeddings_stored: {summary['embeddings_stored']}") print(f"embeddings_skipped_no_text: {summary['embeddings_skipped_no_text']}") print(f"embeddings_errors: {summary['embeddings_errors']}") print(f"windows_processed: {summary['windows_processed']}") if __name__ == "__main__": _main()