"""
Extract and visualize UNet intermediate features from fine-tuned SVD at
early, middle, and late denoising timesteps. PCA → RGB visualization.
Also generates the predicted video for reference.
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "9"

import torch
import numpy as np
import imageio
from PIL import Image
from sklearn.decomposition import PCA
from svd.pipelines import StableVideoDiffusionPipeline
from svd.models import UNetSpatioTemporalConditionModel
from diffusers import AutoencoderKLTemporalDecoder
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

CKPT = "output_libero_7f/checkpoint-46000"  # trained on original libero
BASE = "checkpoints/stable-video-diffusion-img2vid-xt-1-1"
IMG_PATH = "dataset/val_full/libero_0000.png"
WIDTH, HEIGHT = 576, 320
NUM_FRAMES = 7
NUM_STEPS = 25
OUT_DIR = "feature_vis"

os.makedirs(OUT_DIR, exist_ok=True)

# ─── Load model components ────────────────────────────────────────────────────
print("Loading model...")
unet = UNetSpatioTemporalConditionModel.from_pretrained(
    CKPT, subfolder="unet", torch_dtype=torch.float16)

vae = AutoencoderKLTemporalDecoder.from_pretrained(
    BASE, subfolder="vae", torch_dtype=torch.float16)

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    BASE, subfolder="image_encoder", torch_dtype=torch.float16)

feature_extractor = CLIPImageProcessor.from_pretrained(
    BASE, subfolder="feature_extractor")

# Also load pipeline for video generation
pipe = StableVideoDiffusionPipeline.from_pretrained(
    BASE, unet=unet, torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")

device = torch.device("cuda")
unet = unet.to(device)
vae = vae.to(device)
image_encoder = image_encoder.to(device)

print(f"  GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

# ─── Generate predicted video ─────────────────────────────────────────────────
print(f"\nGenerating predicted video from {IMG_PATH}...")
input_image = Image.open(IMG_PATH).resize((WIDTH, HEIGHT))
input_image.save(f"{OUT_DIR}/input_frame.png")

with torch.inference_mode():
    frames = pipe(input_image, height=HEIGHT, width=WIDTH,
                  num_frames=NUM_FRAMES, decode_chunk_size=4,
                  num_inference_steps=NUM_STEPS).frames[0]

frames_np = [np.array(f) for f in frames]
imageio.mimwrite(f"{OUT_DIR}/predicted_video.mp4", frames_np, fps=4, quality=8)
print(f"  Saved predicted video ({len(frames_np)} frames)")

# ─── Feature extraction via hooks ─────────────────────────────────────────────
print("\nExtracting UNet features at early/mid/late denoising steps...")

# Prepare input: encode image to latent
input_tensor = torch.from_numpy(np.array(input_image)).permute(2, 0, 1).float() / 255.0
input_tensor = input_tensor.unsqueeze(0).to(device, dtype=torch.float16)  # (1, 3, H, W)
input_tensor = input_tensor * 2.0 - 1.0  # normalize to [-1, 1]

# Encode conditioning image with CLIP
clip_input = feature_extractor(images=input_image, return_tensors="pt").pixel_values
clip_input = clip_input.to(device, dtype=torch.float16)
image_embeddings = image_encoder(clip_input).image_embeds.unsqueeze(1)  # (1, 1, D)

# Encode image with VAE
with torch.no_grad():
    # VAE encoder processes frames individually (4D input)
    # Encode the single frame, then repeat for temporal dim
    single_latent = vae.encode(input_tensor).latent_dist.sample()  # (1, C, h, w)
    single_latent = single_latent * vae.config.scaling_factor
    # Repeat across temporal dimension
    latent = single_latent.unsqueeze(2).repeat(1, 1, NUM_FRAMES, 1, 1)  # (1, C, T, h, w)

print(f"  Latent shape: {latent.shape}")

# Set up hooks to capture features from different UNet blocks
captured_features = {}

def make_hook(name):
    def hook_fn(module, input, output):
        if isinstance(output, tuple):
            captured_features[name] = output[0].detach()
        else:
            captured_features[name] = output.detach()
    return hook_fn

# Register hooks on mid_block and two up_blocks
hooks = []
hooks.append(unet.mid_block.register_forward_hook(make_hook("mid_block")))
if len(unet.up_blocks) >= 3:
    hooks.append(unet.up_blocks[0].register_forward_hook(make_hook("up_block_0")))
    hooks.append(unet.up_blocks[1].register_forward_hook(make_hook("up_block_1")))
    hooks.append(unet.up_blocks[2].register_forward_hook(make_hook("up_block_2")))

# Denoising timesteps to extract features at
# SVD uses a specific sigma schedule; we'll use the pipeline's scheduler
from diffusers import EulerDiscreteScheduler
scheduler = EulerDiscreteScheduler.from_pretrained(BASE, subfolder="scheduler")
scheduler.set_timesteps(NUM_STEPS, device=device)
timesteps = scheduler.timesteps

early_step = 2          # high noise (early in denoising = late timestep value)
mid_step = NUM_STEPS // 2  # medium noise
late_step = NUM_STEPS - 2   # low noise (late in denoising = early timestep value)

extract_steps = {
    "early": early_step,
    "mid": mid_step,
    "late": late_step,
}

print(f"  Extracting at steps: {extract_steps}")
print(f"  Corresponding timesteps: early={timesteps[early_step]:.1f}, mid={timesteps[mid_step]:.1f}, late={timesteps[late_step]:.1f}")

# Prepare added time IDs (fps, motion_bucket_id, noise_aug_strength)
fps = 7
motion_bucket_id = 127
noise_aug_strength = 0.02
added_time_ids = torch.tensor([[fps, motion_bucket_id, noise_aug_strength]],
                               device=device, dtype=torch.float16)

# Run denoising and capture features at selected steps
all_step_features = {}
noise = torch.randn_like(latent)

with torch.no_grad():
    # Encode conditioning image latent (for concat with noisy latent)
    cond_latent = single_latent.unsqueeze(2).repeat(1, 1, NUM_FRAMES, 1, 1)
    cond_latent = cond_latent / vae.config.scaling_factor  # undo scaling for conditioning

    for step_idx, t in enumerate(timesteps):
        # Create noisy latent at this timestep
        sigma = scheduler.sigmas[step_idx]
        noisy_latent = latent + noise * sigma

        # Scale input
        scaled_input = noisy_latent / ((sigma**2 + 1) ** 0.5)

        # Concat noisy latent with conditioning latent along channel dim
        # UNet expects (B, T, C, H, W) so concat on dim=1 (after permute to get C together)
        noisy_5d = scaled_input.permute(0, 2, 1, 3, 4)  # (1, T, C, H, W)
        cond_5d = cond_latent.permute(0, 2, 1, 3, 4)    # (1, T, C, H, W)
        unet_input = torch.cat([noisy_5d, cond_5d], dim=2)  # (1, T, 8, H, W)

        # encoder_hidden_states: (B, 1, D) - single image embedding
        encoder_hidden_states = image_embeddings

        # UNet forward
        captured_features.clear()
        _ = unet(
            unet_input,
            t.unsqueeze(0),
            encoder_hidden_states=encoder_hidden_states,
            added_time_ids=added_time_ids,
        ).sample

        # Save features at selected steps
        for name, si in extract_steps.items():
            if step_idx == si:
                all_step_features[name] = {k: v.clone() for k, v in captured_features.items()}
                print(f"  Captured features at step {step_idx} ({name})")
                for k, v in captured_features.items():
                    print(f"    {k}: {v.shape}")

# Remove hooks
for h in hooks:
    h.remove()

# ─── PCA visualization ────────────────────────────────────────────────────────
print("\nGenerating PCA visualizations...")

def pca_visualize(features, name, spatial_size=None):
    """Apply PCA to feature map and save as RGB image."""
    # features: (B, C, T, H, W) or (B, C, H, W)
    feat = features[0].float()  # remove batch dim

    if feat.ndim == 4:  # (C, T, H, W)
        # Average across temporal dim to get spatial features
        feat = feat.mean(dim=1)  # (C, H, W)
    elif feat.ndim == 3:  # (C, H, W)
        pass

    C, H, W = feat.shape

    # Reshape to (H*W, C) for PCA
    feat_flat = feat.permute(1, 2, 0).reshape(-1, C).cpu().numpy()

    # PCA to 3 components
    pca = PCA(n_components=3)
    feat_pca = pca.fit_transform(feat_flat)  # (H*W, 3)

    # Normalize each component to [0, 255]
    for i in range(3):
        channel = feat_pca[:, i]
        min_val, max_val = channel.min(), channel.max()
        if max_val - min_val > 0:
            feat_pca[:, i] = (channel - min_val) / (max_val - min_val) * 255
        else:
            feat_pca[:, i] = 0

    feat_rgb = feat_pca.reshape(H, W, 3).astype(np.uint8)

    # Upsample to viewable size
    if spatial_size:
        feat_rgb = np.array(Image.fromarray(feat_rgb).resize(spatial_size, Image.NEAREST))
    else:
        feat_rgb = np.array(Image.fromarray(feat_rgb).resize((WIDTH, HEIGHT), Image.NEAREST))

    return feat_rgb


# Visualize each timestep × each UNet block
for step_name, features_dict in all_step_features.items():
    for block_name, feat_tensor in features_dict.items():
        vis = pca_visualize(feat_tensor, f"{step_name}_{block_name}")
        save_path = f"{OUT_DIR}/pca_{step_name}_{block_name}.png"
        Image.fromarray(vis).save(save_path)
        print(f"  Saved {save_path} (from {feat_tensor.shape})")

# Create a comparison grid: rows = blocks, cols = early/mid/late
print("\nCreating comparison grid...")
block_names = sorted(set().union(*[d.keys() for d in all_step_features.values()]))
step_names = ["early", "mid", "late"]

# Add input image and labels
input_resized = np.array(input_image.resize((WIDTH, HEIGHT)))

grid_rows = []
for block_name in block_names:
    row_imgs = []
    for step_name in step_names:
        if step_name in all_step_features and block_name in all_step_features[step_name]:
            vis = pca_visualize(all_step_features[step_name][block_name], "")
            row_imgs.append(vis)
        else:
            row_imgs.append(np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8))
    grid_rows.append(np.concatenate(row_imgs, axis=1))

# Add input image row
input_row = np.concatenate([input_resized] * 3, axis=1)
grid = np.concatenate([input_row] + grid_rows, axis=0)

# Add text labels
from PIL import ImageDraw, ImageFont
grid_img = Image.fromarray(grid)
draw = ImageDraw.Draw(grid_img)
try:
    font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
except:
    font = ImageFont.load_default()

# Column headers
for i, name in enumerate(step_names):
    t_val = timesteps[extract_steps[name]].item()
    draw.text((i * WIDTH + 5, 5), f"{name} (t={t_val:.0f})", fill=(255, 255, 0), font=font)

# Row headers
row_labels = ["Input"] + block_names
for i, name in enumerate(row_labels):
    draw.text((5, i * HEIGHT + 5), name, fill=(255, 255, 0), font=font)

grid_img.save(f"{OUT_DIR}/feature_grid.png")
print(f"  Saved {OUT_DIR}/feature_grid.png")

print(f"\nAll outputs in {OUT_DIR}/")
print("  - input_frame.png: conditioning image")
print("  - predicted_video.mp4: SVD generated video")
print("  - pca_*.png: individual PCA visualizations")
print("  - feature_grid.png: comparison grid (rows=blocks, cols=timesteps)")
