modify coding template
This commit is contained in:
@@ -0,0 +1,189 @@
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SOURCE_SUFFIXES = {".ts", ".tsx", ".js", ".jsx"}
|
||||
TEST_SUFFIXES = ("ts", "tsx", "js", "jsx")
|
||||
CONFIG_SUFFIXES = {".json", ".css", ".scss", ".md", ".yml", ".yaml"}
|
||||
NEXT_SPECIAL_FILES = {
|
||||
"layout.ts",
|
||||
"layout.tsx",
|
||||
"page.ts",
|
||||
"page.tsx",
|
||||
"loading.tsx",
|
||||
"error.tsx",
|
||||
"not-found.tsx",
|
||||
"globals.css",
|
||||
}
|
||||
|
||||
|
||||
def _repo_root(cwd: Path) -> Path:
|
||||
try:
|
||||
root = subprocess.check_output(
|
||||
["git", "rev-parse", "--show-toplevel"],
|
||||
cwd=cwd,
|
||||
text=True,
|
||||
stderr=subprocess.DEVNULL,
|
||||
).strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return cwd
|
||||
return Path(root)
|
||||
|
||||
|
||||
def _extract_patch_paths(command: str) -> list[str]:
|
||||
prefixes = (
|
||||
"*** Add File: ",
|
||||
"*** Update File: ",
|
||||
"*** Delete File: ",
|
||||
"*** Move to: ",
|
||||
)
|
||||
paths: list[str] = []
|
||||
for raw_line in command.splitlines():
|
||||
line = raw_line.strip()
|
||||
for prefix in prefixes:
|
||||
if line.startswith(prefix):
|
||||
paths.append(line[len(prefix) :].strip())
|
||||
break
|
||||
return paths
|
||||
|
||||
|
||||
def _touched_paths(payload: dict) -> list[str]:
|
||||
tool_input = payload.get("tool_input", {})
|
||||
if not isinstance(tool_input, dict):
|
||||
return []
|
||||
|
||||
file_path = tool_input.get("file_path")
|
||||
if isinstance(file_path, str) and file_path:
|
||||
return [file_path]
|
||||
|
||||
command = tool_input.get("command")
|
||||
if isinstance(command, str):
|
||||
return _extract_patch_paths(command)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def _normalize(path_text: str) -> str:
|
||||
return path_text.replace("\\", "/").lower()
|
||||
|
||||
|
||||
def _is_test_path(path_text: str) -> bool:
|
||||
normalized = _normalize(path_text)
|
||||
name = normalized.rsplit("/", 1)[-1]
|
||||
return (
|
||||
"__tests__/" in normalized
|
||||
or ".test." in name
|
||||
or ".spec." in name
|
||||
or "test" in name
|
||||
or "spec" in name
|
||||
)
|
||||
|
||||
|
||||
def _is_exempt(path_text: str) -> bool:
|
||||
normalized = _normalize(path_text)
|
||||
path = Path(path_text)
|
||||
name = path.name.lower()
|
||||
|
||||
if _is_test_path(path_text):
|
||||
return True
|
||||
if name in NEXT_SPECIAL_FILES:
|
||||
return True
|
||||
if path.suffix.lower() in CONFIG_SUFFIXES:
|
||||
return True
|
||||
if ".env" in name or ".config." in name:
|
||||
return True
|
||||
if any(token in name for token in ("tailwind", "postcss", "next.config", "tsconfig")):
|
||||
return True
|
||||
if "/types/" in normalized or name in {"types.ts", "types.d.ts"}:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_path(path_text: str, cwd: Path) -> Path:
|
||||
path = Path(path_text)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
return (cwd / path).resolve()
|
||||
|
||||
|
||||
def _base_name(path: Path) -> str:
|
||||
for suffix in (".tsx", ".ts", ".jsx", ".js"):
|
||||
if path.name.endswith(suffix):
|
||||
return path.name[: -len(suffix)]
|
||||
return path.stem
|
||||
|
||||
|
||||
def _has_existing_test(path: Path, root: Path) -> bool:
|
||||
directory = path.parent
|
||||
parent = directory.parent
|
||||
base = _base_name(path)
|
||||
|
||||
for ext in TEST_SUFFIXES:
|
||||
if (directory / f"{base}.test.{ext}").exists():
|
||||
return True
|
||||
if (directory / f"{base}.spec.{ext}").exists():
|
||||
return True
|
||||
|
||||
for ext in TEST_SUFFIXES:
|
||||
if (parent / "__tests__" / f"{base}.test.{ext}").exists():
|
||||
return True
|
||||
if (directory / "__tests__" / f"{base}.test.{ext}").exists():
|
||||
return True
|
||||
|
||||
for ext in TEST_SUFFIXES:
|
||||
if (root / "src" / "__tests__" / f"{base}.test.{ext}").exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _guarded_paths(paths: list[str], cwd: Path, root: Path) -> list[str]:
|
||||
missing_tests: list[str] = []
|
||||
for path_text in paths:
|
||||
if _is_exempt(path_text):
|
||||
continue
|
||||
|
||||
path = _resolve_path(path_text, cwd)
|
||||
if path.suffix.lower() not in SOURCE_SUFFIXES:
|
||||
continue
|
||||
if not _has_existing_test(path, root):
|
||||
missing_tests.append(_base_name(path))
|
||||
|
||||
return missing_tests
|
||||
|
||||
|
||||
def main() -> int:
|
||||
try:
|
||||
payload = json.load(sys.stdin)
|
||||
except json.JSONDecodeError:
|
||||
return 0
|
||||
|
||||
cwd = Path(payload.get("cwd") or Path.cwd())
|
||||
root = _repo_root(cwd)
|
||||
missing_tests = _guarded_paths(_touched_paths(payload), cwd, root)
|
||||
if not missing_tests:
|
||||
return 0
|
||||
|
||||
names = ", ".join(sorted(set(missing_tests)))
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": (
|
||||
"TDD GUARD: missing test file for "
|
||||
f"{names}. Write or add the test first."
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user