Files
AbaqusSubroutineDev/.codex/hooks/tdd-guard.py
T
2026-06-09 12:27:22 +09:00

238 lines
5.9 KiB
Python

import json
import subprocess
import sys
from pathlib import Path
SOURCE_SUFFIXES = {
".h",
".hpp",
".hh",
".hxx",
".c",
".cc",
".cpp",
".cxx",
".ixx",
".f",
".for",
".f90",
".f95",
".f03",
".f08",
}
TEST_SUFFIXES = SOURCE_SUFFIXES | {".py"}
CONFIG_SUFFIXES = {
".json",
".md",
".yml",
".yaml",
".txt",
".cmake",
".inp",
".csv",
".msg",
".dat",
".log",
".sta",
".odb",
}
def _repo_root(cwd: Path) -> Path:
try:
root = subprocess.check_output(
["git", "rev-parse", "--show-toplevel"],
cwd=cwd,
text=True,
stderr=subprocess.DEVNULL,
).strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return cwd
return Path(root)
def _extract_patch_paths(command: str) -> list[str]:
prefixes = (
"*** Add File: ",
"*** Update File: ",
"*** Delete File: ",
"*** Move to: ",
)
paths: list[str] = []
for raw_line in command.splitlines():
line = raw_line.strip()
for prefix in prefixes:
if line.startswith(prefix):
paths.append(line[len(prefix) :].strip())
break
return paths
def _touched_paths(payload: dict) -> list[str]:
tool_input = payload.get("tool_input", {})
if not isinstance(tool_input, dict):
return []
file_path = tool_input.get("file_path")
if isinstance(file_path, str) and file_path:
return [file_path]
command = tool_input.get("command")
if isinstance(command, str):
return _extract_patch_paths(command)
return []
def _normalize(path_text: str) -> str:
return path_text.replace("\\", "/").lower()
def _is_test_path(path_text: str) -> bool:
normalized = _normalize(path_text)
name = normalized.rsplit("/", 1)[-1]
path = Path(path_text)
return (
"/tests/" in f"/{normalized}"
or "/test/" in f"/{normalized}"
or name.endswith("_test.cpp")
or name.startswith("test_")
or ".test." in name
or ".spec." in name
) and path.suffix.lower() in TEST_SUFFIXES
def _token(text: str) -> str:
return "".join(ch for ch in text.lower() if ch.isalnum())
def _module_token(path: Path) -> str:
parts = [part.lower() for part in path.parts]
for marker in ("include", "src"):
if marker not in parts:
continue
idx = parts.index(marker)
if marker == "include" and idx + 2 < len(parts) and parts[idx + 1] == "fesa":
return _token(parts[idx + 2])
if marker == "src" and idx + 1 < len(parts):
return _token(parts[idx + 1])
return ""
def _related_tokens(path: Path) -> set[str]:
tokens = {_token(_base_name(path))}
module = _module_token(path)
if module:
tokens.add(module)
return {token for token in tokens if token}
def _candidate_test_paths(paths: list[str], cwd: Path, root: Path) -> list[Path]:
candidates: list[Path] = []
for path_text in paths:
resolved = _resolve_path(path_text, cwd)
if _is_test_path(str(resolved)):
candidates.append(resolved)
for test_root_name in ("tests", "test"):
test_root = root / test_root_name
if not test_root.is_dir():
continue
for suffix in TEST_SUFFIXES:
candidates.extend(test_root.rglob(f"*{suffix}"))
return candidates
def _has_related_test(path: Path, candidate_tests: list[Path]) -> bool:
tokens = _related_tokens(path)
for test_path in candidate_tests:
test_token = _token(test_path.stem)
if any(token and token in test_token for token in tokens):
return True
return False
def _is_exempt(path_text: str) -> bool:
normalized = _normalize(path_text)
path = Path(path_text)
name = path.name.lower()
if normalized.startswith("references/") or "/references/" in f"/{normalized}":
return True
if name == "cmakelists.txt":
return True
if _is_test_path(path_text):
return True
if path.suffix.lower() in CONFIG_SUFFIXES:
return True
if "/cmake/" in normalized:
return True
return False
def _resolve_path(path_text: str, cwd: Path) -> Path:
path = Path(path_text)
if path.is_absolute():
return path
return (cwd / path).resolve()
def _base_name(path: Path) -> str:
for suffix in sorted(SOURCE_SUFFIXES, key=len, reverse=True):
if path.name.lower().endswith(suffix):
return path.name[: -len(suffix)]
return path.stem
def _guarded_paths(paths: list[str], cwd: Path, root: Path) -> list[str]:
missing_tests: list[str] = []
candidate_tests = _candidate_test_paths(paths, cwd, root)
for path_text in paths:
if _is_exempt(path_text):
continue
path = _resolve_path(path_text, cwd)
if path.suffix.lower() not in SOURCE_SUFFIXES:
continue
if not _has_related_test(path, candidate_tests):
missing_tests.append(_base_name(path))
return missing_tests
def main() -> int:
try:
payload = json.load(sys.stdin)
except json.JSONDecodeError:
return 0
cwd = Path(payload.get("cwd") or Path.cwd())
root = _repo_root(cwd)
missing_tests = _guarded_paths(_touched_paths(payload), cwd, root)
if not missing_tests:
return 0
names = ", ".join(sorted(set(missing_tests)))
print(
json.dumps(
{
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": (
"TDD GUARD: missing test file for "
f"{names}. Write or add the test first."
),
}
}
)
)
return 0
if __name__ == "__main__":
raise SystemExit(main())