import logging import json from typing import Optional, List, Tuple import duckdb from database import MotionDatabase, db as default_db import ai_provider _logger = logging.getLogger(__name__) DEFAULT_MODEL = "qwen/qwen3-embedding-4b" def _select_text( db: MotionDatabase, model: str, limit: Optional[int] = None ) -> List[Tuple[int, Optional[str]]]: """Select motions that do not yet have an embedding for `model`. Returns list of (motion_id, text). """ conn = duckdb.connect(db.db_path) params = [model] # prefer layman_explanation > description > title (keep compatibility with existing tests) sql = ( "SELECT m.id, COALESCE(m.layman_explanation, m.description, m.title) AS text" " FROM motions m" " LEFT JOIN embeddings e ON e.motion_id = m.id AND e.model = ?" " WHERE e.id IS NULL" ) if limit: sql += " LIMIT ?" params.append(limit) try: rows = conn.execute(sql, params).fetchall() conn.close() results: List[Tuple[int, Optional[str]]] = [] for r in rows: text_val = r[1] # treat empty strings as no text if text_val is None: text = None else: text = str(text_val).strip() or None results.append((int(r[0]), text)) return results except Exception as exc: _logger.error("Error selecting motions for embeddings: %s", exc) try: conn.close() except Exception: pass return [] def ensure_text_embeddings( db_path: Optional[str] = None, model: Optional[str] = None, batch_size: int = 50 ) -> Tuple[int, int, int, int]: """Ensure all motions have text embeddings for `model`. Uses batched API calls (batch_size texts per HTTP request) for speed. Returns tuple (stored_count, skipped_existing, skipped_no_text, errors). """ model = model or DEFAULT_MODEL db = MotionDatabase(db_path) if db_path else default_db # motions to process to_process = _select_text(db, model) # how many already exist conn = duckdb.connect(db.db_path) try: total_motions = conn.execute("SELECT COUNT(*) FROM motions").fetchone()[0] except Exception: total_motions = 0 try: existing = conn.execute( "SELECT COUNT(DISTINCT motion_id) FROM embeddings WHERE model = ?", (model,) ).fetchone()[0] except Exception: existing = 0 conn.close() stored = 0 skipped_no_text = 0 errors = 0 # Separate motions with text from those without with_text: List[Tuple[int, str]] = [] for motion_id, text in to_process: if not text: _logger.info("Skipping motion %s: no text available", motion_id) skipped_no_text += 1 else: with_text.append((motion_id, text)) _logger.info( "Processing %d motions in batches of %d (%d skipped no text, %d already exist)", len(with_text), batch_size, skipped_no_text, existing, ) # Process in batches for batch_start in range(0, len(with_text), batch_size): batch = with_text[batch_start : batch_start + batch_size] batch_ids = [mid for mid, _ in batch] batch_texts = [txt for _, txt in batch] try: vecs = ai_provider.get_embeddings_batch( batch_texts, model=model, batch_size=batch_size ) except Exception as exc: _logger.error( "Batch embedding failed for motions %s..%s: %s", batch_ids[0], batch_ids[-1], exc, ) errors += len(batch) continue if len(vecs) != len(batch): _logger.error( "Batch size mismatch: expected %d, got %d embeddings", len(batch), len(vecs), ) errors += len(batch) continue batch_stored = 0 for (motion_id, _text), vec in zip(batch, vecs): if not isinstance(vec, list): _logger.warning( "Embedding provider returned non-list for motion %s", motion_id ) errors += 1 continue try: res = db.store_embedding(motion_id, model, vec) if res and res > 0: stored += 1 batch_stored += 1 else: _logger.error( "Failed to store embedding for motion %s (store returned %s)", motion_id, res, ) errors += 1 except Exception as exc: _logger.error( "Error storing embedding for motion %s: %s", motion_id, exc ) errors += 1 _logger.info( "Batch %d-%d: stored %d/%d (total: %d/%d)", batch_start, batch_start + len(batch), batch_stored, len(batch), stored + existing, total_motions, ) skipped_existing = int(existing) return stored, skipped_existing, skipped_no_text, errors