You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
7.7 KiB
233 lines
7.7 KiB
"""Re-run text embeddings, fusion, and similarity for all windows.
|
|
|
|
Clears stale embeddings, re-embeds all motions with available text,
|
|
then fuses SVD + text vectors and rebuilds similarity cache for every
|
|
window that has SVD vectors in the database.
|
|
|
|
Usage:
|
|
.venv/bin/python scripts/rerun_embeddings.py --db-path data/motions.db
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
|
|
try:
|
|
import duckdb
|
|
except Exception:
|
|
duckdb = None
|
|
|
|
from pipeline import text_pipeline
|
|
import importlib
|
|
|
|
# If duckdb is not present at import time (test environments), avoid hard failure
|
|
try:
|
|
importlib.import_module("duckdb")
|
|
except Exception:
|
|
pass
|
|
from pipeline import fusion as fusion_pipeline
|
|
from similarity import compute as similarity_compute
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_all_windows(db_path: str):
|
|
"""Return all distinct window_ids that have SVD vectors."""
|
|
try:
|
|
conn = duckdb.connect(db_path, read_only=True)
|
|
except Exception:
|
|
_logger.exception(
|
|
"Unable to connect to duckdb for _get_all_windows(%s)", db_path
|
|
)
|
|
return []
|
|
|
|
try:
|
|
rows = conn.execute(
|
|
"SELECT DISTINCT window_id FROM svd_vectors ORDER BY window_id"
|
|
).fetchall()
|
|
return [r[0] for r in rows]
|
|
except Exception:
|
|
_logger.exception("Error querying windows from %s", db_path)
|
|
return []
|
|
finally:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _clear_embeddings(db_path: str) -> int:
|
|
"""Delete all rows from embeddings, fused_embeddings, and similarity_cache."""
|
|
try:
|
|
conn = duckdb.connect(db_path)
|
|
except Exception:
|
|
_logger.exception(
|
|
"Unable to connect to duckdb for _clear_embeddings(%s)", db_path
|
|
)
|
|
return 0
|
|
|
|
try:
|
|
emb = conn.execute("DELETE FROM embeddings").rowcount or 0
|
|
fused = conn.execute("DELETE FROM fused_embeddings").rowcount or 0
|
|
sim = conn.execute("DELETE FROM similarity_cache").rowcount or 0
|
|
conn.commit()
|
|
_logger.info(
|
|
"Cleared: %d embeddings, %d fused_embeddings, %d similarity_cache rows",
|
|
emb,
|
|
fused,
|
|
sim,
|
|
)
|
|
return emb + fused + sim
|
|
except Exception:
|
|
_logger.exception("Error clearing embeddings in %s", db_path)
|
|
return 0
|
|
finally:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def rerun_embeddings(
|
|
db_path: str,
|
|
model: str = None,
|
|
retry_missing: bool = False,
|
|
growth_factor: float = 1.5,
|
|
) -> dict:
|
|
"""Full rerun: clear → embed → fuse → similarity for all windows.
|
|
|
|
Returns a summary dict.
|
|
"""
|
|
_logger.info("Starting rerun_embeddings for %s", db_path)
|
|
|
|
# 1. Clear stale data
|
|
cleared = _clear_embeddings(db_path)
|
|
|
|
# 2. Re-embed all motions
|
|
_logger.info("Running text embeddings ...")
|
|
# Call ensure_text_embeddings which historically returned either a 4-tuple
|
|
# (stored, skipped_existing, skipped_no_text, errors) or a 5-tuple that
|
|
# includes failed_ids as the fifth element. Support both shapes for
|
|
# backward-compatibility.
|
|
result = text_pipeline.ensure_text_embeddings(
|
|
db_path=db_path, model=model, growth_factor=growth_factor
|
|
)
|
|
if isinstance(result, tuple) and len(result) == 5:
|
|
stored, skipped_existing, skipped_no_text, emb_errors, failed_ids = result
|
|
elif isinstance(result, tuple) and len(result) == 4:
|
|
stored, skipped_existing, skipped_no_text, emb_errors = result
|
|
failed_ids = []
|
|
else:
|
|
# Fallback: try to unpack defensively
|
|
try:
|
|
stored, skipped_existing, skipped_no_text, emb_errors, failed_ids = result
|
|
except Exception:
|
|
_logger.error(
|
|
"Unexpected return shape from ensure_text_embeddings: %s", result
|
|
)
|
|
stored = skipped_existing = skipped_no_text = emb_errors = 0
|
|
failed_ids = []
|
|
# Optionally retry missing failed ids with smaller batch sizes
|
|
if retry_missing and failed_ids:
|
|
try:
|
|
_logger.info(
|
|
"Retrying %d failed embeddings with smaller batches", len(failed_ids)
|
|
)
|
|
# prefer a helper that can process only specific ids if available
|
|
if hasattr(text_pipeline, "ensure_text_embeddings_for_ids"):
|
|
text_pipeline.ensure_text_embeddings_for_ids(
|
|
db_path=db_path, ids=failed_ids, model=model, batch_size=max(1, 20)
|
|
)
|
|
else:
|
|
# best-effort: call ensure_text_embeddings and let implementation handle limiting
|
|
text_pipeline.ensure_text_embeddings(
|
|
db_path=db_path, model=model, batch_size=max(1, 20)
|
|
)
|
|
except Exception:
|
|
_logger.exception("Retrying missing embeddings failed")
|
|
_logger.info(
|
|
"Text embeddings: stored=%d, skipped_existing=%d, skipped_no_text=%d, errors=%d",
|
|
stored,
|
|
skipped_existing,
|
|
skipped_no_text,
|
|
emb_errors,
|
|
)
|
|
|
|
# 3. Get all windows with SVD vectors
|
|
windows = _get_all_windows(db_path)
|
|
_logger.info("Found %d windows with SVD vectors: %s", len(windows), windows)
|
|
|
|
fusion_summary = {}
|
|
similarity_summary = {}
|
|
|
|
for window_id in windows:
|
|
_logger.info("Processing window %s ...", window_id)
|
|
|
|
# 3a. Fuse
|
|
try:
|
|
result = fusion_pipeline.fuse_for_window(window_id, db_path=db_path)
|
|
fusion_summary[window_id] = result
|
|
_logger.info(" fuse_for_window(%s) -> %s", window_id, result)
|
|
except Exception:
|
|
_logger.exception(" fuse_for_window failed for %s", window_id)
|
|
fusion_summary[window_id] = {"error": True}
|
|
|
|
# 3b. Compute similarities
|
|
try:
|
|
inserted = similarity_compute.compute_similarities(
|
|
vector_type="fused",
|
|
window_id=window_id,
|
|
db_path=db_path,
|
|
)
|
|
similarity_summary[window_id] = inserted
|
|
_logger.info(" compute_similarities(%s) -> %d rows", window_id, inserted)
|
|
except Exception:
|
|
_logger.exception(" compute_similarities failed for %s", window_id)
|
|
similarity_summary[window_id] = -1
|
|
|
|
_logger.info("Finished rerun_embeddings for %s", db_path)
|
|
|
|
return {
|
|
"cleared_rows": cleared,
|
|
"embeddings_stored": stored,
|
|
"embeddings_skipped_no_text": skipped_no_text,
|
|
"embeddings_errors": emb_errors,
|
|
"embeddings_failed_ids": failed_ids,
|
|
"windows_processed": len(windows),
|
|
"fusion_summary": fusion_summary,
|
|
"similarity_summary": similarity_summary,
|
|
}
|
|
|
|
|
|
def _main():
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
)
|
|
parser = argparse.ArgumentParser(
|
|
description="Re-run embeddings, fusion, similarity"
|
|
)
|
|
parser.add_argument("--db-path", required=True, help="Path to motions.db")
|
|
parser.add_argument(
|
|
"--model",
|
|
default=None,
|
|
help="Embedding model name (default: text_pipeline default)",
|
|
)
|
|
parser.add_argument(
|
|
"--growth-factor",
|
|
type=float,
|
|
default=1.5,
|
|
help="AIMD growth factor for batch-size tuning (default: 1.5)",
|
|
)
|
|
args = parser.parse_args()
|
|
summary = rerun_embeddings(
|
|
args.db_path, model=args.model, growth_factor=args.growth_factor
|
|
)
|
|
print(f"cleared_rows: {summary['cleared_rows']}")
|
|
print(f"embeddings_stored: {summary['embeddings_stored']}")
|
|
print(f"embeddings_skipped_no_text: {summary['embeddings_skipped_no_text']}")
|
|
print(f"embeddings_errors: {summary['embeddings_errors']}")
|
|
print(f"windows_processed: {summary['windows_processed']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_main()
|
|
|