import json import subprocess import sys from pathlib import Path SOURCE_SUFFIXES = {".h", ".hpp", ".hh", ".hxx", ".c", ".cc", ".cpp", ".cxx", ".ixx"} TEST_SUFFIXES = {".h", ".hpp", ".hh", ".hxx", ".c", ".cc", ".cpp", ".cxx", ".ixx"} CONFIG_SUFFIXES = {".json", ".md", ".yml", ".yaml", ".txt", ".cmake"} def _repo_root(cwd: Path) -> Path: try: root = subprocess.check_output( ["git", "rev-parse", "--show-toplevel"], cwd=cwd, text=True, stderr=subprocess.DEVNULL, ).strip() except (subprocess.CalledProcessError, FileNotFoundError): return cwd return Path(root) def _extract_patch_paths(command: str) -> list[str]: prefixes = ( "*** Add File: ", "*** Update File: ", "*** Delete File: ", "*** Move to: ", ) paths: list[str] = [] for raw_line in command.splitlines(): line = raw_line.strip() for prefix in prefixes: if line.startswith(prefix): paths.append(line[len(prefix) :].strip()) break return paths def _touched_paths(payload: dict) -> list[str]: tool_input = payload.get("tool_input", {}) if not isinstance(tool_input, dict): return [] file_path = tool_input.get("file_path") if isinstance(file_path, str) and file_path: return [file_path] command = tool_input.get("command") if isinstance(command, str): return _extract_patch_paths(command) return [] def _normalize(path_text: str) -> str: return path_text.replace("\\", "/").lower() def _is_test_path(path_text: str) -> bool: normalized = _normalize(path_text) name = normalized.rsplit("/", 1)[-1] path = Path(path_text) return ( "/tests/" in f"/{normalized}" or "/test/" in f"/{normalized}" or name.endswith("_test.cpp") or name.startswith("test_") or ".test." in name or ".spec." in name ) and path.suffix.lower() in TEST_SUFFIXES def _token(text: str) -> str: return "".join(ch for ch in text.lower() if ch.isalnum()) def _module_token(path: Path) -> str: parts = [part.lower() for part in path.parts] for marker in ("include", "src"): if marker not in parts: continue idx = parts.index(marker) if marker == "include" and idx + 2 < len(parts) and parts[idx + 1] == "fesa": return _token(parts[idx + 2]) if marker == "src" and idx + 1 < len(parts): return _token(parts[idx + 1]) return "" def _related_tokens(path: Path) -> set[str]: tokens = {_token(_base_name(path))} module = _module_token(path) if module: tokens.add(module) return {token for token in tokens if token} def _candidate_test_paths(paths: list[str], cwd: Path, root: Path) -> list[Path]: candidates: list[Path] = [] for path_text in paths: resolved = _resolve_path(path_text, cwd) if _is_test_path(str(resolved)): candidates.append(resolved) for test_root_name in ("tests", "test"): test_root = root / test_root_name if not test_root.is_dir(): continue for suffix in TEST_SUFFIXES: candidates.extend(test_root.rglob(f"*{suffix}")) return candidates def _has_related_test(path: Path, candidate_tests: list[Path]) -> bool: tokens = _related_tokens(path) for test_path in candidate_tests: test_token = _token(test_path.stem) if any(token and token in test_token for token in tokens): return True return False def _is_exempt(path_text: str) -> bool: normalized = _normalize(path_text) path = Path(path_text) name = path.name.lower() if name == "cmakelists.txt": return True if _is_test_path(path_text): return True if path.suffix.lower() in CONFIG_SUFFIXES: return True if "/cmake/" in normalized: return True return False def _resolve_path(path_text: str, cwd: Path) -> Path: path = Path(path_text) if path.is_absolute(): return path return (cwd / path).resolve() def _base_name(path: Path) -> str: for suffix in sorted(SOURCE_SUFFIXES, key=len, reverse=True): if path.name.lower().endswith(suffix): return path.name[: -len(suffix)] return path.stem def _guarded_paths(paths: list[str], cwd: Path, root: Path) -> list[str]: missing_tests: list[str] = [] candidate_tests = _candidate_test_paths(paths, cwd, root) for path_text in paths: if _is_exempt(path_text): continue path = _resolve_path(path_text, cwd) if path.suffix.lower() not in SOURCE_SUFFIXES: continue if not _has_related_test(path, candidate_tests): missing_tests.append(_base_name(path)) return missing_tests def main() -> int: try: payload = json.load(sys.stdin) except json.JSONDecodeError: return 0 cwd = Path(payload.get("cwd") or Path.cwd()) root = _repo_root(cwd) missing_tests = _guarded_paths(_touched_paths(payload), cwd, root) if not missing_tests: return 0 names = ", ".join(sorted(set(missing_tests))) print( json.dumps( { "hookSpecificOutput": { "hookEventName": "PreToolUse", "permissionDecision": "deny", "permissionDecisionReason": ( "TDD GUARD: missing test file for " f"{names}. Write or add the test first." ), } } ) ) return 0 if __name__ == "__main__": raise SystemExit(main())