You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
8.5 KiB
249 lines
8.5 KiB
"""Thin AI provider adapter for OpenRouter-compatible backends.
|
|
|
|
Provides simple helpers for embeddings and chat completions using requests.
|
|
This module is intentionally small and dependency-light to make testing easy.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
import random
|
|
from datetime import datetime, timezone
|
|
from email.utils import parsedate_to_datetime
|
|
from typing import Any
|
|
|
|
import requests
|
|
|
|
|
|
class ProviderError(Exception):
|
|
"""Terminal provider error (non-retryable or configuration issues)."""
|
|
|
|
|
|
def _get_base_url() -> str:
|
|
# Support multiple env var names and fall back to OpenRouter default
|
|
return os.environ.get(
|
|
"OPENROUTER_URL",
|
|
os.environ.get("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"),
|
|
)
|
|
|
|
|
|
def _get_api_key() -> str:
|
|
# Accept several common env var names for convenience
|
|
for name in ("OPENROUTER_API_KEY", "OPENROUTER_KEY", "OPENAI_API_KEY", "API_KEY"):
|
|
key = os.environ.get(name)
|
|
if key:
|
|
return key
|
|
raise ProviderError(
|
|
"OPENROUTER_API_KEY (or OPENAI_API_KEY) environment variable is required"
|
|
)
|
|
|
|
|
|
def _post_with_retries(
|
|
path: str, json: dict[str, Any], retries: int = 3
|
|
) -> requests.Response:
|
|
"""POST to the provider with a small retry/backoff for transient errors.
|
|
|
|
Retries on network errors (requests.ConnectionError) and 5xx responses.
|
|
"""
|
|
url = _get_base_url().rstrip("/") + path
|
|
headers = {
|
|
"Authorization": f"Bearer {_get_api_key()}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
backoff = 0.5
|
|
for attempt in range(1, retries + 1):
|
|
try:
|
|
resp = requests.post(url, json=json, headers=headers, timeout=10)
|
|
except requests.ConnectionError as exc:
|
|
if attempt == retries:
|
|
raise ProviderError(
|
|
f"Connection error when calling provider: {exc}"
|
|
) from exc
|
|
sleep = backoff * (2 ** (attempt - 1))
|
|
sleep = sleep + random.uniform(0, sleep * 0.1)
|
|
time.sleep(sleep)
|
|
continue
|
|
|
|
# Treat 429 (Too Many Requests) as transient and respect Retry-After when present
|
|
if getattr(resp, "status_code", 0) == 429:
|
|
if attempt == retries:
|
|
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
|
|
retry_after = None
|
|
# headers are case-insensitive mapping on requests' Response
|
|
raw = (
|
|
resp.headers.get("Retry-After")
|
|
if getattr(resp, "headers", None)
|
|
else None
|
|
)
|
|
if raw:
|
|
# Try integer seconds first, then HTTP-date
|
|
try:
|
|
retry_after = int(raw)
|
|
except Exception:
|
|
try:
|
|
dt = parsedate_to_datetime(raw)
|
|
now = datetime.now(tz=dt.tzinfo or timezone.utc)
|
|
secs = (dt - now).total_seconds()
|
|
retry_after = max(0, int(secs))
|
|
except Exception:
|
|
retry_after = None
|
|
|
|
if retry_after is not None:
|
|
time.sleep(retry_after)
|
|
continue
|
|
|
|
# fallback to exponential backoff when Retry-After missing/invalid
|
|
sleep = backoff * (2 ** (attempt - 1))
|
|
sleep = sleep + random.uniform(0, sleep * 0.1)
|
|
time.sleep(sleep)
|
|
continue
|
|
|
|
# Treat 5xx as transient
|
|
status = getattr(resp, "status_code", 0)
|
|
if 500 <= status < 600:
|
|
if attempt == retries:
|
|
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
|
|
sleep = backoff * (2 ** (attempt - 1))
|
|
sleep = sleep + random.uniform(0, sleep * 0.1)
|
|
time.sleep(sleep)
|
|
continue
|
|
|
|
# Treat 429 (rate limiting) as transient and respect Retry-After header when present
|
|
if status == 429:
|
|
if attempt == retries:
|
|
raise ProviderError(f"Provider returned HTTP {resp.status_code}")
|
|
retry_after = None
|
|
try:
|
|
# header may be present as int seconds or as string
|
|
retry_after = resp.headers.get("Retry-After")
|
|
except Exception:
|
|
retry_after = None
|
|
|
|
if retry_after is not None:
|
|
try:
|
|
sleep = float(retry_after)
|
|
except Exception:
|
|
# fallback to exponential backoff if header unparsable
|
|
sleep = backoff * (2 ** (attempt - 1))
|
|
else:
|
|
sleep = backoff * (2 ** (attempt - 1))
|
|
|
|
sleep = sleep + random.uniform(0, sleep * 0.1)
|
|
time.sleep(sleep)
|
|
continue
|
|
|
|
return resp
|
|
|
|
# Should not reach here
|
|
raise ProviderError("Failed to call provider after retries")
|
|
|
|
|
|
def get_embedding(text: str, model: str | None = None) -> list[float]:
|
|
"""Return an embedding vector for `text` using the configured provider.
|
|
|
|
Raises ProviderError for configuration or provider-side failures.
|
|
"""
|
|
if not isinstance(text, str):
|
|
raise ProviderError("text must be a string")
|
|
|
|
# Resolve model: prefer explicit arg, then env vars, then sensible Qwen default
|
|
if model is None:
|
|
model = (
|
|
os.environ.get("EMBEDDING_MODEL")
|
|
or os.environ.get("QWEN_EMBEDDING_MODEL")
|
|
or "qwen/qwen3-embedding-4b"
|
|
)
|
|
resp = _post_with_retries("/embeddings", json={"model": model, "input": text})
|
|
|
|
try:
|
|
data = resp.json()
|
|
except Exception as exc:
|
|
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
|
|
|
|
# Expecting {"data": [{"embedding": [...]}, ...]}
|
|
try:
|
|
embedding = data["data"][0]["embedding"]
|
|
except Exception as exc:
|
|
# If provider returns an error JSON, allow a local fallback when explicitly enabled
|
|
fallback = os.environ.get("ALLOW_LOCAL_EMBED_FALLBACK", "false").lower() in (
|
|
"1",
|
|
"true",
|
|
"yes",
|
|
)
|
|
if fallback:
|
|
# choose fallback dim via env or default
|
|
dim = int(os.environ.get("LOCAL_EMBED_DIM", "64"))
|
|
return _local_embedding(text, dim=dim)
|
|
raise ProviderError(f"Unexpected embedding response shape: {data}") from exc
|
|
|
|
if not isinstance(embedding, list):
|
|
raise ProviderError("Embedding is not a list")
|
|
|
|
return [float(x) for x in embedding]
|
|
|
|
|
|
def _local_embedding(text: str, dim: int = 64) -> list[float]:
|
|
"""Deterministic local fallback embedding based on SHA256.
|
|
|
|
Returns a list of `dim` floats in range [-1, 1]. Not semantically rich but useful
|
|
for local testing when provider embeddings are unavailable.
|
|
"""
|
|
import hashlib
|
|
|
|
h = hashlib.sha256(text.encode("utf8")).digest()
|
|
values = []
|
|
i = 0
|
|
# Expand digest if needed
|
|
while len(values) < dim:
|
|
# take 8 bytes -> 64-bit int
|
|
chunk = h[i % len(h) : (i % len(h)) + 8]
|
|
if len(chunk) < 8:
|
|
chunk = chunk.ljust(8, b"\0")
|
|
val = int.from_bytes(chunk, "big", signed=False)
|
|
# normalize to [-1,1]
|
|
valscale = (val / (2**64 - 1)) * 2.0 - 1.0
|
|
values.append(valscale)
|
|
i += 1
|
|
# re-hash occasionally to get more entropy
|
|
if i % (len(h) // 2 + 1) == 0:
|
|
h = hashlib.sha256(h + chunk).digest()
|
|
|
|
return values[:dim]
|
|
|
|
|
|
def chat_completion(messages: list[dict], model: str | None = None) -> str:
|
|
"""Return the assistant's content string for a chat completion request.
|
|
|
|
messages should be a list of dicts like {"role": "user", "content": "..."}.
|
|
"""
|
|
if not isinstance(messages, list):
|
|
raise ProviderError("messages must be a list of dicts")
|
|
|
|
# Resolve chat model: prefer explicit arg, then env var QWEN_MODEL, then a sensible default
|
|
if model is None:
|
|
model = (
|
|
os.environ.get("QWEN_MODEL")
|
|
or os.environ.get("CHAT_MODEL")
|
|
or "qwen/qwen-3.2"
|
|
)
|
|
|
|
resp = _post_with_retries(
|
|
"/chat/completions", json={"model": model, "messages": messages}
|
|
)
|
|
|
|
try:
|
|
data = resp.json()
|
|
except Exception as exc:
|
|
raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc
|
|
|
|
# Expecting {"choices": [{"message": {"content": "..."}}]}
|
|
try:
|
|
content = data["choices"][0]["message"]["content"]
|
|
except Exception as exc:
|
|
raise ProviderError(
|
|
f"Unexpected chat completion response shape: {data}"
|
|
) from exc
|
|
|
|
return str(content)
|
|
|