You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
49 lines
1.3 KiB
49 lines
1.3 KiB
import os
|
|
import types
|
|
|
|
import pytest
|
|
|
|
import ai_provider
|
|
|
|
|
|
class DummyResponse:
|
|
def __init__(self, status_code=200, json_data=None):
|
|
self.status_code = status_code
|
|
self._json = json_data or {}
|
|
|
|
def json(self):
|
|
return self._json
|
|
|
|
|
|
def test_get_embedding_success(monkeypatch):
|
|
fake = DummyResponse(json_data={"data": [{"embedding": [0.1, 0.2, 0.3]}]})
|
|
|
|
def fake_post(url, json, headers, timeout):
|
|
return fake
|
|
|
|
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-test")
|
|
monkeypatch.setattr("requests.post", fake_post)
|
|
|
|
emb = ai_provider.get_embedding("hello world")
|
|
assert emb == [0.1, 0.2, 0.3]
|
|
|
|
|
|
def test_chat_completion_success(monkeypatch):
|
|
fake = DummyResponse(json_data={"choices": [{"message": {"content": "summary"}}]})
|
|
|
|
def fake_post(url, json, headers, timeout):
|
|
return fake
|
|
|
|
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-test")
|
|
monkeypatch.setattr("requests.post", fake_post)
|
|
|
|
out = ai_provider.chat_completion([{"role": "user", "content": "hi"}])
|
|
assert out == "summary"
|
|
|
|
|
|
def test_missing_api_key_raises(monkeypatch):
|
|
# Ensure env var is not set
|
|
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
|
|
|
with pytest.raises(ai_provider.ProviderError):
|
|
ai_provider.get_embedding("x")
|
|
|