You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
375 lines
13 KiB
375 lines
13 KiB
"""Thin AI provider adapter for OpenRouter-compatible backends.
|
|
|
|
Provides simple helpers for embeddings and chat completions using requests.
|
|
This module is intentionally small and dependency-light to make testing easy.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
import random
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime, timezone
|
|
from email.utils import parsedate_to_datetime
|
|
from typing import Any
|
|
|
|
import requests
|
|
|
|
|
|
class ProviderError(Exception):
|
|
"""Terminal provider error (non-retryable or configuration issues)."""
|
|
|
|
|
|
def _get_base_url() -> str:
|
|
# Support multiple env var names and fall back to OpenRouter default
|
|
return os.environ.get(
|
|
"OPENROUTER_URL",
|
|
os.environ.get("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"),
|
|
)
|
|
|
|
|
|
def _get_api_key() -> str:
|
|
# Accept several common env var names for convenience
|
|
for name in ("OPENROUTER_API_KEY", "OPENROUTER_KEY", "OPENAI_API_KEY", "API_KEY"):
|
|
key = os.environ.get(name)
|
|
if key:
|
|
return key
|
|
raise ProviderError(
|
|
"OPENROUTER_API_KEY (or OPENAI_API_KEY) environment variable is required"
|
|
)
|
|
|
|
|
|
def _post_with_retries(
|
|
path: str, json: dict[str, Any], retries: int = 3
|
|
) -> requests.Response:
|
|
"""POST to the provider with a small retry/backoff for transient errors.
|
|
|
|
Retries on network errors (requests.ConnectionError) and 5xx responses.
|
|
"""
|
|
url = _get_base_url().rstrip("/") + path
|
|
headers = {
|
|
"Authorization": f"Bearer {_get_api_key()}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
backoff = 0.5
|
|
for attempt in range(1, retries + 1):
|
|
try:
|
|
resp = requests.post(url, json=json, headers=headers, timeout=60)
|
|
except (requests.ConnectionError, requests.Timeout) as exc:
|
|
if attempt == retries:
|
|
raise ProviderError(
|
|
f"Connection error when calling provider: {exc}"
|
|
) from exc
|
|
sleep = backoff * (2 ** (attempt - 1))
|
|
sleep = sleep + random.uniform(0, sleep * 0.1)
|
|
time.sleep(sleep)
|
|
continue
|
|
|
|
# Treat 429 (Too Many Requests) as transient and respect Retry-After when present
|
|
if getattr(resp, "status_code", 0) == 429:
|
|
if attempt == retries:
|
|
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
|
|
retry_after = None
|
|
# headers are case-insensitive mapping on requests' Response
|
|
raw = (
|
|
resp.headers.get("Retry-After")
|
|
if getattr(resp, "headers", None)
|
|
else None
|
|
)
|
|
if raw:
|
|
# Try integer seconds first, then HTTP-date
|
|
try:
|
|
retry_after = int(raw)
|
|
except Exception:
|
|
try:
|
|
dt = parsedate_to_datetime(raw)
|
|
now = datetime.now(tz=dt.tzinfo or timezone.utc)
|
|
secs = (dt - now).total_seconds()
|
|
retry_after = max(0, int(secs))
|
|
except Exception:
|
|
retry_after = None
|
|
|
|
if retry_after is not None:
|
|
time.sleep(retry_after)
|
|
continue
|
|
|
|
# fallback to exponential backoff when Retry-After missing/invalid
|
|
sleep = backoff * (2 ** (attempt - 1))
|
|
sleep = sleep + random.uniform(0, sleep * 0.1)
|
|
time.sleep(sleep)
|
|
continue
|
|
|
|
# Treat 5xx as transient
|
|
status = getattr(resp, "status_code", 0)
|
|
if 500 <= status < 600:
|
|
if attempt == retries:
|
|
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
|
|
sleep = backoff * (2 ** (attempt - 1))
|
|
sleep = sleep + random.uniform(0, sleep * 0.1)
|
|
time.sleep(sleep)
|
|
continue
|
|
|
|
return resp
|
|
|
|
# Should not reach here
|
|
raise ProviderError("Failed to call provider after retries")
|
|
|
|
|
|
def get_embedding(text: str, model: str | None = None) -> list[float]:
|
|
"""Return an embedding vector for `text` using the configured provider.
|
|
|
|
Raises ProviderError for configuration or provider-side failures.
|
|
"""
|
|
if not isinstance(text, str):
|
|
raise ProviderError("text must be a string")
|
|
|
|
# Resolve model: prefer explicit arg, then env vars, then sensible Qwen default
|
|
if model is None:
|
|
model = (
|
|
os.environ.get("EMBEDDING_MODEL")
|
|
or os.environ.get("QWEN_EMBEDDING_MODEL")
|
|
or "qwen/qwen3-embedding-4b"
|
|
)
|
|
resp = _post_with_retries("/embeddings", json={"model": model, "input": text})
|
|
|
|
try:
|
|
data = resp.json()
|
|
except Exception as exc:
|
|
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
|
|
|
|
# Expecting {"data": [{"embedding": [...]}, ...]}
|
|
try:
|
|
embedding = data["data"][0]["embedding"]
|
|
except Exception as exc:
|
|
# If provider returns an error JSON, allow a local fallback when explicitly enabled
|
|
fallback = os.environ.get("ALLOW_LOCAL_EMBED_FALLBACK", "false").lower() in (
|
|
"1",
|
|
"true",
|
|
"yes",
|
|
)
|
|
if fallback:
|
|
# choose fallback dim via env or default
|
|
dim = int(os.environ.get("LOCAL_EMBED_DIM", "64"))
|
|
return _local_embedding(text, dim=dim)
|
|
raise ProviderError(f"Unexpected embedding response shape: {data}") from exc
|
|
|
|
if not isinstance(embedding, list):
|
|
raise ProviderError("Embedding is not a list")
|
|
|
|
return [float(x) for x in embedding]
|
|
|
|
|
|
def get_embeddings_batch(
|
|
texts: list[str], model: str | None = None, batch_size: int = 50
|
|
) -> list[list[float]]:
|
|
"""Return embedding vectors for multiple texts using batched API calls.
|
|
|
|
The OpenAI/OpenRouter /embeddings endpoint accepts an array of inputs.
|
|
This sends texts in chunks of `batch_size` and returns one embedding per input,
|
|
preserving order. Raises ProviderError on failure.
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
if model is None:
|
|
model = (
|
|
os.environ.get("EMBEDDING_MODEL")
|
|
or os.environ.get("QWEN_EMBEDDING_MODEL")
|
|
or "qwen/qwen3-embedding-4b"
|
|
)
|
|
|
|
all_embeddings: list[list[float]] = []
|
|
|
|
for start in range(0, len(texts), batch_size):
|
|
chunk = texts[start : start + batch_size]
|
|
resp = _post_with_retries("/embeddings", json={"model": model, "input": chunk})
|
|
|
|
try:
|
|
data = resp.json()
|
|
except Exception as exc:
|
|
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
|
|
|
|
try:
|
|
items = data["data"]
|
|
except Exception as exc:
|
|
# Check local fallback
|
|
fallback = os.environ.get(
|
|
"ALLOW_LOCAL_EMBED_FALLBACK", "false"
|
|
).lower() in ("1", "true", "yes")
|
|
if fallback:
|
|
dim = int(os.environ.get("LOCAL_EMBED_DIM", "64"))
|
|
all_embeddings.extend(_local_embedding(t, dim=dim) for t in chunk)
|
|
continue
|
|
raise ProviderError(
|
|
f"Unexpected batch embedding response shape: {data}"
|
|
) from exc
|
|
|
|
# Sort by index to guarantee order (API spec says index field is present)
|
|
items_sorted = sorted(items, key=lambda x: x.get("index", 0))
|
|
|
|
if len(items_sorted) != len(chunk):
|
|
raise ProviderError(
|
|
f"Expected {len(chunk)} embeddings, got {len(items_sorted)}"
|
|
)
|
|
|
|
for item in items_sorted:
|
|
emb = item.get("embedding")
|
|
if not isinstance(emb, list):
|
|
raise ProviderError(
|
|
f"Embedding at index {item.get('index')} is not a list"
|
|
)
|
|
all_embeddings.append([float(x) for x in emb])
|
|
|
|
return all_embeddings
|
|
|
|
|
|
def _local_embedding(text: str, dim: int = 64) -> list[float]:
|
|
"""Deterministic local fallback embedding based on SHA256.
|
|
|
|
Returns a list of `dim` floats in range [-1, 1]. Not semantically rich but useful
|
|
for local testing when provider embeddings are unavailable.
|
|
"""
|
|
import hashlib
|
|
|
|
h = hashlib.sha256(text.encode("utf8")).digest()
|
|
values = []
|
|
i = 0
|
|
# Expand digest if needed
|
|
while len(values) < dim:
|
|
# take 8 bytes -> 64-bit int
|
|
chunk = h[i % len(h) : (i % len(h)) + 8]
|
|
if len(chunk) < 8:
|
|
chunk = chunk.ljust(8, b"\0")
|
|
val = int.from_bytes(chunk, "big", signed=False)
|
|
# normalize to [-1,1]
|
|
valscale = (val / (2**64 - 1)) * 2.0 - 1.0
|
|
values.append(valscale)
|
|
i += 1
|
|
# re-hash occasionally to get more entropy
|
|
if i % (len(h) // 2 + 1) == 0:
|
|
h = hashlib.sha256(h + chunk).digest()
|
|
|
|
return values[:dim]
|
|
|
|
|
|
def chat_completion(messages: list[dict], model: str | None = None) -> str:
|
|
"""Return the assistant's content string for a chat completion request.
|
|
|
|
messages should be a list of dicts like {"role": "user", "content": "..."}.
|
|
"""
|
|
if not isinstance(messages, list):
|
|
raise ProviderError("messages must be a list of dicts")
|
|
|
|
# Resolve chat model: prefer explicit arg, then env var QWEN_MODEL, then a sensible default
|
|
if model is None:
|
|
model = (
|
|
os.environ.get("QWEN_MODEL")
|
|
or os.environ.get("CHAT_MODEL")
|
|
or "qwen/qwen-3.2"
|
|
)
|
|
|
|
resp = _post_with_retries(
|
|
"/chat/completions", json={"model": model, "messages": messages}
|
|
)
|
|
|
|
try:
|
|
data = resp.json()
|
|
except Exception as exc:
|
|
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
|
|
|
|
# Expecting {"choices": [{"message": {"content": "..."}}]}
|
|
try:
|
|
content = data["choices"][0]["message"]["content"]
|
|
except Exception as exc:
|
|
raise ProviderError(
|
|
f"Unexpected chat completion response shape: {data}"
|
|
) from exc
|
|
|
|
return str(content)
|
|
|
|
|
|
def chat_completion_json(
|
|
messages: list[dict],
|
|
model: str | None = None,
|
|
json_schema: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Return parsed JSON from a chat completion request using JSON mode.
|
|
|
|
Some OpenRouter models (e.g., Google Gemma 4) support native JSON output via
|
|
the OpenAI-compatible response_format field. We request type='json_object' and
|
|
optionally supply a JSON schema in the top-level json_schema key.
|
|
"""
|
|
if not isinstance(messages, list):
|
|
raise ProviderError("messages must be a list of dicts")
|
|
|
|
if model is None:
|
|
model = (
|
|
os.environ.get("QWEN_MODEL")
|
|
or os.environ.get("CHAT_MODEL")
|
|
or "qwen/qwen-3.2"
|
|
)
|
|
|
|
payload: dict[str, Any] = {"model": model, "messages": messages}
|
|
|
|
# Prefer explicit JSON schema (supported by some providers/OpenAI spec)
|
|
if json_schema is not None:
|
|
payload["response_format"] = {
|
|
"type": "json_schema",
|
|
"json_schema": json_schema,
|
|
}
|
|
else:
|
|
# Fallback: simple JSON object mode
|
|
payload["response_format"] = {"type": "json_object"}
|
|
|
|
resp = _post_with_retries("/chat/completions", json=payload)
|
|
|
|
try:
|
|
data = resp.json()
|
|
except Exception as exc:
|
|
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
|
|
|
|
try:
|
|
content = data["choices"][0]["message"]["content"]
|
|
except Exception as exc:
|
|
raise ProviderError(
|
|
f"Unexpected chat completion response shape: {data}"
|
|
) from exc
|
|
|
|
import json as _json
|
|
|
|
try:
|
|
parsed = _json.loads(content)
|
|
except Exception as exc:
|
|
raise ProviderError(f"Model returned invalid JSON: {exc}") from exc
|
|
|
|
if not isinstance(parsed, dict):
|
|
raise ProviderError(f"Expected JSON object, got {type(parsed).__name__}")
|
|
|
|
return parsed
|
|
|
|
|
|
def chat_completion_json_parallel(
|
|
message_batches: list[list[dict]],
|
|
model: str | None = None,
|
|
json_schema: dict[str, Any] | None = None,
|
|
max_workers: int = 3,
|
|
) -> list[dict[str, Any]]:
|
|
"""Send multiple chat completion requests in parallel and return parsed JSON for each.
|
|
|
|
Useful for saturating the API when the provider supports concurrent requests.
|
|
Each item in message_batches is a separate conversation (list of messages).
|
|
Returns a list of parsed JSON dicts in the same order as the input batches.
|
|
"""
|
|
if not message_batches:
|
|
return []
|
|
|
|
def _fetch_one(messages: list[dict]) -> dict[str, Any]:
|
|
return chat_completion_json(messages, model=model, json_schema=json_schema)
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
futures = [executor.submit(_fetch_one, batch) for batch in message_batches]
|
|
results = [f.result() for f in futures]
|
|
|
|
return results
|
|
|