"""
Video-only sampling: start frame -> full video.
No text, no actions; single path used by sample_video_from_start_frame_simple.py.
Self-contained: no imports from unified_video_action.
"""

import sys
from pathlib import Path

_REPO = Path(__file__).resolve().parents[1]
if str(_REPO) not in sys.path:
    sys.path.insert(0, str(_REPO))

import torch


@torch.no_grad()
def sample_video_from_start_frame(policy, cfg, x0: torch.Tensor):
    """
    x0: (B, 3, 256, 256) in [-1, 1]
    returns: (B, T, 3, 256, 256) in [-1, 1]
    """
    b, c, h, w = x0.shape
    assert (c, h, w) == (3, 256, 256), f"expected (B,3,256,256), got {x0.shape}"

    posterior = policy.vae_model.encode(x0.float())
    z0 = posterior.sample().mul_(0.2325)
    n_frames = int(policy.model.n_frames)
    cond = z0.unsqueeze(1).repeat(1, n_frames, 1, 1, 1)

    tokens, _ = policy.model.sample_tokens(
        bsz=b,
        cond=cond,
        num_iter=cfg.model.policy.autoregressive_model_params.num_iter,
        cfg=cfg.model.policy.autoregressive_model_params.cfg,
        cfg_schedule=cfg.model.policy.autoregressive_model_params.cfg_schedule,
        temperature=cfg.model.policy.autoregressive_model_params.temperature,
    )

    # Decode latents: tokens are (B*T, C, H, W), scale back and decode
    z = tokens / 0.2325
    pred = policy.vae_model.decode(z)
    pred = pred.view(b, n_frames, 3, 256, 256)
    return pred
