You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
motief/ai_provider.py

289 lines
9.8 KiB

"""Thin AI provider adapter for OpenRouter-compatible backends.
Provides simple helpers for embeddings and chat completions using requests.
This module is intentionally small and dependency-light to make testing easy.
"""
from __future__ import annotations
import os
import time
import random
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any
import requests
class ProviderError(Exception):
"""Terminal provider error (non-retryable or configuration issues)."""
def _get_base_url() -> str:
# Support multiple env var names and fall back to OpenRouter default
return os.environ.get(
"OPENROUTER_URL",
os.environ.get("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"),
)
def _get_api_key() -> str:
# Accept several common env var names for convenience
for name in ("OPENROUTER_API_KEY", "OPENROUTER_KEY", "OPENAI_API_KEY", "API_KEY"):
key = os.environ.get(name)
if key:
return key
raise ProviderError(
"OPENROUTER_API_KEY (or OPENAI_API_KEY) environment variable is required"
)
def _post_with_retries(
path: str, json: dict[str, Any], retries: int = 3
) -> requests.Response:
"""POST to the provider with a small retry/backoff for transient errors.
Retries on network errors (requests.ConnectionError) and 5xx responses.
"""
url = _get_base_url().rstrip("/") + path
headers = {
"Authorization": f"Bearer {_get_api_key()}",
"Content-Type": "application/json",
}
backoff = 0.5
for attempt in range(1, retries + 1):
try:
resp = requests.post(url, json=json, headers=headers, timeout=10)
except requests.ConnectionError as exc:
if attempt == retries:
raise ProviderError(
f"Connection error when calling provider: {exc}"
) from exc
sleep = backoff * (2 ** (attempt - 1))
sleep = sleep + random.uniform(0, sleep * 0.1)
time.sleep(sleep)
continue
# Treat 429 (Too Many Requests) as transient and respect Retry-After when present
if getattr(resp, "status_code", 0) == 429:
if attempt == retries:
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
retry_after = None
# headers are case-insensitive mapping on requests' Response
raw = (
resp.headers.get("Retry-After")
if getattr(resp, "headers", None)
else None
)
if raw:
# Try integer seconds first, then HTTP-date
try:
retry_after = int(raw)
except Exception:
try:
dt = parsedate_to_datetime(raw)
now = datetime.now(tz=dt.tzinfo or timezone.utc)
secs = (dt - now).total_seconds()
retry_after = max(0, int(secs))
except Exception:
retry_after = None
if retry_after is not None:
time.sleep(retry_after)
continue
# fallback to exponential backoff when Retry-After missing/invalid
sleep = backoff * (2 ** (attempt - 1))
sleep = sleep + random.uniform(0, sleep * 0.1)
time.sleep(sleep)
continue
# Treat 5xx as transient
status = getattr(resp, "status_code", 0)
if 500 <= status < 600:
if attempt == retries:
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
sleep = backoff * (2 ** (attempt - 1))
sleep = sleep + random.uniform(0, sleep * 0.1)
time.sleep(sleep)
continue
return resp
# Should not reach here
raise ProviderError("Failed to call provider after retries")
def get_embedding(text: str, model: str | None = None) -> list[float]:
"""Return an embedding vector for `text` using the configured provider.
Raises ProviderError for configuration or provider-side failures.
"""
if not isinstance(text, str):
raise ProviderError("text must be a string")
# Resolve model: prefer explicit arg, then env vars, then sensible Qwen default
if model is None:
model = (
os.environ.get("EMBEDDING_MODEL")
or os.environ.get("QWEN_EMBEDDING_MODEL")
or "qwen/qwen3-embedding-4b"
)
resp = _post_with_retries("/embeddings", json={"model": model, "input": text})
try:
data = resp.json()
except Exception as exc:
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
# Expecting {"data": [{"embedding": [...]}, ...]}
try:
embedding = data["data"][0]["embedding"]
except Exception as exc:
# If provider returns an error JSON, allow a local fallback when explicitly enabled
fallback = os.environ.get("ALLOW_LOCAL_EMBED_FALLBACK", "false").lower() in (
"1",
"true",
"yes",
)
if fallback:
# choose fallback dim via env or default
dim = int(os.environ.get("LOCAL_EMBED_DIM", "64"))
return _local_embedding(text, dim=dim)
raise ProviderError(f"Unexpected embedding response shape: {data}") from exc
if not isinstance(embedding, list):
raise ProviderError("Embedding is not a list")
return [float(x) for x in embedding]
def get_embeddings_batch(
texts: list[str], model: str | None = None, batch_size: int = 50
) -> list[list[float]]:
"""Return embedding vectors for multiple texts using batched API calls.
The OpenAI/OpenRouter /embeddings endpoint accepts an array of inputs.
This sends texts in chunks of `batch_size` and returns one embedding per input,
preserving order. Raises ProviderError on failure.
"""
if not texts:
return []
if model is None:
model = (
os.environ.get("EMBEDDING_MODEL")
or os.environ.get("QWEN_EMBEDDING_MODEL")
or "qwen/qwen3-embedding-4b"
)
all_embeddings: list[list[float]] = []
for start in range(0, len(texts), batch_size):
chunk = texts[start : start + batch_size]
resp = _post_with_retries("/embeddings", json={"model": model, "input": chunk})
try:
data = resp.json()
except Exception as exc:
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
try:
items = data["data"]
except Exception as exc:
# Check local fallback
fallback = os.environ.get(
"ALLOW_LOCAL_EMBED_FALLBACK", "false"
).lower() in ("1", "true", "yes")
if fallback:
dim = int(os.environ.get("LOCAL_EMBED_DIM", "64"))
all_embeddings.extend(_local_embedding(t, dim=dim) for t in chunk)
continue
raise ProviderError(
f"Unexpected batch embedding response shape: {data}"
) from exc
# Sort by index to guarantee order (API spec says index field is present)
items_sorted = sorted(items, key=lambda x: x.get("index", 0))
if len(items_sorted) != len(chunk):
raise ProviderError(
f"Expected {len(chunk)} embeddings, got {len(items_sorted)}"
)
for item in items_sorted:
emb = item.get("embedding")
if not isinstance(emb, list):
raise ProviderError(
f"Embedding at index {item.get('index')} is not a list"
)
all_embeddings.append([float(x) for x in emb])
return all_embeddings
def _local_embedding(text: str, dim: int = 64) -> list[float]:
"""Deterministic local fallback embedding based on SHA256.
Returns a list of `dim` floats in range [-1, 1]. Not semantically rich but useful
for local testing when provider embeddings are unavailable.
"""
import hashlib
h = hashlib.sha256(text.encode("utf8")).digest()
values = []
i = 0
# Expand digest if needed
while len(values) < dim:
# take 8 bytes -> 64-bit int
chunk = h[i % len(h) : (i % len(h)) + 8]
if len(chunk) < 8:
chunk = chunk.ljust(8, b"\0")
val = int.from_bytes(chunk, "big", signed=False)
# normalize to [-1,1]
valscale = (val / (2**64 - 1)) * 2.0 - 1.0
values.append(valscale)
i += 1
# re-hash occasionally to get more entropy
if i % (len(h) // 2 + 1) == 0:
h = hashlib.sha256(h + chunk).digest()
return values[:dim]
def chat_completion(messages: list[dict], model: str | None = None) -> str:
"""Return the assistant's content string for a chat completion request.
messages should be a list of dicts like {"role": "user", "content": "..."}.
"""
if not isinstance(messages, list):
raise ProviderError("messages must be a list of dicts")
# Resolve chat model: prefer explicit arg, then env var QWEN_MODEL, then a sensible default
if model is None:
model = (
os.environ.get("QWEN_MODEL")
or os.environ.get("CHAT_MODEL")
or "qwen/qwen-3.2"
)
resp = _post_with_retries(
"/chat/completions", json={"model": model, "messages": messages}
)
try:
data = resp.json()
except Exception as exc:
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
# Expecting {"choices": [{"message": {"content": "..."}}]}
try:
content = data["choices"][0]["message"]["content"]
except Exception as exc:
raise ProviderError(
f"Unexpected chat completion response shape: {data}"
) from exc
return str(content)