Files
MultiPhysicsVault/tests/test_retrieve.py
T
김경종 72dad72703
Tests / Hermetic test suite (push) Has been cancelled
Tests / Skill frontmatter validation (push) Has been cancelled
add claude-obsidian
2026-05-28 10:57:16 +09:00

344 lines
15 KiB
Python

#!/usr/bin/env python3
"""test_retrieve.py — hermetic tests for scripts/retrieve.py and scripts/rerank.py.
No network, no ollama, no LLM calls. Tests cover:
- import_sibling resolves hyphenated module names
- chunk_snippet truncation behavior
- rerank.cosine math correctness
- rerank.rerank() no-op behavior when ollama is unreachable
- retrieve.py exit 10 (not provisioned) when chunks/index are missing
- dedupe-by-page logic via integration smoke test on synthetic fixtures
Usage:
python3 tests/test_retrieve.py
"""
import importlib.util
import json
import os
import subprocess
import sys
import tempfile
import unittest.mock
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
RETRIEVE = ROOT / "scripts" / "retrieve.py"
RERANK = ROOT / "scripts" / "rerank.py"
BM25 = ROOT / "scripts" / "bm25-index.py"
def import_script(name, path):
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
retrieve = import_script("retrieve", RETRIEVE)
rerank = import_script("rerank", RERANK)
bm25 = import_script("bm25", BM25)
class Fail(SystemExit):
pass
def assert_eq(label, expected, actual):
if expected != actual:
raise Fail(f"FAIL {label}: expected {expected!r}, got {actual!r}")
print(f"OK {label}")
def assert_true(label, cond, hint=""):
if not cond:
raise Fail(f"FAIL {label}{(': ' + hint) if hint else ''}")
print(f"OK {label}")
def assert_close(label, expected, actual, eps=1e-6):
if abs(expected - actual) > eps:
raise Fail(f"FAIL {label}: expected ~{expected}, got {actual}")
print(f"OK {label}")
# ─── import_sibling ──────────────────────────────────────────────────────────
def test_import_sibling_resolves_hyphenated_names():
"""retrieve.import_sibling('bm25_index', 'bm25-index.py') must succeed."""
mod = retrieve.import_sibling("bm25_index", "bm25-index.py")
assert_true("import_sibling returns module", mod is not None)
assert_true("module has tokenize()", callable(getattr(mod, "tokenize", None)))
# ─── chunk_snippet ───────────────────────────────────────────────────────────
def test_chunk_snippet_short():
"""Short chunks should pass through unchanged."""
out = retrieve.chunk_snippet({"raw_text": "short text"}, max_chars=200)
assert_eq("chunk_snippet short pass-through", "short text", out)
def test_chunk_snippet_truncates_with_ellipsis():
"""Long chunks should be truncated with an ellipsis."""
long_text = "x" * 500
out = retrieve.chunk_snippet({"raw_text": long_text}, max_chars=100)
assert_true("snippet length under cap", len(out) <= 110, hint=f"len={len(out)}")
assert_true("snippet ends with ellipsis", out.endswith(""))
# ─── rerank.cosine() ─────────────────────────────────────────────────────────
def test_cosine_identical():
assert_close("cosine identical vectors", 1.0, rerank.cosine([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]))
def test_cosine_orthogonal():
assert_close("cosine orthogonal", 0.0, rerank.cosine([1.0, 0.0], [0.0, 1.0]))
def test_cosine_anti_parallel():
assert_close("cosine anti-parallel", -1.0, rerank.cosine([1.0, 0.0], [-1.0, 0.0]))
def test_cosine_length_mismatch():
"""Mismatched vector lengths should return 0.0 (defensive, not crash)."""
assert_close("cosine length mismatch", 0.0, rerank.cosine([1.0], [1.0, 2.0]))
def test_cosine_zero_vector():
assert_close("cosine zero vector", 0.0, rerank.cosine([0.0, 0.0], [1.0, 2.0]))
# ─── rerank.rerank() no-op fallback ──────────────────────────────────────────
def test_rerank_noop_when_ollama_unreachable():
"""When ollama is not reachable, rerank should pass candidates through with
rerank_source='noop-no-ollama'. We force this by patching ollama_alive."""
with unittest.mock.patch.object(rerank, "ollama_alive", return_value=(False, [])):
candidates = [
{"chunk_id": "c-001:0", "score": 7.5, "path": "fake/p1.json"},
{"chunk_id": "c-002:0", "score": 5.1, "path": "fake/p2.json"},
]
out = rerank.rerank("query", candidates, top_k=5)
assert_eq("rerank no-op preserves order", ["c-001:0", "c-002:0"],
[c["chunk_id"] for c in out])
assert_true("rerank no-op tags source",
all(c.get("rerank_source") == "noop-no-ollama" for c in out))
assert_true("rerank no-op copies score to rerank_score",
all(c["rerank_score"] == c["score"] for c in out))
def test_rerank_noop_when_model_missing():
"""When ollama is up but model isn't pulled, rerank should still no-op."""
with unittest.mock.patch.object(rerank, "ollama_alive", return_value=(True, ["other-model"])):
candidates = [{"chunk_id": "c-001:0", "score": 5.0, "path": "x"}]
out = rerank.rerank("query", candidates, top_k=5)
assert_eq("rerank no-op for missing model", "noop-no-model", out[0]["rerank_source"])
def test_rerank_truncates_to_top_k():
with unittest.mock.patch.object(rerank, "ollama_alive", return_value=(False, [])):
candidates = [{"chunk_id": f"c-{i:03}:0", "score": float(i), "path": "x"} for i in range(10)]
out = rerank.rerank("query", candidates, top_k=3)
assert_eq("rerank truncates to top_k", 3, len(out))
# ─── retrieve.py CLI: exit 10 when not provisioned ────────────────────────────
def test_retrieve_exits_10_without_index():
"""End-to-end CLI test: with no .vault-meta/bm25/index.json, retrieve.py must exit 10."""
with tempfile.TemporaryDirectory() as tmpdir:
# Build a minimal vault layout under tmpdir
sandbox = Path(tmpdir)
(sandbox / "scripts").mkdir()
(sandbox / ".vault-meta").mkdir()
# Copy retrieve.py and its dependencies into the sandbox
import shutil
for f in ["retrieve.py", "bm25-index.py", "rerank.py"]:
shutil.copy(ROOT / "scripts" / f, sandbox / "scripts" / f)
os.chmod(sandbox / "scripts" / f, 0o755)
# Run retrieve.py — should exit 10 because no bm25 index exists
result = subprocess.run(
[sys.executable, str(sandbox / "scripts" / "retrieve.py"), "test query"],
capture_output=True,
text=True,
timeout=10,
)
assert_eq("retrieve.py exit 10 when not provisioned", 10, result.returncode)
assert_true("retrieve.py prints friendly error",
"no BM25 index" in result.stderr,
hint=result.stderr[:200])
# ─── Integration smoke test: end-to-end with synthetic data ──────────────────
def test_end_to_end_with_synthetic_chunks():
"""Build a minimal vault with 2 chunks, index it, run retrieve, verify output."""
import hashlib
with tempfile.TemporaryDirectory() as tmpdir:
sandbox = Path(tmpdir)
(sandbox / "scripts").mkdir()
meta = sandbox / ".vault-meta"
chunks_dir = meta / "chunks"
bm25_dir = meta / "bm25"
chunks_dir.mkdir(parents=True)
bm25_dir.mkdir(parents=True)
# Copy scripts
import shutil
for f in ["retrieve.py", "bm25-index.py", "rerank.py"]:
shutil.copy(ROOT / "scripts" / f, sandbox / "scripts" / f)
os.chmod(sandbox / "scripts" / f, 0o755)
# Write 2 synthetic chunks
def chunk(addr, idx, text):
return {
"schema_version": 1,
"page_path": f"wiki/fake/{addr}.md",
"page_address": addr,
"chunk_index": idx,
"raw_text": text,
"contextualized_text": text,
"prefix": "",
"prefix_source": "synthetic",
"char_count": len(text),
"body_hash": "sha256:" + hashlib.sha256(text.encode()).hexdigest(),
"page_body_hash": "sha256:0",
"created_at": "2026-05-17T00:00:00Z",
}
(chunks_dir / "c-000001").mkdir()
(chunks_dir / "c-000002").mkdir()
(chunks_dir / "c-000001" / "chunk-000.json").write_text(
json.dumps(chunk("c-000001", 0, "compounding wiki vault pattern by karpathy")))
(chunks_dir / "c-000002" / "chunk-000.json").write_text(
json.dumps(chunk("c-000002", 0, "obsidian cli transport detection")))
# Build index via subprocess (uses the sandbox's META_DIR? no — it uses the
# script's hard-coded paths relative to its location. Since we copied the
# script into sandbox/scripts/, VAULT_ROOT will compute to `sandbox`.)
result = subprocess.run(
[sys.executable, str(sandbox / "scripts" / "bm25-index.py"), "build"],
capture_output=True, text=True, timeout=10)
assert_eq("bm25 build rc=0", 0, result.returncode)
# Run retrieve
result = subprocess.run(
[sys.executable, str(sandbox / "scripts" / "retrieve.py"),
"karpathy wiki", "--top", "2", "--no-rerank"],
capture_output=True, text=True, timeout=10)
assert_eq("retrieve rc=0", 0, result.returncode)
out = json.loads(result.stdout)
assert_eq("retrieve.strategy is bm25-only", "bm25-only", out["strategy"])
assert_true("retrieve returns at least 1 candidate", len(out["candidates"]) >= 1)
# c-000001 should rank above c-000002 for "karpathy wiki"
first = out["candidates"][0]
assert_eq("top hit is c-000001", "c-000001", first["page_address"])
# ─── M8 closure: --explain and --no-rerank flag coverage ─────────────────────
def test_explain_flag_adds_diagnostics_block():
"""v1.7.2 / closes audit M8: --explain must include an 'explain' diagnostics block."""
import hashlib
with tempfile.TemporaryDirectory() as tmpdir:
sandbox = Path(tmpdir)
(sandbox / "scripts").mkdir()
meta = sandbox / ".vault-meta"
chunks_dir = meta / "chunks"
bm25_dir = meta / "bm25"
chunks_dir.mkdir(parents=True)
bm25_dir.mkdir(parents=True)
import shutil
for f in ["retrieve.py", "bm25-index.py", "rerank.py"]:
shutil.copy(ROOT / "scripts" / f, sandbox / "scripts" / f)
os.chmod(sandbox / "scripts" / f, 0o755)
# 2 synthetic chunks
(chunks_dir / "c-000010").mkdir()
(chunks_dir / "c-000010" / "chunk-000.json").write_text(json.dumps({
"schema_version": 1, "page_path": "wiki/fake/c-000010.md",
"page_address": "c-000010", "chunk_index": 0,
"raw_text": "hybrid retrieval pipeline",
"contextualized_text": "hybrid retrieval pipeline",
"prefix": "", "prefix_source": "synthetic",
"char_count": 25,
"body_hash": "sha256:" + hashlib.sha256(b"hybrid retrieval pipeline").hexdigest(),
"page_body_hash": "sha256:0",
"created_at": "2026-05-17T00:00:00Z",
}))
# Build index
subprocess.run([sys.executable, str(sandbox / "scripts" / "bm25-index.py"), "build"],
capture_output=True, timeout=10, check=True)
# Run with --explain --no-rerank
result = subprocess.run(
[sys.executable, str(sandbox / "scripts" / "retrieve.py"),
"hybrid", "--top", "1", "--no-rerank", "--explain"],
capture_output=True, text=True, timeout=10)
assert_eq("retrieve --explain --no-rerank rc=0", 0, result.returncode)
out = json.loads(result.stdout)
assert_true("--explain produces 'explain' key",
"explain" in out, hint=f"keys={list(out.keys())}")
explain = out.get("explain", {})
assert_true("--explain reports BM25 candidate count",
"bm25_candidates" in explain or "bm25" in str(explain).lower(),
hint=f"explain={explain}")
def test_no_rerank_flag_strategy_bm25_only():
"""v1.7.2 / closes audit M8: --no-rerank must produce strategy='bm25-only'."""
import hashlib
with tempfile.TemporaryDirectory() as tmpdir:
sandbox = Path(tmpdir)
(sandbox / "scripts").mkdir()
meta = sandbox / ".vault-meta"
chunks_dir = meta / "chunks"
bm25_dir = meta / "bm25"
chunks_dir.mkdir(parents=True)
bm25_dir.mkdir(parents=True)
import shutil
for f in ["retrieve.py", "bm25-index.py", "rerank.py"]:
shutil.copy(ROOT / "scripts" / f, sandbox / "scripts" / f)
os.chmod(sandbox / "scripts" / f, 0o755)
(chunks_dir / "c-000020").mkdir()
(chunks_dir / "c-000020" / "chunk-000.json").write_text(json.dumps({
"schema_version": 1, "page_path": "wiki/fake/c-000020.md",
"page_address": "c-000020", "chunk_index": 0,
"raw_text": "transport detection fallback chain",
"contextualized_text": "transport detection fallback chain",
"prefix": "", "prefix_source": "synthetic",
"char_count": 35,
"body_hash": "sha256:" + hashlib.sha256(b"transport detection fallback chain").hexdigest(),
"page_body_hash": "sha256:0",
"created_at": "2026-05-17T00:00:00Z",
}))
subprocess.run([sys.executable, str(sandbox / "scripts" / "bm25-index.py"), "build"],
capture_output=True, timeout=10, check=True)
result = subprocess.run(
[sys.executable, str(sandbox / "scripts" / "retrieve.py"),
"transport", "--top", "1", "--no-rerank"],
capture_output=True, text=True, timeout=10)
assert_eq("retrieve --no-rerank rc=0", 0, result.returncode)
out = json.loads(result.stdout)
assert_eq("--no-rerank sets strategy='bm25-only'", "bm25-only", out.get("strategy"))
# --no-rerank produces a consistent shape: rerank fields are populated
# but rerank_source is "skipped" (so callers don't have to special-case).
candidates = out.get("candidates", [])
assert_true("--no-rerank still returns candidates", len(candidates) >= 1)
first = candidates[0]
assert_eq("--no-rerank candidate rerank_source='skipped'", "skipped",
first.get("rerank_source"))
assert_eq("--no-rerank candidate rerank_score equals bm25_score",
first.get("bm25_score"), first.get("rerank_score"))
def main():
print("=== test_retrieve.py ===")
test_import_sibling_resolves_hyphenated_names()
test_chunk_snippet_short()
test_chunk_snippet_truncates_with_ellipsis()
test_cosine_identical()
test_cosine_orthogonal()
test_cosine_anti_parallel()
test_cosine_length_mismatch()
test_cosine_zero_vector()
test_rerank_noop_when_ollama_unreachable()
test_rerank_noop_when_model_missing()
test_rerank_truncates_to_top_k()
test_retrieve_exits_10_without_index()
test_end_to_end_with_synthetic_chunks()
test_explain_flag_adds_diagnostics_block()
test_no_rerank_flag_strategy_bm25_only()
print("\nAll retrieve tests passed.")
if __name__ == "__main__":
main()