"""Thin AI provider adapter for OpenRouter-compatible backends. Provides simple helpers for embeddings and chat completions using requests. This module is intentionally small and dependency-light to make testing easy. """ from __future__ import annotations import os import time import random from datetime import datetime, timezone from email.utils import parsedate_to_datetime from typing import Any import requests class ProviderError(Exception): """Terminal provider error (non-retryable or configuration issues).""" def _get_base_url() -> str: # Support multiple env var names and fall back to OpenRouter default return os.environ.get( "OPENROUTER_URL", os.environ.get("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"), ) def _get_api_key() -> str: # Accept several common env var names for convenience for name in ("OPENROUTER_API_KEY", "OPENROUTER_KEY", "OPENAI_API_KEY", "API_KEY"): key = os.environ.get(name) if key: return key raise ProviderError( "OPENROUTER_API_KEY (or OPENAI_API_KEY) environment variable is required" ) def _post_with_retries( path: str, json: dict[str, Any], retries: int = 3 ) -> requests.Response: """POST to the provider with a small retry/backoff for transient errors. Retries on network errors (requests.ConnectionError) and 5xx responses. """ url = _get_base_url().rstrip("/") + path headers = { "Authorization": f"Bearer {_get_api_key()}", "Content-Type": "application/json", } backoff = 0.5 for attempt in range(1, retries + 1): try: resp = requests.post(url, json=json, headers=headers, timeout=10) except requests.ConnectionError as exc: if attempt == retries: raise ProviderError( f"Connection error when calling provider: {exc}" ) from exc sleep = backoff * (2 ** (attempt - 1)) sleep = sleep + random.uniform(0, sleep * 0.1) time.sleep(sleep) continue # Treat 429 (Too Many Requests) as transient and respect Retry-After when present if getattr(resp, "status_code", 0) == 429: if attempt == retries: raise ProviderError(f"Provider returned HTTP {resp.status_code}") retry_after = None # headers are case-insensitive mapping on requests' Response raw = ( resp.headers.get("Retry-After") if getattr(resp, "headers", None) else None ) if raw: # Try integer seconds first, then HTTP-date try: retry_after = int(raw) except Exception: try: dt = parsedate_to_datetime(raw) now = datetime.now(tz=dt.tzinfo or timezone.utc) secs = (dt - now).total_seconds() retry_after = max(0, int(secs)) except Exception: retry_after = None if retry_after is not None: time.sleep(retry_after) continue # fallback to exponential backoff when Retry-After missing/invalid sleep = backoff * (2 ** (attempt - 1)) sleep = sleep + random.uniform(0, sleep * 0.1) time.sleep(sleep) continue # Treat 5xx as transient status = getattr(resp, "status_code", 0) if 500 <= status < 600: if attempt == retries: raise ProviderError(f"Provider returned HTTP {resp.status_code}") sleep = backoff * (2 ** (attempt - 1)) sleep = sleep + random.uniform(0, sleep * 0.1) time.sleep(sleep) continue # Treat 429 (rate limiting) as transient and respect Retry-After header when present if status == 429: if attempt == retries: raise ProviderError(f"Provider returned HTTP {resp.status_code}") retry_after = None try: # header may be present as int seconds or as string retry_after = resp.headers.get("Retry-After") except Exception: retry_after = None if retry_after is not None: try: sleep = float(retry_after) except Exception: # fallback to exponential backoff if header unparsable sleep = backoff * (2 ** (attempt - 1)) else: sleep = backoff * (2 ** (attempt - 1)) sleep = sleep + random.uniform(0, sleep * 0.1) time.sleep(sleep) continue return resp # Should not reach here raise ProviderError("Failed to call provider after retries") def get_embedding(text: str, model: str | None = None) -> list[float]: """Return an embedding vector for `text` using the configured provider. Raises ProviderError for configuration or provider-side failures. """ if not isinstance(text, str): raise ProviderError("text must be a string") # Resolve model: prefer explicit arg, then env vars, then sensible Qwen default if model is None: model = ( os.environ.get("EMBEDDING_MODEL") or os.environ.get("QWEN_EMBEDDING_MODEL") or "qwen/qwen3-embedding-4b" ) resp = _post_with_retries("/embeddings", json={"model": model, "input": text}) try: data = resp.json() except Exception as exc: raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc # Expecting {"data": [{"embedding": [...]}, ...]} try: embedding = data["data"][0]["embedding"] except Exception as exc: # If provider returns an error JSON, allow a local fallback when explicitly enabled fallback = os.environ.get("ALLOW_LOCAL_EMBED_FALLBACK", "false").lower() in ( "1", "true", "yes", ) if fallback: # choose fallback dim via env or default dim = int(os.environ.get("LOCAL_EMBED_DIM", "64")) return _local_embedding(text, dim=dim) raise ProviderError(f"Unexpected embedding response shape: {data}") from exc if not isinstance(embedding, list): raise ProviderError("Embedding is not a list") return [float(x) for x in embedding] def _local_embedding(text: str, dim: int = 64) -> list[float]: """Deterministic local fallback embedding based on SHA256. Returns a list of `dim` floats in range [-1, 1]. Not semantically rich but useful for local testing when provider embeddings are unavailable. """ import hashlib h = hashlib.sha256(text.encode("utf8")).digest() values = [] i = 0 # Expand digest if needed while len(values) < dim: # take 8 bytes -> 64-bit int chunk = h[i % len(h) : (i % len(h)) + 8] if len(chunk) < 8: chunk = chunk.ljust(8, b"\0") val = int.from_bytes(chunk, "big", signed=False) # normalize to [-1,1] valscale = (val / (2**64 - 1)) * 2.0 - 1.0 values.append(valscale) i += 1 # re-hash occasionally to get more entropy if i % (len(h) // 2 + 1) == 0: h = hashlib.sha256(h + chunk).digest() return values[:dim] def chat_completion(messages: list[dict], model: str | None = None) -> str: """Return the assistant's content string for a chat completion request. messages should be a list of dicts like {"role": "user", "content": "..."}. """ if not isinstance(messages, list): raise ProviderError("messages must be a list of dicts") # Resolve chat model: prefer explicit arg, then env var QWEN_MODEL, then a sensible default if model is None: model = ( os.environ.get("QWEN_MODEL") or os.environ.get("CHAT_MODEL") or "qwen/qwen-3.2" ) resp = _post_with_retries( "/chat/completions", json={"model": model, "messages": messages} ) try: data = resp.json() except Exception as exc: raise ProviderError(f"Invalid JSON response from provider: {exc}") from exc # Expecting {"choices": [{"message": {"content": "..."}}]} try: content = data["choices"][0]["message"]["content"] except Exception as exc: raise ProviderError( f"Unexpected chat completion response shape: {data}" ) from exc return str(content)