206 lines
5.6 KiB
Python
206 lines
5.6 KiB
Python
import json
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
SOURCE_SUFFIXES = {".h", ".hpp", ".hh", ".hxx", ".c", ".cc", ".cpp", ".cxx", ".ixx"}
|
|
TEST_SUFFIXES = {".h", ".hpp", ".hh", ".hxx", ".c", ".cc", ".cpp", ".cxx", ".ixx"}
|
|
CONFIG_SUFFIXES = {".json", ".md", ".yml", ".yaml", ".txt", ".cmake"}
|
|
|
|
|
|
def _repo_root(cwd: Path) -> Path:
|
|
try:
|
|
root = subprocess.check_output(
|
|
["git", "rev-parse", "--show-toplevel"],
|
|
cwd=cwd,
|
|
text=True,
|
|
stderr=subprocess.DEVNULL,
|
|
).strip()
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
return cwd
|
|
return Path(root)
|
|
|
|
|
|
def _extract_patch_paths(command: str) -> list[str]:
|
|
prefixes = (
|
|
"*** Add File: ",
|
|
"*** Update File: ",
|
|
"*** Delete File: ",
|
|
"*** Move to: ",
|
|
)
|
|
paths: list[str] = []
|
|
for raw_line in command.splitlines():
|
|
line = raw_line.strip()
|
|
for prefix in prefixes:
|
|
if line.startswith(prefix):
|
|
paths.append(line[len(prefix) :].strip())
|
|
break
|
|
return paths
|
|
|
|
|
|
def _touched_paths(payload: dict) -> list[str]:
|
|
tool_input = payload.get("tool_input", {})
|
|
if not isinstance(tool_input, dict):
|
|
return []
|
|
|
|
file_path = tool_input.get("file_path")
|
|
if isinstance(file_path, str) and file_path:
|
|
return [file_path]
|
|
|
|
command = tool_input.get("command")
|
|
if isinstance(command, str):
|
|
return _extract_patch_paths(command)
|
|
|
|
return []
|
|
|
|
|
|
def _normalize(path_text: str) -> str:
|
|
return path_text.replace("\\", "/").lower()
|
|
|
|
|
|
def _is_test_path(path_text: str) -> bool:
|
|
normalized = _normalize(path_text)
|
|
name = normalized.rsplit("/", 1)[-1]
|
|
path = Path(path_text)
|
|
return (
|
|
"/tests/" in f"/{normalized}"
|
|
or "/test/" in f"/{normalized}"
|
|
or name.endswith("_test.cpp")
|
|
or name.startswith("test_")
|
|
or ".test." in name
|
|
or ".spec." in name
|
|
) and path.suffix.lower() in TEST_SUFFIXES
|
|
|
|
|
|
def _token(text: str) -> str:
|
|
return "".join(ch for ch in text.lower() if ch.isalnum())
|
|
|
|
|
|
def _module_token(path: Path) -> str:
|
|
parts = [part.lower() for part in path.parts]
|
|
for marker in ("include", "src"):
|
|
if marker not in parts:
|
|
continue
|
|
idx = parts.index(marker)
|
|
if marker == "include" and idx + 2 < len(parts) and parts[idx + 1] == "fesa":
|
|
return _token(parts[idx + 2])
|
|
if marker == "src" and idx + 1 < len(parts):
|
|
return _token(parts[idx + 1])
|
|
return ""
|
|
|
|
|
|
def _related_tokens(path: Path) -> set[str]:
|
|
tokens = {_token(_base_name(path))}
|
|
module = _module_token(path)
|
|
if module:
|
|
tokens.add(module)
|
|
return {token for token in tokens if token}
|
|
|
|
|
|
def _candidate_test_paths(paths: list[str], cwd: Path, root: Path) -> list[Path]:
|
|
candidates: list[Path] = []
|
|
for path_text in paths:
|
|
resolved = _resolve_path(path_text, cwd)
|
|
if _is_test_path(str(resolved)):
|
|
candidates.append(resolved)
|
|
|
|
for test_root_name in ("tests", "test"):
|
|
test_root = root / test_root_name
|
|
if not test_root.is_dir():
|
|
continue
|
|
for suffix in TEST_SUFFIXES:
|
|
candidates.extend(test_root.rglob(f"*{suffix}"))
|
|
|
|
return candidates
|
|
|
|
|
|
def _has_related_test(path: Path, candidate_tests: list[Path]) -> bool:
|
|
tokens = _related_tokens(path)
|
|
for test_path in candidate_tests:
|
|
test_token = _token(test_path.stem)
|
|
if any(token and token in test_token for token in tokens):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _is_exempt(path_text: str) -> bool:
|
|
normalized = _normalize(path_text)
|
|
path = Path(path_text)
|
|
name = path.name.lower()
|
|
|
|
if name == "cmakelists.txt":
|
|
return True
|
|
if _is_test_path(path_text):
|
|
return True
|
|
if path.suffix.lower() in CONFIG_SUFFIXES:
|
|
return True
|
|
if "/cmake/" in normalized:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _resolve_path(path_text: str, cwd: Path) -> Path:
|
|
path = Path(path_text)
|
|
if path.is_absolute():
|
|
return path
|
|
return (cwd / path).resolve()
|
|
|
|
|
|
def _base_name(path: Path) -> str:
|
|
for suffix in sorted(SOURCE_SUFFIXES, key=len, reverse=True):
|
|
if path.name.lower().endswith(suffix):
|
|
return path.name[: -len(suffix)]
|
|
return path.stem
|
|
|
|
|
|
def _guarded_paths(paths: list[str], cwd: Path, root: Path) -> list[str]:
|
|
missing_tests: list[str] = []
|
|
candidate_tests = _candidate_test_paths(paths, cwd, root)
|
|
for path_text in paths:
|
|
if _is_exempt(path_text):
|
|
continue
|
|
|
|
path = _resolve_path(path_text, cwd)
|
|
if path.suffix.lower() not in SOURCE_SUFFIXES:
|
|
continue
|
|
if not _has_related_test(path, candidate_tests):
|
|
missing_tests.append(_base_name(path))
|
|
|
|
return missing_tests
|
|
|
|
|
|
def main() -> int:
|
|
try:
|
|
payload = json.load(sys.stdin)
|
|
except json.JSONDecodeError:
|
|
return 0
|
|
|
|
cwd = Path(payload.get("cwd") or Path.cwd())
|
|
root = _repo_root(cwd)
|
|
missing_tests = _guarded_paths(_touched_paths(payload), cwd, root)
|
|
if not missing_tests:
|
|
return 0
|
|
|
|
names = ", ".join(sorted(set(missing_tests)))
|
|
print(
|
|
json.dumps(
|
|
{
|
|
"hookSpecificOutput": {
|
|
"hookEventName": "PreToolUse",
|
|
"permissionDecision": "deny",
|
|
"permissionDecisionReason": (
|
|
"TDD GUARD: missing test file for "
|
|
f"{names}. Write or add the test first."
|
|
),
|
|
}
|
|
}
|
|
)
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|