You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
7.8 KiB
255 lines
7.8 KiB
import logging
|
|
import json
|
|
from typing import Optional, List, Tuple
|
|
|
|
try:
|
|
import duckdb
|
|
except Exception:
|
|
duckdb = None
|
|
|
|
from database import MotionDatabase, db as default_db
|
|
import pipeline.ai_provider_wrapper as ai_wrapper
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_MODEL = "qwen/qwen3-embedding-4b"
|
|
|
|
|
|
def _select_text(
|
|
db: MotionDatabase, model: str, limit: Optional[int] = None
|
|
) -> List[Tuple[int, Optional[str]]]:
|
|
"""Select motions that do not yet have an embedding for `model`.
|
|
|
|
Returns list of (motion_id, text).
|
|
"""
|
|
if duckdb is None:
|
|
return []
|
|
conn = duckdb.connect(db.db_path)
|
|
params = [model]
|
|
# prefer layman_explanation > body_text > description > title
|
|
# (adds body_text as second-priority fallback so motion HTML is used when available)
|
|
sql = (
|
|
"SELECT m.id, COALESCE(m.layman_explanation, m.body_text, m.description, m.title) AS text"
|
|
" FROM motions m"
|
|
" LEFT JOIN embeddings e ON e.motion_id = m.id AND e.model = ?"
|
|
" WHERE e.id IS NULL"
|
|
)
|
|
if limit:
|
|
sql += " LIMIT ?"
|
|
params.append(limit)
|
|
|
|
try:
|
|
rows = conn.execute(sql, params).fetchall()
|
|
conn.close()
|
|
results: List[Tuple[int, Optional[str]]] = []
|
|
for r in rows:
|
|
text_val = r[1]
|
|
# treat empty strings as no text
|
|
if text_val is None:
|
|
text = None
|
|
else:
|
|
text = str(text_val).strip() or None
|
|
results.append((int(r[0]), text))
|
|
return results
|
|
except Exception as exc:
|
|
_logger.error("Error selecting motions for embeddings: %s", exc)
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
return []
|
|
|
|
|
|
def ensure_text_embeddings(
|
|
db_path: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
batch_size: int = 50,
|
|
db=None,
|
|
embedder=None,
|
|
) -> Tuple[int, int, int, int, list]:
|
|
"""Ensure all motions have text embeddings for `model`.
|
|
|
|
Uses batched API calls (batch_size texts per HTTP request) for speed.
|
|
Returns tuple (stored_count, skipped_existing, skipped_no_text, errors).
|
|
"""
|
|
model = model or DEFAULT_MODEL
|
|
if db is None:
|
|
db = MotionDatabase(db_path) if db_path else default_db
|
|
|
|
# motions to process
|
|
to_process = _select_text(db, model)
|
|
|
|
# how many already exist
|
|
if duckdb is None:
|
|
total_motions = 0
|
|
existing = 0
|
|
else:
|
|
conn = duckdb.connect(db.db_path)
|
|
try:
|
|
total_motions = conn.execute("SELECT COUNT(*) FROM motions").fetchone()[0]
|
|
except Exception:
|
|
total_motions = 0
|
|
|
|
try:
|
|
existing = conn.execute(
|
|
"SELECT COUNT(DISTINCT motion_id) FROM embeddings WHERE model = ?",
|
|
(model,),
|
|
).fetchone()[0]
|
|
except Exception:
|
|
existing = 0
|
|
|
|
conn.close()
|
|
|
|
stored = 0
|
|
skipped_no_text = 0
|
|
errors = 0
|
|
failed_ids: list = []
|
|
|
|
# Separate motions with text from those without
|
|
with_text: List[Tuple[int, str]] = []
|
|
for motion_id, text in to_process:
|
|
if not text:
|
|
_logger.info("Skipping motion %s: no text available", motion_id)
|
|
skipped_no_text += 1
|
|
else:
|
|
with_text.append((motion_id, text))
|
|
|
|
_logger.info(
|
|
"Processing %d motions in batches of %d (%d skipped no text, %d already exist)",
|
|
len(with_text),
|
|
batch_size,
|
|
skipped_no_text,
|
|
existing,
|
|
)
|
|
|
|
# Process in batches
|
|
for batch_start in range(0, len(with_text), batch_size):
|
|
batch = with_text[batch_start : batch_start + batch_size]
|
|
batch_ids = [mid for mid, _ in batch]
|
|
batch_texts = [txt for _, txt in batch]
|
|
|
|
vecs = ai_wrapper.get_embeddings_with_retry(
|
|
batch_texts,
|
|
motion_ids=batch_ids,
|
|
model=model,
|
|
batch_size=batch_size,
|
|
embedder=embedder,
|
|
)
|
|
|
|
batch_stored = 0
|
|
for (motion_id, _text), vec in zip(batch, vecs):
|
|
if not isinstance(vec, list):
|
|
_logger.warning(
|
|
"Embedding provider returned non-list for motion %s", motion_id
|
|
)
|
|
errors += 1
|
|
failed_ids.append(motion_id)
|
|
continue
|
|
|
|
try:
|
|
res = db.store_embedding(motion_id, model, vec)
|
|
if res and res > 0:
|
|
stored += 1
|
|
batch_stored += 1
|
|
else:
|
|
_logger.error(
|
|
"Failed to store embedding for motion %s (store returned %s)",
|
|
motion_id,
|
|
res,
|
|
)
|
|
errors += 1
|
|
failed_ids.append(motion_id)
|
|
except Exception as exc:
|
|
_logger.error(
|
|
"Error storing embedding for motion %s: %s", motion_id, exc
|
|
)
|
|
errors += 1
|
|
failed_ids.append(motion_id)
|
|
|
|
_logger.info(
|
|
"Batch %d-%d: stored %d/%d (total: %d/%d)",
|
|
batch_start,
|
|
batch_start + len(batch),
|
|
batch_stored,
|
|
len(batch),
|
|
stored + existing,
|
|
total_motions,
|
|
)
|
|
|
|
skipped_existing = int(existing)
|
|
# Historically some callers expected a 4-tuple; return the primary
|
|
# metrics (stored, skipped_existing, skipped_no_text, errors).
|
|
# The list of failed_ids is intentionally not returned here to remain
|
|
# backward-compatible with older callers.
|
|
return stored, skipped_existing, skipped_no_text, errors
|
|
|
|
|
|
def ensure_text_embeddings_for_ids(
|
|
db_path: Optional[str] = None,
|
|
ids: Optional[list] = None,
|
|
model: Optional[str] = None,
|
|
batch_size: int = 50,
|
|
db=None,
|
|
embedder=None,
|
|
) -> Tuple[int, int, int, int, list]:
|
|
"""Ensure embeddings for a specific list of motion ids.
|
|
|
|
This helper selects the motion texts for the supplied ids and reuses the
|
|
same embedding logic. Returns the same tuple shape as ensure_text_embeddings.
|
|
"""
|
|
model = model or DEFAULT_MODEL
|
|
if db is None:
|
|
db = MotionDatabase(db_path) if db_path else default_db
|
|
|
|
if not ids:
|
|
return 0, 0, 0, 0, []
|
|
|
|
# Fetch texts for given ids
|
|
if duckdb is None:
|
|
return 0, 0, 0, 0, []
|
|
conn = duckdb.connect(db.db_path)
|
|
try:
|
|
placeholders = ",".join("?" for _ in ids)
|
|
rows = conn.execute(
|
|
f"SELECT id, COALESCE(layman_explanation, body_text, description, title) AS text FROM motions WHERE id IN ({placeholders})",
|
|
ids,
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
|
|
to_process = [(int(r[0]), (r[1] or "").strip() or None) for r in rows]
|
|
|
|
# Reuse the main loop by creating a minimal local copy of the selection
|
|
stored = 0
|
|
skipped_no_text = 0
|
|
errors = 0
|
|
failed_ids = []
|
|
|
|
with_text = [(mid, txt) for mid, txt in to_process if txt]
|
|
|
|
for batch_start in range(0, len(with_text), batch_size):
|
|
batch = with_text[batch_start : batch_start + batch_size]
|
|
batch_ids = [mid for mid, _ in batch]
|
|
batch_texts = [txt for _, txt in batch]
|
|
|
|
vecs = ai_wrapper.get_embeddings_with_retry(
|
|
batch_texts,
|
|
motion_ids=batch_ids,
|
|
model=model,
|
|
batch_size=batch_size,
|
|
embedder=embedder,
|
|
)
|
|
|
|
for (motion_id, _text), vec in zip(batch, vecs):
|
|
if not isinstance(vec, list):
|
|
errors += 1
|
|
failed_ids.append(motion_id)
|
|
continue
|
|
res = db.store_embedding(motion_id, model, vec)
|
|
if res and res > 0:
|
|
stored += 1
|
|
else:
|
|
errors += 1
|
|
failed_ids.append(motion_id)
|
|
|
|
return stored, 0, skipped_no_text, errors, failed_ids
|
|
|