You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
2.6 KiB
80 lines
2.6 KiB
import json
|
|
import pytest
|
|
|
|
# duckdb is an optional dependency in some environments; skip test if not available
|
|
duckdb = pytest.importorskip("duckdb")
|
|
|
|
from database import MotionDatabase
|
|
|
|
|
|
def test_ensure_text_embeddings_monkeypatch(tmp_path, monkeypatch):
|
|
# prepare temp db
|
|
db_path = str(tmp_path / "motions.db")
|
|
db = MotionDatabase(db_path)
|
|
|
|
# create embeddings table (migration would normally do this)
|
|
conn = duckdb.connect(db.db_path)
|
|
# create embeddings table with autoincrement id for sqlite
|
|
conn.execute("CREATE SEQUENCE IF NOT EXISTS embeddings_id_seq START 1")
|
|
conn.execute(
|
|
"CREATE TABLE IF NOT EXISTS embeddings (id INTEGER PRIMARY KEY DEFAULT nextval('embeddings_id_seq'), motion_id INTEGER, model TEXT, vector JSON, created_at TIMESTAMP)"
|
|
)
|
|
|
|
# insert three motions
|
|
conn.execute(
|
|
"INSERT INTO motions (title, description, url, layman_explanation) VALUES (?, ?, ?, ?)",
|
|
("t1", "d1", "u1", "ex1"),
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO motions (title, description, url, layman_explanation) VALUES (?, ?, ?, ?)",
|
|
("t2", "d2", "u2", "ex2"),
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO motions (title, description, url, layman_explanation) VALUES (?, ?, ?, ?)",
|
|
("t3", "d3", "u3", "ex3"),
|
|
)
|
|
|
|
# fetch ids
|
|
rows = conn.execute("SELECT id FROM motions ORDER BY id").fetchall()
|
|
ids = [r[0] for r in rows]
|
|
|
|
# insert existing embedding for first motion
|
|
import json as _json
|
|
|
|
vec = _json.dumps([0.1] * 16)
|
|
conn.execute(
|
|
"INSERT INTO embeddings (motion_id, model, vector) VALUES (?, ?, ?)",
|
|
(ids[0], "test-model", vec),
|
|
)
|
|
|
|
conn.close()
|
|
|
|
# monkeypatch ai_provider.get_embeddings_batch (used by batched pipeline)
|
|
def fake_get_embeddings_batch(texts, model=None, batch_size=50):
|
|
return [[0.1] * 16 for _ in texts]
|
|
|
|
monkeypatch.setattr("ai_provider.get_embeddings_batch", fake_get_embeddings_batch)
|
|
|
|
# run ensure_text_embeddings
|
|
from pipeline.text_pipeline import ensure_text_embeddings
|
|
|
|
stored, skipped_existing, skipped_no_text, errors, failed_ids = (
|
|
ensure_text_embeddings(db_path=db_path, model="test-model")
|
|
)
|
|
|
|
assert stored == 2
|
|
assert skipped_existing == 1
|
|
assert skipped_no_text == 0
|
|
assert errors == 0
|
|
|
|
# verify stored vectors length
|
|
conn = duckdb.connect(db.db_path)
|
|
rows = conn.execute(
|
|
"SELECT vector FROM embeddings WHERE model = ? ORDER BY motion_id",
|
|
("test-model",),
|
|
).fetchall()
|
|
conn.close()
|
|
assert len(rows) == 3
|
|
for r in rows:
|
|
v = _json.loads(r[0])
|
|
assert len(v) == 16
|
|
|