"""Downsample base64-embedded raster images inside an SVG to camera-ready sizes.

For each <image> element:
  target_px = svg_width_mm / 25.4 * DPI * SAFETY
If the native pixel size exceeds target, resample (LANCZOS) and re-encode.
Photos (no alpha) → JPEG quality 85. Anything with alpha → PNG.

Usage:
    python3 downsample_svg_rasters.py <input.svg> <output.svg> [--dpi 200] [--safety 1.5]
"""
import argparse
import base64
import io
from pathlib import Path
from lxml import etree
from PIL import Image

NS = {"svg": "http://www.w3.org/2000/svg", "xlink": "http://www.w3.org/1999/xlink"}
HREF_KEYS = ["{http://www.w3.org/1999/xlink}href", "href"]


def get_href(el):
    for k in HREF_KEYS:
        v = el.get(k)
        if v is not None:
            return k, v
    return None, None


def decode_b64(b64: str) -> bytes:
    return base64.b64decode(b64 + "=" * (-len(b64) % 4))


def downsample_svg(src: Path, dst: Path, dpi: float = 200, safety: float = 1.5,
                   jpeg_quality: int = 85) -> dict:
    parser = etree.XMLParser(huge_tree=True)
    tree = etree.parse(str(src), parser)
    root = tree.getroot()
    stats = {"images": 0, "downsampled": 0, "before": 0, "after": 0}

    for el in root.findall(".//svg:image", NS):
        href_key, href = get_href(el)
        if not href or not href.startswith("data:image/"):
            continue
        stats["images"] += 1
        _, _, b64 = href.partition("base64,")
        raw = decode_b64(b64)
        stats["before"] += len(raw)

        try:
            img = Image.open(io.BytesIO(raw))
            img.load()
        except Exception:
            stats["after"] += len(raw)
            continue

        try:
            wmm = float(el.get("width", "0"))
            hmm = float(el.get("height", "0"))
        except ValueError:
            stats["after"] += len(raw)
            continue

        target_w = max(1, int(wmm / 25.4 * dpi * safety))
        target_h = max(1, int(hmm / 25.4 * dpi * safety))

        nw, nh = img.size
        if nw <= target_w and nh <= target_h:
            stats["after"] += len(raw)
            continue

        # Preserve aspect by fitting within target_w x target_h
        scale = min(target_w / nw, target_h / nh)
        new_size = (max(1, int(nw * scale)), max(1, int(nh * scale)))
        small = img.resize(new_size, Image.LANCZOS)

        has_alpha = (img.mode in ("RGBA", "LA")) or (
            img.mode == "P" and "transparency" in img.info
        )
        buf = io.BytesIO()
        if has_alpha:
            small.save(buf, format="PNG", optimize=True)
            mime = "image/png"
        else:
            small.convert("RGB").save(buf, format="JPEG", quality=jpeg_quality,
                                       optimize=True, progressive=True)
            mime = "image/jpeg"
        new_raw = buf.getvalue()

        # Always pick the smaller of (original, downsampled)
        if len(new_raw) >= len(raw):
            stats["after"] += len(raw)
            continue

        new_b64 = base64.b64encode(new_raw).decode("ascii")
        el.set(href_key, f"data:{mime};base64,{new_b64}")
        stats["downsampled"] += 1
        stats["after"] += len(new_raw)

    tree.write(str(dst), xml_declaration=True, encoding="utf-8")
    return stats


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("input")
    p.add_argument("output")
    p.add_argument("--dpi", type=float, default=200)
    p.add_argument("--safety", type=float, default=1.5)
    p.add_argument("--quality", type=int, default=85)
    args = p.parse_args()
    s = downsample_svg(Path(args.input), Path(args.output),
                       dpi=args.dpi, safety=args.safety, jpeg_quality=args.quality)
    print(f"images: {s['images']}, downsampled: {s['downsampled']}")
    print(f"raster bytes: {s['before']/1024/1024:.1f}MB → {s['after']/1024/1024:.1f}MB"
          f" ({100*s['after']/max(1,s['before']):.0f}%)")
