You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/pipeline/ai_provider_wrapper.py

116 lines
3.7 KiB

"""Wrapper around ai_provider to provide retries and smaller-batch fallback.
Returns a list of embedding vectors aligned with inputs. For inputs that
fail permanently the corresponding list entry will be None and an audit event
is appended via database.db.append_audit_event.
"""
from __future__ import annotations
import time
import random
from typing import List, Optional
import ai_provider
from database import db as motion_db
import logging
_logger = logging.getLogger(__name__)
def get_embeddings_with_retry(
texts: List[str],
motion_ids: Optional[List[Optional[int]]] = None,
model: Optional[str] = None,
batch_size: int = 50,
retries: int = 3,
db=None,
embedder=None,
) -> List[Optional[List[float]]]:
"""Return embeddings aligned with `texts` or None for failed items.
Strategy:
- Try batches of `batch_size` with up to `retries` attempts.
- On persistent batch failure, fall back to per-item attempts (batch_size=1).
- Record an audit event for items that permanently fail.
"""
if not texts:
return []
if motion_ids is None:
motion_ids = [None for _ in texts]
results: List[Optional[List[float]]] = [None] * len(texts)
# resolve embedder at call time; prefer injected, otherwise use ai_provider.get_embeddings_batch
_embedder = embedder if embedder is not None else ai_provider.get_embeddings_batch
def _attempt_batch(chunk_texts, start_index):
backoff = 0.5
last_exc = None
for attempt in range(1, retries + 1):
try:
emb_chunk = _embedder(
chunk_texts, model=model, batch_size=len(chunk_texts)
)
return emb_chunk, None
except Exception as exc:
last_exc = exc
if attempt == retries:
break
sleep = backoff * (2 ** (attempt - 1))
sleep = sleep + random.uniform(0, sleep * 0.1)
_logger.debug(
"Batch embedding attempt %d failed, retrying after %.2fs: %s",
attempt,
sleep,
exc,
)
time.sleep(sleep)
# persistent failure
_logger.warning(
"Batch embedding failed for texts starting at %d: %s", start_index, last_exc
)
return None, last_exc
# process in batches
i = 0
n = len(texts)
while i < n:
end = min(n, i + batch_size)
chunk = texts[i:end]
emb_chunk, emb_exc = _attempt_batch(chunk, i)
if emb_chunk is not None:
# success: assign
for j, emb in enumerate(emb_chunk):
results[i + j] = emb
i = end
continue
# batch failed -> fallback to per-item attempts
for j in range(i, end):
t = texts[j]
mid = motion_ids[j] if j < len(motion_ids) else None
single, single_exc = _attempt_batch([t], j)
if single:
results[j] = single[0]
continue
# permanent failure for this item
err_text = repr(single_exc) if single_exc is not None else "unknown"
try:
_db = db if db is not None else motion_db
_db.append_audit_event(
actor_id=None,
action="embedding_failed",
target_type="motion",
target_id=str(mid) if mid is not None else None,
metadata={"error": err_text},
)
except Exception:
_logger.exception("Failed to append audit event for embedding failure")
results[j] = None
i = end
return results