You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/pipeline/text_pipeline.py

175 lines
5.2 KiB

import logging
import json
from typing import Optional, List, Tuple
import duckdb
from database import MotionDatabase, db as default_db
import ai_provider
_logger = logging.getLogger(__name__)
DEFAULT_MODEL = "qwen/qwen3-embedding-4b"
def _select_text(
db: MotionDatabase, model: str, limit: Optional[int] = None
) -> List[Tuple[int, Optional[str]]]:
"""Select motions that do not yet have an embedding for `model`.
Returns list of (motion_id, text).
"""
conn = duckdb.connect(db.db_path)
params = [model]
# prefer layman_explanation > description > title (keep compatibility with existing tests)
sql = (
"SELECT m.id, COALESCE(m.layman_explanation, m.description, m.title) AS text"
" FROM motions m"
" LEFT JOIN embeddings e ON e.motion_id = m.id AND e.model = ?"
" WHERE e.id IS NULL"
)
if limit:
sql += " LIMIT ?"
params.append(limit)
try:
rows = conn.execute(sql, params).fetchall()
conn.close()
results: List[Tuple[int, Optional[str]]] = []
for r in rows:
text_val = r[1]
# treat empty strings as no text
if text_val is None:
text = None
else:
text = str(text_val).strip() or None
results.append((int(r[0]), text))
return results
except Exception as exc:
_logger.error("Error selecting motions for embeddings: %s", exc)
try:
conn.close()
except Exception:
pass
return []
def ensure_text_embeddings(
db_path: Optional[str] = None, model: Optional[str] = None, batch_size: int = 50
) -> Tuple[int, int, int, int]:
"""Ensure all motions have text embeddings for `model`.
Uses batched API calls (batch_size texts per HTTP request) for speed.
Returns tuple (stored_count, skipped_existing, skipped_no_text, errors).
"""
model = model or DEFAULT_MODEL
db = MotionDatabase(db_path) if db_path else default_db
# motions to process
to_process = _select_text(db, model)
# how many already exist
conn = duckdb.connect(db.db_path)
try:
total_motions = conn.execute("SELECT COUNT(*) FROM motions").fetchone()[0]
except Exception:
total_motions = 0
try:
existing = conn.execute(
"SELECT COUNT(DISTINCT motion_id) FROM embeddings WHERE model = ?", (model,)
).fetchone()[0]
except Exception:
existing = 0
conn.close()
stored = 0
skipped_no_text = 0
errors = 0
# Separate motions with text from those without
with_text: List[Tuple[int, str]] = []
for motion_id, text in to_process:
if not text:
_logger.info("Skipping motion %s: no text available", motion_id)
skipped_no_text += 1
else:
with_text.append((motion_id, text))
_logger.info(
"Processing %d motions in batches of %d (%d skipped no text, %d already exist)",
len(with_text),
batch_size,
skipped_no_text,
existing,
)
# Process in batches
for batch_start in range(0, len(with_text), batch_size):
batch = with_text[batch_start : batch_start + batch_size]
batch_ids = [mid for mid, _ in batch]
batch_texts = [txt for _, txt in batch]
try:
vecs = ai_provider.get_embeddings_batch(
batch_texts, model=model, batch_size=batch_size
)
except Exception as exc:
_logger.error(
"Batch embedding failed for motions %s..%s: %s",
batch_ids[0],
batch_ids[-1],
exc,
)
errors += len(batch)
continue
if len(vecs) != len(batch):
_logger.error(
"Batch size mismatch: expected %d, got %d embeddings",
len(batch),
len(vecs),
)
errors += len(batch)
continue
batch_stored = 0
for (motion_id, _text), vec in zip(batch, vecs):
if not isinstance(vec, list):
_logger.warning(
"Embedding provider returned non-list for motion %s", motion_id
)
errors += 1
continue
try:
res = db.store_embedding(motion_id, model, vec)
if res and res > 0:
stored += 1
batch_stored += 1
else:
_logger.error(
"Failed to store embedding for motion %s (store returned %s)",
motion_id,
res,
)
errors += 1
except Exception as exc:
_logger.error(
"Error storing embedding for motion %s: %s", motion_id, exc
)
errors += 1
_logger.info(
"Batch %d-%d: stored %d/%d (total: %d/%d)",
batch_start,
batch_start + len(batch),
batch_stored,
len(batch),
stored + existing,
total_motions,
)
skipped_existing = int(existing)
return stored, skipped_existing, skipped_no_text, errors