You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
motief/ai_provider.py

375 lines
13 KiB

"""Thin AI provider adapter for OpenRouter-compatible backends.
Provides simple helpers for embeddings and chat completions using requests.
This module is intentionally small and dependency-light to make testing easy.
"""
from __future__ import annotations
import os
import time
import random
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any
import requests
class ProviderError(Exception):
"""Terminal provider error (non-retryable or configuration issues)."""
def _get_base_url() -> str:
# Support multiple env var names and fall back to OpenRouter default
return os.environ.get(
"OPENROUTER_URL",
os.environ.get("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"),
)
def _get_api_key() -> str:
# Accept several common env var names for convenience
for name in ("OPENROUTER_API_KEY", "OPENROUTER_KEY", "OPENAI_API_KEY", "API_KEY"):
key = os.environ.get(name)
if key:
return key
raise ProviderError(
"OPENROUTER_API_KEY (or OPENAI_API_KEY) environment variable is required"
)
def _post_with_retries(
path: str, json: dict[str, Any], retries: int = 3
) -> requests.Response:
"""POST to the provider with a small retry/backoff for transient errors.
Retries on network errors (requests.ConnectionError) and 5xx responses.
"""
url = _get_base_url().rstrip("/") + path
headers = {
"Authorization": f"Bearer {_get_api_key()}",
"Content-Type": "application/json",
}
backoff = 0.5
for attempt in range(1, retries + 1):
try:
resp = requests.post(url, json=json, headers=headers, timeout=60)
except requests.ConnectionError as exc:
if attempt == retries:
raise ProviderError(
f"Connection error when calling provider: {exc}"
) from exc
sleep = backoff * (2 ** (attempt - 1))
sleep = sleep + random.uniform(0, sleep * 0.1)
time.sleep(sleep)
continue
# Treat 429 (Too Many Requests) as transient and respect Retry-After when present
if getattr(resp, "status_code", 0) == 429:
if attempt == retries:
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
retry_after = None
# headers are case-insensitive mapping on requests' Response
raw = (
resp.headers.get("Retry-After")
if getattr(resp, "headers", None)
else None
)
if raw:
# Try integer seconds first, then HTTP-date
try:
retry_after = int(raw)
except Exception:
try:
dt = parsedate_to_datetime(raw)
now = datetime.now(tz=dt.tzinfo or timezone.utc)
secs = (dt - now).total_seconds()
retry_after = max(0, int(secs))
except Exception:
retry_after = None
if retry_after is not None:
time.sleep(retry_after)
continue
# fallback to exponential backoff when Retry-After missing/invalid
sleep = backoff * (2 ** (attempt - 1))
sleep = sleep + random.uniform(0, sleep * 0.1)
time.sleep(sleep)
continue
# Treat 5xx as transient
status = getattr(resp, "status_code", 0)
if 500 <= status < 600:
if attempt == retries:
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
sleep = backoff * (2 ** (attempt - 1))
sleep = sleep + random.uniform(0, sleep * 0.1)
time.sleep(sleep)
continue
return resp
# Should not reach here
raise ProviderError("Failed to call provider after retries")
def get_embedding(text: str, model: str | None = None) -> list[float]:
"""Return an embedding vector for `text` using the configured provider.
Raises ProviderError for configuration or provider-side failures.
"""
if not isinstance(text, str):
raise ProviderError("text must be a string")
# Resolve model: prefer explicit arg, then env vars, then sensible Qwen default
if model is None:
model = (
os.environ.get("EMBEDDING_MODEL")
or os.environ.get("QWEN_EMBEDDING_MODEL")
or "qwen/qwen3-embedding-4b"
)
resp = _post_with_retries("/embeddings", json={"model": model, "input": text})
try:
data = resp.json()
except Exception as exc:
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
# Expecting {"data": [{"embedding": [...]}, ...]}
try:
embedding = data["data"][0]["embedding"]
except Exception as exc:
# If provider returns an error JSON, allow a local fallback when explicitly enabled
fallback = os.environ.get("ALLOW_LOCAL_EMBED_FALLBACK", "false").lower() in (
"1",
"true",
"yes",
)
if fallback:
# choose fallback dim via env or default
dim = int(os.environ.get("LOCAL_EMBED_DIM", "64"))
return _local_embedding(text, dim=dim)
raise ProviderError(f"Unexpected embedding response shape: {data}") from exc
if not isinstance(embedding, list):
raise ProviderError("Embedding is not a list")
return [float(x) for x in embedding]
def get_embeddings_batch(
texts: list[str], model: str | None = None, batch_size: int = 50
) -> list[list[float]]:
"""Return embedding vectors for multiple texts using batched API calls.
The OpenAI/OpenRouter /embeddings endpoint accepts an array of inputs.
This sends texts in chunks of `batch_size` and returns one embedding per input,
preserving order. Raises ProviderError on failure.
"""
if not texts:
return []
if model is None:
model = (
os.environ.get("EMBEDDING_MODEL")
or os.environ.get("QWEN_EMBEDDING_MODEL")
or "qwen/qwen3-embedding-4b"
)
all_embeddings: list[list[float]] = []
for start in range(0, len(texts), batch_size):
chunk = texts[start : start + batch_size]
resp = _post_with_retries("/embeddings", json={"model": model, "input": chunk})
try:
data = resp.json()
except Exception as exc:
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
try:
items = data["data"]
except Exception as exc:
# Check local fallback
fallback = os.environ.get(
"ALLOW_LOCAL_EMBED_FALLBACK", "false"
).lower() in ("1", "true", "yes")
if fallback:
dim = int(os.environ.get("LOCAL_EMBED_DIM", "64"))
all_embeddings.extend(_local_embedding(t, dim=dim) for t in chunk)
continue
raise ProviderError(
f"Unexpected batch embedding response shape: {data}"
) from exc
# Sort by index to guarantee order (API spec says index field is present)
items_sorted = sorted(items, key=lambda x: x.get("index", 0))
if len(items_sorted) != len(chunk):
raise ProviderError(
f"Expected {len(chunk)} embeddings, got {len(items_sorted)}"
)
for item in items_sorted:
emb = item.get("embedding")
if not isinstance(emb, list):
raise ProviderError(
f"Embedding at index {item.get('index')} is not a list"
)
all_embeddings.append([float(x) for x in emb])
return all_embeddings
def _local_embedding(text: str, dim: int = 64) -> list[float]:
"""Deterministic local fallback embedding based on SHA256.
Returns a list of `dim` floats in range [-1, 1]. Not semantically rich but useful
for local testing when provider embeddings are unavailable.
"""
import hashlib
h = hashlib.sha256(text.encode("utf8")).digest()
values = []
i = 0
# Expand digest if needed
while len(values) < dim:
# take 8 bytes -> 64-bit int
chunk = h[i % len(h) : (i % len(h)) + 8]
if len(chunk) < 8:
chunk = chunk.ljust(8, b"\0")
val = int.from_bytes(chunk, "big", signed=False)
# normalize to [-1,1]
valscale = (val / (2**64 - 1)) * 2.0 - 1.0
values.append(valscale)
i += 1
# re-hash occasionally to get more entropy
if i % (len(h) // 2 + 1) == 0:
h = hashlib.sha256(h + chunk).digest()
return values[:dim]
def chat_completion(messages: list[dict], model: str | None = None) -> str:
"""Return the assistant's content string for a chat completion request.
messages should be a list of dicts like {"role": "user", "content": "..."}.
"""
if not isinstance(messages, list):
raise ProviderError("messages must be a list of dicts")
# Resolve chat model: prefer explicit arg, then env var QWEN_MODEL, then a sensible default
if model is None:
model = (
os.environ.get("QWEN_MODEL")
or os.environ.get("CHAT_MODEL")
or "qwen/qwen-3.2"
)
resp = _post_with_retries(
"/chat/completions", json={"model": model, "messages": messages}
)
try:
data = resp.json()
except Exception as exc:
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
# Expecting {"choices": [{"message": {"content": "..."}}]}
try:
content = data["choices"][0]["message"]["content"]
except Exception as exc:
raise ProviderError(
f"Unexpected chat completion response shape: {data}"
) from exc
return str(content)
def chat_completion_json(
messages: list[dict],
model: str | None = None,
json_schema: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Return parsed JSON from a chat completion request using JSON mode.
Some OpenRouter models (e.g., Google Gemma 4) support native JSON output via
the OpenAI-compatible response_format field. We request type='json_object' and
optionally supply a JSON schema in the top-level json_schema key.
"""
if not isinstance(messages, list):
raise ProviderError("messages must be a list of dicts")
if model is None:
model = (
os.environ.get("QWEN_MODEL")
or os.environ.get("CHAT_MODEL")
or "qwen/qwen-3.2"
)
payload: dict[str, Any] = {"model": model, "messages": messages}
# Prefer explicit JSON schema (supported by some providers/OpenAI spec)
if json_schema is not None:
payload["response_format"] = {
"type": "json_schema",
"json_schema": json_schema,
}
else:
# Fallback: simple JSON object mode
payload["response_format"] = {"type": "json_object"}
resp = _post_with_retries("/chat/completions", json=payload)
try:
data = resp.json()
except Exception as exc:
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
try:
content = data["choices"][0]["message"]["content"]
except Exception as exc:
raise ProviderError(
f"Unexpected chat completion response shape: {data}"
) from exc
import json as _json
try:
parsed = _json.loads(content)
except Exception as exc:
raise ProviderError(f"Model returned invalid JSON: {exc}") from exc
if not isinstance(parsed, dict):
raise ProviderError(f"Expected JSON object, got {type(parsed).__name__}")
return parsed
def chat_completion_json_parallel(
message_batches: list[list[dict]],
model: str | None = None,
json_schema: dict[str, Any] | None = None,
max_workers: int = 3,
) -> list[dict[str, Any]]:
"""Send multiple chat completion requests in parallel and return parsed JSON for each.
Useful for saturating the API when the provider supports concurrent requests.
Each item in message_batches is a separate conversation (list of messages).
Returns a list of parsed JSON dicts in the same order as the input batches.
"""
if not message_batches:
return []
def _fetch_one(messages: list[dict]) -> dict[str, Any]:
return chat_completion_json(messages, model=model, json_schema=json_schema)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_fetch_one, batch) for batch in message_batches]
results = [f.result() for f in futures]
return results