#!/usr/bin/env python3 """Remove a solid chroma-key background from an image. This helper supports the imagegen skill's built-in-first transparent workflow: generate an image on a flat key color, then convert that key color to alpha. """ from __future__ import annotations import argparse from io import BytesIO from pathlib import Path import re from statistics import median import sys from typing import Tuple Color = Tuple[int, int, int] KEY_DOMINANCE_THRESHOLD = 16.0 ALPHA_NOISE_FLOOR = 8 def _die(message: str, code: int = 1) -> None: print(f"Error: {message}", file=sys.stderr) raise SystemExit(code) def _dependency_hint(package: str) -> str: return ( "Activate the repo-selected environment first, then install it with " f"`uv pip install {package}`. If this repo uses a local virtualenv, start with " "`source .venv/bin/activate`; otherwise use this repo's configured shared fallback " "environment." ) def _load_pillow(): try: from PIL import Image, ImageFilter except ImportError: _die(f"Pillow is required for chroma-key removal. {_dependency_hint('pillow')}") return Image, ImageFilter def _parse_key_color(raw: str) -> Color: value = raw.strip() match = re.fullmatch(r"#?([0-9a-fA-F]{6})", value) if not match: _die("key color must be a hex RGB value like #00ff00.") hex_value = match.group(1) return ( int(hex_value[0:2], 16), int(hex_value[2:4], 16), int(hex_value[4:6], 16), ) def _validate_args(args: argparse.Namespace) -> None: if args.tolerance < 0 or args.tolerance > 255: _die("--tolerance must be between 0 and 255.") if args.transparent_threshold < 0 or args.transparent_threshold > 255: _die("--transparent-threshold must be between 0 and 255.") if args.opaque_threshold < 0 or args.opaque_threshold > 255: _die("--opaque-threshold must be between 0 and 255.") if args.soft_matte and args.transparent_threshold >= args.opaque_threshold: _die("--transparent-threshold must be lower than --opaque-threshold.") if args.edge_feather < 0 or args.edge_feather > 64: _die("--edge-feather must be between 0 and 64.") if args.edge_contract < 0 or args.edge_contract > 16: _die("--edge-contract must be between 0 and 16.") src = Path(args.input) if not src.exists(): _die(f"Input image not found: {src}") out = Path(args.out) if out.exists() and not args.force: _die(f"Output already exists: {out} (use --force to overwrite)") if out.suffix.lower() not in {".png", ".webp"}: _die("--out must end in .png or .webp so the alpha channel is preserved.") def _channel_distance(a: Color, b: Color) -> int: return max(abs(a[0] - b[0]), abs(a[1] - b[1]), abs(a[2] - b[2])) def _clamp_channel(value: float) -> int: return max(0, min(255, int(round(value)))) def _smoothstep(value: float) -> float: value = max(0.0, min(1.0, value)) return value * value * (3.0 - 2.0 * value) def _soft_alpha(distance: int, transparent_threshold: float, opaque_threshold: float) -> int: if distance <= transparent_threshold: return 0 if distance >= opaque_threshold: return 255 ratio = (float(distance) - transparent_threshold) / ( opaque_threshold - transparent_threshold ) return _clamp_channel(255.0 * _smoothstep(ratio)) def _dominance_alpha(rgb: Color, key: Color) -> int: spill_channels = _spill_channels(key) if not spill_channels: return 255 channels = [float(value) for value in rgb] non_spill = [idx for idx in range(3) if idx not in spill_channels] key_strength = ( min(channels[idx] for idx in spill_channels) if len(spill_channels) > 1 else channels[spill_channels[0]] ) non_key_strength = max((channels[idx] for idx in non_spill), default=0.0) dominance = key_strength - non_key_strength if dominance <= 0: return 255 denominator = max(1.0, float(max(key)) - non_key_strength) alpha = 1.0 - min(1.0, dominance / denominator) return _clamp_channel(alpha * 255.0) def _spill_channels(key: Color) -> list[int]: key_max = max(key) if key_max < 128: return [] return [idx for idx, value in enumerate(key) if value >= key_max - 16 and value >= 128] def _key_channel_dominance(rgb: Color, key: Color) -> float: spill_channels = _spill_channels(key) if not spill_channels: return 0.0 channels = [float(value) for value in rgb] non_spill = [idx for idx in range(3) if idx not in spill_channels] key_strength = ( min(channels[idx] for idx in spill_channels) if len(spill_channels) > 1 else channels[spill_channels[0]] ) non_key_strength = max((channels[idx] for idx in non_spill), default=0.0) return key_strength - non_key_strength def _looks_key_colored(rgb: Color, key: Color, distance: int) -> bool: if distance <= 32: return True spill_channels = _spill_channels(key) if not spill_channels: return True return _key_channel_dominance(rgb, key) >= KEY_DOMINANCE_THRESHOLD def _cleanup_spill(rgb: Color, key: Color, alpha: int = 255) -> Color: if alpha >= 252: return rgb spill_channels = _spill_channels(key) if not spill_channels: return rgb channels = [float(value) for value in rgb] non_spill = [idx for idx in range(3) if idx not in spill_channels] if non_spill: anchor = max(channels[idx] for idx in non_spill) cap = max(0.0, anchor - 1.0) for idx in spill_channels: if channels[idx] > cap: channels[idx] = cap return ( _clamp_channel(channels[0]), _clamp_channel(channels[1]), _clamp_channel(channels[2]), ) def _apply_alpha_to_image( image, *, key: Color, tolerance: int, spill_cleanup: bool, soft_matte: bool, transparent_threshold: float, opaque_threshold: float, ) -> int: pixels = image.load() width, height = image.size transparent = 0 for y in range(height): for x in range(width): red, green, blue, alpha = pixels[x, y] rgb = (red, green, blue) distance = _channel_distance(rgb, key) key_like = _looks_key_colored(rgb, key, distance) output_alpha = ( min( _soft_alpha(distance, transparent_threshold, opaque_threshold), _dominance_alpha(rgb, key), ) if soft_matte and key_like else (0 if distance <= tolerance else 255) ) output_alpha = int(round(output_alpha * (alpha / 255.0))) if 0 < output_alpha <= ALPHA_NOISE_FLOOR: output_alpha = 0 if output_alpha == 0: pixels[x, y] = (0, 0, 0, 0) transparent += 1 continue if spill_cleanup and key_like: red, green, blue = _cleanup_spill(rgb, key, output_alpha) pixels[x, y] = (red, green, blue, output_alpha) return transparent def _contract_alpha(image, pixels: int): if pixels == 0: return image _, ImageFilter = _load_pillow() alpha = image.getchannel("A") for _ in range(pixels): alpha = alpha.filter(ImageFilter.MinFilter(3)) image.putalpha(alpha) return image def _apply_edge_feather(image, radius: float): if radius == 0: return image _, ImageFilter = _load_pillow() alpha = image.getchannel("A") alpha = alpha.filter(ImageFilter.GaussianBlur(radius=radius)) image.putalpha(alpha) return image def _encode_image(image, output_format: str) -> bytes: out = BytesIO() image.save(out, format=output_format.upper()) return out.getvalue() def _alpha_counts(image) -> tuple[int, int, int]: pixels = image.load() width, height = image.size total = 0 transparent = 0 partial = 0 for y in range(height): for x in range(width): alpha = pixels[x, y][3] total += 1 if alpha == 0: transparent += 1 elif alpha < 255: partial += 1 return total, transparent, partial def _sample_border_key(image, mode: str) -> Color: width, height = image.size pixels = image.load() samples: list[Color] = [] if mode == "corners": patch = max(1, min(width, height, 12)) boxes = [ (0, 0, patch, patch), (width - patch, 0, width, patch), (0, height - patch, patch, height), (width - patch, height - patch, width, height), ] for left, top, right, bottom in boxes: for y in range(top, bottom): for x in range(left, right): red, green, blue = pixels[x, y][:3] samples.append((red, green, blue)) else: band = max(1, min(width, height, 6)) step = max(1, min(width, height) // 256) for x in range(0, width, step): for y in range(band): red, green, blue = pixels[x, y][:3] samples.append((red, green, blue)) red, green, blue = pixels[x, height - 1 - y][:3] samples.append((red, green, blue)) for y in range(0, height, step): for x in range(band): red, green, blue = pixels[x, y][:3] samples.append((red, green, blue)) red, green, blue = pixels[width - 1 - x, y][:3] samples.append((red, green, blue)) if not samples: _die("Could not sample background key color from image border.") return ( int(round(median(sample[0] for sample in samples))), int(round(median(sample[1] for sample in samples))), int(round(median(sample[2] for sample in samples))), ) def _remove_chroma_key(args: argparse.Namespace) -> None: Image, _ = _load_pillow() src = Path(args.input) out = Path(args.out) with Image.open(src) as image: rgba = image.convert("RGBA") key = ( _sample_border_key(rgba, args.auto_key) if args.auto_key != "none" else _parse_key_color(args.key_color) ) transparent = _apply_alpha_to_image( rgba, key=key, tolerance=args.tolerance, spill_cleanup=args.spill_cleanup, soft_matte=args.soft_matte, transparent_threshold=args.transparent_threshold, opaque_threshold=args.opaque_threshold, ) rgba = _contract_alpha(rgba, args.edge_contract) rgba = _apply_edge_feather(rgba, args.edge_feather) total, transparent_after, partial_after = _alpha_counts(rgba) out.parent.mkdir(parents=True, exist_ok=True) output_format = "PNG" if out.suffix.lower() == ".png" else "WEBP" out.write_bytes(_encode_image(rgba, output_format)) print(f"Wrote {out}") print(f"Key color: #{key[0]:02x}{key[1]:02x}{key[2]:02x}") print(f"Transparent pixels: {transparent_after}/{total}") print(f"Partially transparent pixels: {partial_after}/{total}") if transparent == 0: print("Warning: no pixels matched the key color before feathering.", file=sys.stderr) def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Remove a solid chroma-key background and write an image with alpha." ) parser.add_argument("--input", required=True, help="Input image path.") parser.add_argument("--out", required=True, help="Output .png or .webp path.") parser.add_argument( "--key-color", default="#00ff00", help="Hex RGB key color to remove, for example #00ff00.", ) parser.add_argument( "--tolerance", type=int, default=12, help="Hard-key per-channel tolerance for matching the key color, 0-255.", ) parser.add_argument( "--auto-key", choices=["none", "corners", "border"], default="none", help="Sample the key color from image corners or border instead of --key-color.", ) parser.add_argument( "--soft-matte", action="store_true", help="Use a smooth alpha ramp between transparent and opaque thresholds.", ) parser.add_argument( "--transparent-threshold", type=float, default=12.0, help="Soft-matte distance at or below which pixels become fully transparent.", ) parser.add_argument( "--opaque-threshold", type=float, default=96.0, help="Soft-matte distance at or above which pixels become fully opaque.", ) parser.add_argument( "--edge-feather", type=float, default=0.0, help="Optional alpha blur radius for softened edges, 0-64.", ) parser.add_argument( "--edge-contract", type=int, default=0, help="Shrink the visible alpha matte by this many pixels before feathering.", ) parser.add_argument( "--spill-cleanup", dest="spill_cleanup", action="store_true", help="Reduce obvious key-color spill on opaque pixels.", ) parser.add_argument( "--despill", dest="spill_cleanup", action="store_true", help="Alias for --spill-cleanup; decontaminate key-color edge spill.", ) parser.add_argument("--force", action="store_true", help="Overwrite an existing output file.") return parser def main() -> None: parser = _build_parser() args = parser.parse_args() _validate_args(args) _remove_chroma_key(args) if __name__ == "__main__": main()