344 lines
15 KiB
Python
344 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""test_retrieve.py — hermetic tests for scripts/retrieve.py and scripts/rerank.py.
|
|
|
|
No network, no ollama, no LLM calls. Tests cover:
|
|
- import_sibling resolves hyphenated module names
|
|
- chunk_snippet truncation behavior
|
|
- rerank.cosine math correctness
|
|
- rerank.rerank() no-op behavior when ollama is unreachable
|
|
- retrieve.py exit 10 (not provisioned) when chunks/index are missing
|
|
- dedupe-by-page logic via integration smoke test on synthetic fixtures
|
|
|
|
Usage:
|
|
python3 tests/test_retrieve.py
|
|
"""
|
|
import importlib.util
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import unittest.mock
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent
|
|
RETRIEVE = ROOT / "scripts" / "retrieve.py"
|
|
RERANK = ROOT / "scripts" / "rerank.py"
|
|
BM25 = ROOT / "scripts" / "bm25-index.py"
|
|
|
|
|
|
def import_script(name, path):
|
|
spec = importlib.util.spec_from_file_location(name, path)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
return mod
|
|
|
|
|
|
retrieve = import_script("retrieve", RETRIEVE)
|
|
rerank = import_script("rerank", RERANK)
|
|
bm25 = import_script("bm25", BM25)
|
|
|
|
|
|
class Fail(SystemExit):
|
|
pass
|
|
|
|
|
|
def assert_eq(label, expected, actual):
|
|
if expected != actual:
|
|
raise Fail(f"FAIL {label}: expected {expected!r}, got {actual!r}")
|
|
print(f"OK {label}")
|
|
|
|
|
|
def assert_true(label, cond, hint=""):
|
|
if not cond:
|
|
raise Fail(f"FAIL {label}{(': ' + hint) if hint else ''}")
|
|
print(f"OK {label}")
|
|
|
|
|
|
def assert_close(label, expected, actual, eps=1e-6):
|
|
if abs(expected - actual) > eps:
|
|
raise Fail(f"FAIL {label}: expected ~{expected}, got {actual}")
|
|
print(f"OK {label}")
|
|
|
|
|
|
# ─── import_sibling ──────────────────────────────────────────────────────────
|
|
def test_import_sibling_resolves_hyphenated_names():
|
|
"""retrieve.import_sibling('bm25_index', 'bm25-index.py') must succeed."""
|
|
mod = retrieve.import_sibling("bm25_index", "bm25-index.py")
|
|
assert_true("import_sibling returns module", mod is not None)
|
|
assert_true("module has tokenize()", callable(getattr(mod, "tokenize", None)))
|
|
|
|
|
|
# ─── chunk_snippet ───────────────────────────────────────────────────────────
|
|
def test_chunk_snippet_short():
|
|
"""Short chunks should pass through unchanged."""
|
|
out = retrieve.chunk_snippet({"raw_text": "short text"}, max_chars=200)
|
|
assert_eq("chunk_snippet short pass-through", "short text", out)
|
|
|
|
|
|
def test_chunk_snippet_truncates_with_ellipsis():
|
|
"""Long chunks should be truncated with an ellipsis."""
|
|
long_text = "x" * 500
|
|
out = retrieve.chunk_snippet({"raw_text": long_text}, max_chars=100)
|
|
assert_true("snippet length under cap", len(out) <= 110, hint=f"len={len(out)}")
|
|
assert_true("snippet ends with ellipsis", out.endswith("…"))
|
|
|
|
|
|
# ─── rerank.cosine() ─────────────────────────────────────────────────────────
|
|
def test_cosine_identical():
|
|
assert_close("cosine identical vectors", 1.0, rerank.cosine([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]))
|
|
|
|
|
|
def test_cosine_orthogonal():
|
|
assert_close("cosine orthogonal", 0.0, rerank.cosine([1.0, 0.0], [0.0, 1.0]))
|
|
|
|
|
|
def test_cosine_anti_parallel():
|
|
assert_close("cosine anti-parallel", -1.0, rerank.cosine([1.0, 0.0], [-1.0, 0.0]))
|
|
|
|
|
|
def test_cosine_length_mismatch():
|
|
"""Mismatched vector lengths should return 0.0 (defensive, not crash)."""
|
|
assert_close("cosine length mismatch", 0.0, rerank.cosine([1.0], [1.0, 2.0]))
|
|
|
|
|
|
def test_cosine_zero_vector():
|
|
assert_close("cosine zero vector", 0.0, rerank.cosine([0.0, 0.0], [1.0, 2.0]))
|
|
|
|
|
|
# ─── rerank.rerank() no-op fallback ──────────────────────────────────────────
|
|
def test_rerank_noop_when_ollama_unreachable():
|
|
"""When ollama is not reachable, rerank should pass candidates through with
|
|
rerank_source='noop-no-ollama'. We force this by patching ollama_alive."""
|
|
with unittest.mock.patch.object(rerank, "ollama_alive", return_value=(False, [])):
|
|
candidates = [
|
|
{"chunk_id": "c-001:0", "score": 7.5, "path": "fake/p1.json"},
|
|
{"chunk_id": "c-002:0", "score": 5.1, "path": "fake/p2.json"},
|
|
]
|
|
out = rerank.rerank("query", candidates, top_k=5)
|
|
assert_eq("rerank no-op preserves order", ["c-001:0", "c-002:0"],
|
|
[c["chunk_id"] for c in out])
|
|
assert_true("rerank no-op tags source",
|
|
all(c.get("rerank_source") == "noop-no-ollama" for c in out))
|
|
assert_true("rerank no-op copies score to rerank_score",
|
|
all(c["rerank_score"] == c["score"] for c in out))
|
|
|
|
|
|
def test_rerank_noop_when_model_missing():
|
|
"""When ollama is up but model isn't pulled, rerank should still no-op."""
|
|
with unittest.mock.patch.object(rerank, "ollama_alive", return_value=(True, ["other-model"])):
|
|
candidates = [{"chunk_id": "c-001:0", "score": 5.0, "path": "x"}]
|
|
out = rerank.rerank("query", candidates, top_k=5)
|
|
assert_eq("rerank no-op for missing model", "noop-no-model", out[0]["rerank_source"])
|
|
|
|
|
|
def test_rerank_truncates_to_top_k():
|
|
with unittest.mock.patch.object(rerank, "ollama_alive", return_value=(False, [])):
|
|
candidates = [{"chunk_id": f"c-{i:03}:0", "score": float(i), "path": "x"} for i in range(10)]
|
|
out = rerank.rerank("query", candidates, top_k=3)
|
|
assert_eq("rerank truncates to top_k", 3, len(out))
|
|
|
|
|
|
# ─── retrieve.py CLI: exit 10 when not provisioned ────────────────────────────
|
|
def test_retrieve_exits_10_without_index():
|
|
"""End-to-end CLI test: with no .vault-meta/bm25/index.json, retrieve.py must exit 10."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Build a minimal vault layout under tmpdir
|
|
sandbox = Path(tmpdir)
|
|
(sandbox / "scripts").mkdir()
|
|
(sandbox / ".vault-meta").mkdir()
|
|
# Copy retrieve.py and its dependencies into the sandbox
|
|
import shutil
|
|
for f in ["retrieve.py", "bm25-index.py", "rerank.py"]:
|
|
shutil.copy(ROOT / "scripts" / f, sandbox / "scripts" / f)
|
|
os.chmod(sandbox / "scripts" / f, 0o755)
|
|
# Run retrieve.py — should exit 10 because no bm25 index exists
|
|
result = subprocess.run(
|
|
[sys.executable, str(sandbox / "scripts" / "retrieve.py"), "test query"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=10,
|
|
)
|
|
assert_eq("retrieve.py exit 10 when not provisioned", 10, result.returncode)
|
|
assert_true("retrieve.py prints friendly error",
|
|
"no BM25 index" in result.stderr,
|
|
hint=result.stderr[:200])
|
|
|
|
|
|
# ─── Integration smoke test: end-to-end with synthetic data ──────────────────
|
|
def test_end_to_end_with_synthetic_chunks():
|
|
"""Build a minimal vault with 2 chunks, index it, run retrieve, verify output."""
|
|
import hashlib
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
sandbox = Path(tmpdir)
|
|
(sandbox / "scripts").mkdir()
|
|
meta = sandbox / ".vault-meta"
|
|
chunks_dir = meta / "chunks"
|
|
bm25_dir = meta / "bm25"
|
|
chunks_dir.mkdir(parents=True)
|
|
bm25_dir.mkdir(parents=True)
|
|
# Copy scripts
|
|
import shutil
|
|
for f in ["retrieve.py", "bm25-index.py", "rerank.py"]:
|
|
shutil.copy(ROOT / "scripts" / f, sandbox / "scripts" / f)
|
|
os.chmod(sandbox / "scripts" / f, 0o755)
|
|
# Write 2 synthetic chunks
|
|
def chunk(addr, idx, text):
|
|
return {
|
|
"schema_version": 1,
|
|
"page_path": f"wiki/fake/{addr}.md",
|
|
"page_address": addr,
|
|
"chunk_index": idx,
|
|
"raw_text": text,
|
|
"contextualized_text": text,
|
|
"prefix": "",
|
|
"prefix_source": "synthetic",
|
|
"char_count": len(text),
|
|
"body_hash": "sha256:" + hashlib.sha256(text.encode()).hexdigest(),
|
|
"page_body_hash": "sha256:0",
|
|
"created_at": "2026-05-17T00:00:00Z",
|
|
}
|
|
(chunks_dir / "c-000001").mkdir()
|
|
(chunks_dir / "c-000002").mkdir()
|
|
(chunks_dir / "c-000001" / "chunk-000.json").write_text(
|
|
json.dumps(chunk("c-000001", 0, "compounding wiki vault pattern by karpathy")))
|
|
(chunks_dir / "c-000002" / "chunk-000.json").write_text(
|
|
json.dumps(chunk("c-000002", 0, "obsidian cli transport detection")))
|
|
# Build index via subprocess (uses the sandbox's META_DIR? no — it uses the
|
|
# script's hard-coded paths relative to its location. Since we copied the
|
|
# script into sandbox/scripts/, VAULT_ROOT will compute to `sandbox`.)
|
|
result = subprocess.run(
|
|
[sys.executable, str(sandbox / "scripts" / "bm25-index.py"), "build"],
|
|
capture_output=True, text=True, timeout=10)
|
|
assert_eq("bm25 build rc=0", 0, result.returncode)
|
|
# Run retrieve
|
|
result = subprocess.run(
|
|
[sys.executable, str(sandbox / "scripts" / "retrieve.py"),
|
|
"karpathy wiki", "--top", "2", "--no-rerank"],
|
|
capture_output=True, text=True, timeout=10)
|
|
assert_eq("retrieve rc=0", 0, result.returncode)
|
|
out = json.loads(result.stdout)
|
|
assert_eq("retrieve.strategy is bm25-only", "bm25-only", out["strategy"])
|
|
assert_true("retrieve returns at least 1 candidate", len(out["candidates"]) >= 1)
|
|
# c-000001 should rank above c-000002 for "karpathy wiki"
|
|
first = out["candidates"][0]
|
|
assert_eq("top hit is c-000001", "c-000001", first["page_address"])
|
|
|
|
|
|
# ─── M8 closure: --explain and --no-rerank flag coverage ─────────────────────
|
|
def test_explain_flag_adds_diagnostics_block():
|
|
"""v1.7.2 / closes audit M8: --explain must include an 'explain' diagnostics block."""
|
|
import hashlib
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
sandbox = Path(tmpdir)
|
|
(sandbox / "scripts").mkdir()
|
|
meta = sandbox / ".vault-meta"
|
|
chunks_dir = meta / "chunks"
|
|
bm25_dir = meta / "bm25"
|
|
chunks_dir.mkdir(parents=True)
|
|
bm25_dir.mkdir(parents=True)
|
|
import shutil
|
|
for f in ["retrieve.py", "bm25-index.py", "rerank.py"]:
|
|
shutil.copy(ROOT / "scripts" / f, sandbox / "scripts" / f)
|
|
os.chmod(sandbox / "scripts" / f, 0o755)
|
|
# 2 synthetic chunks
|
|
(chunks_dir / "c-000010").mkdir()
|
|
(chunks_dir / "c-000010" / "chunk-000.json").write_text(json.dumps({
|
|
"schema_version": 1, "page_path": "wiki/fake/c-000010.md",
|
|
"page_address": "c-000010", "chunk_index": 0,
|
|
"raw_text": "hybrid retrieval pipeline",
|
|
"contextualized_text": "hybrid retrieval pipeline",
|
|
"prefix": "", "prefix_source": "synthetic",
|
|
"char_count": 25,
|
|
"body_hash": "sha256:" + hashlib.sha256(b"hybrid retrieval pipeline").hexdigest(),
|
|
"page_body_hash": "sha256:0",
|
|
"created_at": "2026-05-17T00:00:00Z",
|
|
}))
|
|
# Build index
|
|
subprocess.run([sys.executable, str(sandbox / "scripts" / "bm25-index.py"), "build"],
|
|
capture_output=True, timeout=10, check=True)
|
|
# Run with --explain --no-rerank
|
|
result = subprocess.run(
|
|
[sys.executable, str(sandbox / "scripts" / "retrieve.py"),
|
|
"hybrid", "--top", "1", "--no-rerank", "--explain"],
|
|
capture_output=True, text=True, timeout=10)
|
|
assert_eq("retrieve --explain --no-rerank rc=0", 0, result.returncode)
|
|
out = json.loads(result.stdout)
|
|
assert_true("--explain produces 'explain' key",
|
|
"explain" in out, hint=f"keys={list(out.keys())}")
|
|
explain = out.get("explain", {})
|
|
assert_true("--explain reports BM25 candidate count",
|
|
"bm25_candidates" in explain or "bm25" in str(explain).lower(),
|
|
hint=f"explain={explain}")
|
|
|
|
|
|
def test_no_rerank_flag_strategy_bm25_only():
|
|
"""v1.7.2 / closes audit M8: --no-rerank must produce strategy='bm25-only'."""
|
|
import hashlib
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
sandbox = Path(tmpdir)
|
|
(sandbox / "scripts").mkdir()
|
|
meta = sandbox / ".vault-meta"
|
|
chunks_dir = meta / "chunks"
|
|
bm25_dir = meta / "bm25"
|
|
chunks_dir.mkdir(parents=True)
|
|
bm25_dir.mkdir(parents=True)
|
|
import shutil
|
|
for f in ["retrieve.py", "bm25-index.py", "rerank.py"]:
|
|
shutil.copy(ROOT / "scripts" / f, sandbox / "scripts" / f)
|
|
os.chmod(sandbox / "scripts" / f, 0o755)
|
|
(chunks_dir / "c-000020").mkdir()
|
|
(chunks_dir / "c-000020" / "chunk-000.json").write_text(json.dumps({
|
|
"schema_version": 1, "page_path": "wiki/fake/c-000020.md",
|
|
"page_address": "c-000020", "chunk_index": 0,
|
|
"raw_text": "transport detection fallback chain",
|
|
"contextualized_text": "transport detection fallback chain",
|
|
"prefix": "", "prefix_source": "synthetic",
|
|
"char_count": 35,
|
|
"body_hash": "sha256:" + hashlib.sha256(b"transport detection fallback chain").hexdigest(),
|
|
"page_body_hash": "sha256:0",
|
|
"created_at": "2026-05-17T00:00:00Z",
|
|
}))
|
|
subprocess.run([sys.executable, str(sandbox / "scripts" / "bm25-index.py"), "build"],
|
|
capture_output=True, timeout=10, check=True)
|
|
result = subprocess.run(
|
|
[sys.executable, str(sandbox / "scripts" / "retrieve.py"),
|
|
"transport", "--top", "1", "--no-rerank"],
|
|
capture_output=True, text=True, timeout=10)
|
|
assert_eq("retrieve --no-rerank rc=0", 0, result.returncode)
|
|
out = json.loads(result.stdout)
|
|
assert_eq("--no-rerank sets strategy='bm25-only'", "bm25-only", out.get("strategy"))
|
|
# --no-rerank produces a consistent shape: rerank fields are populated
|
|
# but rerank_source is "skipped" (so callers don't have to special-case).
|
|
candidates = out.get("candidates", [])
|
|
assert_true("--no-rerank still returns candidates", len(candidates) >= 1)
|
|
first = candidates[0]
|
|
assert_eq("--no-rerank candidate rerank_source='skipped'", "skipped",
|
|
first.get("rerank_source"))
|
|
assert_eq("--no-rerank candidate rerank_score equals bm25_score",
|
|
first.get("bm25_score"), first.get("rerank_score"))
|
|
|
|
|
|
def main():
|
|
print("=== test_retrieve.py ===")
|
|
test_import_sibling_resolves_hyphenated_names()
|
|
test_chunk_snippet_short()
|
|
test_chunk_snippet_truncates_with_ellipsis()
|
|
test_cosine_identical()
|
|
test_cosine_orthogonal()
|
|
test_cosine_anti_parallel()
|
|
test_cosine_length_mismatch()
|
|
test_cosine_zero_vector()
|
|
test_rerank_noop_when_ollama_unreachable()
|
|
test_rerank_noop_when_model_missing()
|
|
test_rerank_truncates_to_top_k()
|
|
test_retrieve_exits_10_without_index()
|
|
test_end_to_end_with_synthetic_chunks()
|
|
test_explain_flag_adds_diagnostics_block()
|
|
test_no_rerank_flag_strategy_bm25_only()
|
|
print("\nAll retrieve tests passed.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|