169 lines
4.7 KiB
Python
169 lines
4.7 KiB
Python
"""Project guardrails for shell commands and apply_patch edits."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
REMOTE_ENGINE_PATTERNS = [
|
|
"--api-url",
|
|
"router mode",
|
|
"http client mode",
|
|
"http client backend",
|
|
"http-client",
|
|
"remote api",
|
|
"remote endpoint",
|
|
"openai-compatible",
|
|
"openai compatible",
|
|
"mathpix",
|
|
"mistral ocr",
|
|
"nanonets",
|
|
]
|
|
|
|
DIRECT_SERVER_COMMAND_PATTERNS = [
|
|
r"(^|\s)mineru-api(\s|$)",
|
|
r"(^|\s)mineru-router(\s|$)",
|
|
]
|
|
|
|
ALLOWED_NEGATION_PATTERNS = [
|
|
"do not",
|
|
"never",
|
|
"exclude",
|
|
"excluded",
|
|
"non-goal",
|
|
"not use",
|
|
"no cloud",
|
|
"blocked",
|
|
"prohibit",
|
|
"prohibited",
|
|
"forbid",
|
|
"forbidden",
|
|
"reject",
|
|
"rejecting",
|
|
]
|
|
|
|
|
|
def read_payload() -> dict:
|
|
raw = sys.stdin.read()
|
|
if not raw.strip():
|
|
return {}
|
|
try:
|
|
return json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
|
|
|
|
def deny(reason: str) -> int:
|
|
output = {
|
|
"hookSpecificOutput": {
|
|
"hookEventName": "PreToolUse",
|
|
"permissionDecision": "deny",
|
|
"permissionDecisionReason": reason,
|
|
}
|
|
}
|
|
print(json.dumps(output, ensure_ascii=True))
|
|
return 0
|
|
|
|
|
|
def find_repo_root(cwd: str | None) -> Path:
|
|
start = Path(cwd or Path.cwd()).resolve()
|
|
try:
|
|
result = subprocess.run(
|
|
["git", "rev-parse", "--show-toplevel"],
|
|
cwd=start,
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
)
|
|
return Path(result.stdout.strip()).resolve()
|
|
except Exception:
|
|
return start
|
|
|
|
|
|
def samples_are_untracked(root: Path) -> bool:
|
|
try:
|
|
result = subprocess.run(
|
|
["git", "status", "--porcelain", "--", "samples"],
|
|
cwd=root,
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
)
|
|
except Exception:
|
|
return False
|
|
return any(line.startswith("?? ") for line in result.stdout.splitlines())
|
|
|
|
|
|
def check_shell_command(command: str, root: Path) -> str | None:
|
|
normalized = command.replace("\\", "/").lower()
|
|
|
|
if re.search(r"\bgit\s+add\b.*(?:^|\s|/)samples(?:\s|/|$)", normalized):
|
|
return "Do not stage samples/ unless the user explicitly requests it."
|
|
|
|
stages_everything = re.search(r"\bgit\s+add\b", normalized) and re.search(
|
|
r"(\s\.($|\s)|\s-a($|\s)|\s--all($|\s))",
|
|
normalized,
|
|
)
|
|
if samples_are_untracked(root) and stages_everything:
|
|
return "Use path-specific git add commands; samples/ is untracked local fixture data."
|
|
|
|
destructive_samples = [
|
|
r"\bgit\s+clean\b.*\b-f\b.*(?:^|\s|/)samples(?:\s|/|$)",
|
|
r"\brm\s+.*-r[f]?\b.*(?:^|\s|/)samples(?:\s|/|$)",
|
|
r"\bremove-item\b.*-recurse\b.*(?:^|\s|/)samples(?:\s|/|$)",
|
|
r"\bgit\s+reset\s+--hard\b",
|
|
]
|
|
if any(re.search(pattern, normalized) for pattern in destructive_samples):
|
|
return "Destructive workspace or samples/ command blocked by project policy."
|
|
|
|
if any(re.search(pattern, normalized) for pattern in DIRECT_SERVER_COMMAND_PATTERNS):
|
|
return "Direct MinerU server/router commands are blocked; use the mineru CLI. CLI-internal temporary local mineru-api is allowed."
|
|
|
|
for pattern in REMOTE_ENGINE_PATTERNS:
|
|
if pattern in normalized:
|
|
return "Remote/API conversion paths are blocked; v1 must run MinerU 3.1.0 through the local CLI only."
|
|
|
|
return None
|
|
|
|
|
|
def check_patch(command: str) -> str | None:
|
|
for line in command.splitlines():
|
|
if not line.startswith("+") or line.startswith("+++"):
|
|
continue
|
|
lowered = line[1:].strip().lower()
|
|
if any(negation in lowered for negation in ALLOWED_NEGATION_PATTERNS):
|
|
continue
|
|
if any(pattern in lowered for pattern in REMOTE_ENGINE_PATTERNS):
|
|
return "Patch appears to add remote/API conversion behavior or excluded engine references."
|
|
if "runtime engine" in lowered and ("selection" in lowered or "switch" in lowered):
|
|
return "Runtime engine selection is out of scope for v1."
|
|
return None
|
|
|
|
|
|
def main() -> int:
|
|
payload = read_payload()
|
|
tool_name = payload.get("tool_name", "")
|
|
tool_input = payload.get("tool_input") or {}
|
|
command = str(tool_input.get("command") or tool_input.get("patch") or "")
|
|
root = find_repo_root(payload.get("cwd"))
|
|
|
|
if tool_name == "Bash":
|
|
reason = check_shell_command(command, root)
|
|
if reason:
|
|
return deny(reason)
|
|
|
|
if tool_name in {"apply_patch", "Edit", "Write"}:
|
|
reason = check_patch(command)
|
|
if reason:
|
|
return deny(reason)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|