#!/usr/bin/env python3
"""Visualize augmentation ranges that simulate viewpoint changes."""
import cv2
import numpy as np
from pathlib import Path

# Load a sample frame
frame_path = "/data/libero/ood_objpos_v3/libero_spatial/task_0/demo_0/frames/000005.png"
img = cv2.imread(frame_path)
H, W = img.shape[:2]  # 448x448

n_cols = 8  # number of samples per augmentation

# ============================================================
# Define augmentation dimensions with min/max ranges
# ============================================================

def apply_crop(img, dx, dy, crop_frac=0.85):
    """Simulate lateral camera shift via off-center crop + resize.
    dx, dy in pixels: how far the crop center shifts from image center.
    crop_frac: fraction of image to keep (smaller = more zoom).
    """
    H, W = img.shape[:2]
    crop_h, crop_w = int(H * crop_frac), int(W * crop_frac)
    cx, cy = W // 2 + int(dx), H // 2 + int(dy)
    x1 = max(0, min(cx - crop_w // 2, W - crop_w))
    y1 = max(0, min(cy - crop_h // 2, H - crop_h))
    cropped = img[y1:y1+crop_h, x1:x1+crop_w]
    return cv2.resize(cropped, (W, H))

def apply_rotation(img, angle_deg):
    """In-plane rotation around image center."""
    H, W = img.shape[:2]
    M = cv2.getRotationMatrix2D((W/2, H/2), angle_deg, 1.0)
    return cv2.warpAffine(img, M, (W, H), borderMode=cv2.BORDER_REFLECT_101)

def apply_hshear(img, shear_x):
    """Horizontal shear — simulates horizontal oblique viewing."""
    H, W = img.shape[:2]
    M = np.float32([[1, shear_x, -shear_x * H/2],
                     [0, 1, 0]])
    return cv2.warpAffine(img, M, (W, H), borderMode=cv2.BORDER_REFLECT_101)

def apply_vshear(img, shear_y):
    """Vertical shear — simulates vertical oblique viewing (elevation change)."""
    H, W = img.shape[:2]
    M = np.float32([[1, 0, 0],
                     [shear_y, 1, -shear_y * W/2]])
    return cv2.warpAffine(img, M, (W, H), borderMode=cv2.BORDER_REFLECT_101)

def apply_perspective(img, strength, direction='horizontal'):
    """Perspective warp — more realistic viewpoint simulation than shear.
    strength: how strong the perspective effect is (0 = none, 0.15 = strong).
    """
    H, W = img.shape[:2]
    src = np.float32([[0, 0], [W, 0], [W, H], [0, H]])
    s = strength
    if direction == 'horizontal':
        # Simulate looking from the left (positive) or right (negative)
        dst = np.float32([[0 + s*W, 0 + s*H/3],
                          [W - s*W, 0 - s*H/3],
                          [W + s*W, H + s*H/3],
                          [0 - s*W, H - s*H/3]])
    else:  # vertical
        # Simulate looking from above (positive) or below (negative)
        dst = np.float32([[0 + s*W/3, 0 + s*H],
                          [W - s*W/3, 0 + s*H],
                          [W + s*W/3, H - s*H],
                          [0 - s*W/3, H - s*H]])
    M = cv2.getPerspectiveTransform(src, dst)
    return cv2.warpPerspective(img, M, (W, H), borderMode=cv2.BORDER_REFLECT_101)


# ============================================================
# Augmentation specs: (name, function, param_values, param_label)
# ============================================================
augmentations = [
    ("Horizontal Crop\n(lateral shift)",
     lambda img, v: apply_crop(img, dx=v, dy=0, crop_frac=0.82),
     np.linspace(-50, 50, n_cols), "dx (px)"),

    ("Vertical Crop\n(elevation shift)",
     lambda img, v: apply_crop(img, dx=0, dy=v, crop_frac=0.82),
     np.linspace(-50, 50, n_cols), "dy (px)"),

    ("Rotation\n(camera roll)",
     lambda img, v: apply_rotation(img, v),
     np.linspace(-15, 15, n_cols), "angle (deg)"),

    ("Horizontal Shear\n(azimuth sim)",
     lambda img, v: apply_hshear(img, v),
     np.linspace(-0.20, 0.20, n_cols), "shear_x"),

    ("Vertical Shear\n(elevation sim)",
     lambda img, v: apply_vshear(img, v),
     np.linspace(-0.20, 0.20, n_cols), "shear_y"),

    ("Horizontal Perspective\n(azimuth viewpoint)",
     lambda img, v: apply_perspective(img, v, 'horizontal'),
     np.linspace(-0.15, 0.15, n_cols), "strength"),

    ("Vertical Perspective\n(elevation viewpoint)",
     lambda img, v: apply_perspective(img, v, 'vertical'),
     np.linspace(-0.15, 0.15, n_cols), "strength"),
]

# ============================================================
# Build the grid image
# ============================================================
thumb = 160
row_label_w = 180
col_label_h = 30
pad = 3

grid_w = row_label_w + n_cols * (thumb + pad)
grid_h = col_label_h + len(augmentations) * (thumb + pad + 18)  # 18 for param label

canvas = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
canvas[:] = (30, 30, 30)

y = col_label_h
for aug_name, aug_fn, params, param_label in augmentations:
    # Row label
    lines = aug_name.split('\n')
    for li, line in enumerate(lines):
        cv2.putText(canvas, line, (8, y + thumb//2 - 10 + li*18),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.42, (200, 200, 200), 1)

    # Generate augmented images
    x = row_label_w
    for i, val in enumerate(params):
        aug_img = aug_fn(img.copy(), val)
        aug_thumb = cv2.resize(aug_img, (thumb, thumb))

        # Highlight center (zero/neutral) column
        is_center = (i == n_cols // 2 - 1) or (i == n_cols // 2)
        if abs(val) < 1e-6 or (abs(val) == min(abs(params))):
            # Closest to zero
            pass

        # Border color: blue for min, green for center, red for max
        t = i / (n_cols - 1)  # 0 to 1
        if t < 0.5:
            # Blue to white
            color = (255, int(200*t*2), int(100*t*2))
        else:
            # White to red
            color = (int(255*(1-(t-0.5)*2)), int(200*(1-(t-0.5)*2)), 255)

        aug_thumb[:2, :] = color; aug_thumb[-2:, :] = color
        aug_thumb[:, :2] = color; aug_thumb[:, -2:] = color

        canvas[y:y+thumb, x:x+thumb] = aug_thumb

        # Parameter value label below
        lbl = f"{val:.1f}" if abs(val) >= 1 else f"{val:.3f}"
        cv2.putText(canvas, lbl, (x + 2, y + thumb + 13),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.33, (150, 150, 150), 1)

        x += thumb + pad

    y += thumb + pad + 18

# Title
cv2.putText(canvas, "Augmentation Range Exploration (viewpoint simulation)", (row_label_w, 20),
            cv2.FONT_HERSHEY_SIMPLEX, 0.55, (100, 200, 255), 1)

out_path = "results/augmentation_grid.png"
cv2.imwrite(out_path, canvas)
print(f"Saved: {out_path} ({canvas.shape[1]}x{canvas.shape[0]})")
