# DiffLoss and helpers (from unified_video_action diffusion_loss) - self-contained
import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from simple_uva.diffusion import create_diffusion


def modulate(x, shift, scale):
    return x * (1 + scale) + shift


class TimestepEmbedder(nn.Module):
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32)
            / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.in_ln = nn.LayerNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Linear(channels, channels, bias=True),
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)
        )

    def forward(self, x, y):
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
        h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
        h = self.mlp(h)
        return x + gate_mlp * h


class FinalLayer(nn.Module):
    def __init__(self, model_channels, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(
            model_channels, elementwise_affine=False, eps=1e-6
        )
        self.linear = nn.Linear(model_channels, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class SimpleMLPAdaLN(nn.Module):
    def __init__(
        self,
        in_channels,
        model_channels,
        out_channels,
        z_channels,
        num_res_blocks,
        grad_checkpointing=False,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.grad_checkpointing = grad_checkpointing
        self.time_embed = TimestepEmbedder(model_channels)
        self.cond_embed = nn.Linear(z_channels, model_channels)
        self.input_proj = nn.Linear(in_channels, model_channels)
        self.res_blocks = nn.ModuleList(
            [ResBlock(model_channels) for _ in range(num_res_blocks)]
        )
        self.final_layer = FinalLayer(model_channels, out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
        for block in self.res_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def forward(self, x, t, c):
        x = self.input_proj(x)
        t = self.time_embed(t)
        c = self.cond_embed(c)
        y = t + c
        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.res_blocks:
                x = checkpoint(block, x, y)
        else:
            for block in self.res_blocks:
                x = block(x, y)
        return self.final_layer(x, y)

    def forward_with_cfg(self, x, t, c, cfg_scale):
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = self.forward(combined, t, c)
        eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)


class DiffLoss(nn.Module):
    def __init__(
        self,
        target_channels,
        z_channels,
        depth,
        width,
        num_sampling_steps,
        grad_checkpointing=False,
        **kwargs
    ):
        super(DiffLoss, self).__init__()
        self.n_frames = kwargs["n_frames"]
        self.in_channels = target_channels
        self.net = SimpleMLPAdaLN(
            in_channels=target_channels,
            model_channels=width,
            out_channels=target_channels * 2,
            z_channels=z_channels,
            num_res_blocks=depth,
            grad_checkpointing=grad_checkpointing,
        )
        self.train_diffusion = create_diffusion(
            timestep_respacing="", noise_schedule="cosine"
        )
        self.gen_diffusion = create_diffusion(
            timestep_respacing=num_sampling_steps, noise_schedule="cosine"
        )

    def forward(self, target, z, mask=None, conf_score=None, text_latents=None):
        bsz, seq_len, _ = target.shape
        target = target.reshape(bsz * seq_len, -1)
        z = z.reshape(bsz * seq_len, -1)
        mask = mask.reshape(bsz * seq_len)
        t = torch.randint(
            0,
            self.train_diffusion.num_timesteps,
            (target.shape[0],),
            device=target.device,
        )
        model_kwargs = dict(c=z)
        loss_dict = self.train_diffusion.training_losses(
            self.net, target, t, model_kwargs
        )
        loss = loss_dict["loss"]
        if mask is not None:
            loss = (loss * mask).sum() / mask.sum()
        return loss.mean()

    def sample(self, z, temperature=1.0, cfg=1.0, text_latents=None):
        if not cfg == 1.0:
            noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
            noise = torch.cat([noise, noise], dim=0)
            model_kwargs = dict(c=z, cfg_scale=cfg)
            sample_fn = self.net.forward_with_cfg
        else:
            noise = torch.randn(z.shape[0], self.in_channels).cuda()
            model_kwargs = dict(c=z)
            sample_fn = self.net.forward
        sampled_token_latent = self.gen_diffusion.p_sample_loop(
            sample_fn,
            noise.shape,
            noise,
            clip_denoised=False,
            model_kwargs=model_kwargs,
            progress=False,
            temperature=temperature,
        )
        return sampled_token_latent


# --- MAR video-only: first-frame conditioning, no text/proprio/action/wrist ---

from functools import partial
import numpy as np
from tqdm import tqdm
import scipy.stats as stats
from einops import rearrange
from timm.models.vision_transformer import Block


def mask_by_order(mask_len, order, bsz, seq_len, device):
    masking = torch.zeros(bsz, seq_len).to(device)
    masking = torch.scatter(
        masking,
        dim=-1,
        index=order[:, : mask_len.long()],
        src=torch.ones(bsz, seq_len).to(device),
    ).bool()
    return masking


class MARVideoOnly(nn.Module):
    """MAR for video-from-first-frame only: no text, proprio, action, or wrist."""

    def __init__(
        self,
        img_size=256,
        vae_stride=16,
        patch_size=1,
        encoder_embed_dim=1024,
        encoder_depth=16,
        encoder_num_heads=16,
        decoder_embed_dim=1024,
        decoder_depth=16,
        decoder_num_heads=16,
        mlp_ratio=4.0,
        norm_layer=nn.LayerNorm,
        vae_embed_dim=16,
        mask_ratio_min=0.7,
        label_drop_prob=0.1,
        attn_dropout=0.1,
        proj_dropout=0.1,
        diffloss_d=3,
        diffloss_w=1024,
        num_sampling_steps="100",
        grad_checkpointing=False,
        **kwargs
    ):
        super().__init__()
        self.n_frames = 4
        self.buffer_size_text = 64
        self.buffer_size_action = 64

        self.img_size = img_size
        self.vae_stride = vae_stride
        self.patch_size = patch_size
        self.seq_h = self.seq_w = img_size // vae_stride // patch_size
        self.seq_len = self.seq_h * self.seq_w
        self.token_embed_dim = vae_embed_dim * patch_size**2
        self.vae_embed_dim = vae_embed_dim
        self.grad_checkpointing = grad_checkpointing
        self.mask_ratio_generator = stats.truncnorm(
            (mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25
        )

        self.z_proj_cond = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
        self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
        self.fake_latent_x = nn.Parameter(torch.zeros(1, encoder_embed_dim))
        self.fake_action_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
        self.proj_cond_x_layer = nn.Linear(
            3 * encoder_embed_dim, encoder_embed_dim, bias=True
        )

        self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
        self.text_pos_embed = nn.Parameter(
            torch.zeros(1, self.buffer_size_text, encoder_embed_dim)
        )

        self.temporal_pos_embed = nn.Parameter(
            torch.zeros(1, self.n_frames, encoder_embed_dim)
        )
        self.spatial_pos_embed = nn.Parameter(
            torch.zeros(1, self.seq_len, encoder_embed_dim)
        )
        self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)

        self.encoder_blocks = nn.ModuleList(
            [
                Block(
                    encoder_embed_dim,
                    encoder_num_heads,
                    mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                    proj_drop=proj_dropout,
                    attn_drop=attn_dropout,
                )
                for _ in range(encoder_depth)
            ]
        )
        self.encoder_norm = norm_layer(encoder_embed_dim)

        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.decoder_temporal_pos_embed = nn.Parameter(
            torch.zeros(1, self.n_frames, decoder_embed_dim)
        )
        self.decoder_spatial_pos_embed = nn.Parameter(
            torch.zeros(1, self.seq_len, decoder_embed_dim)
        )
        self.decoder_text_pos_embed = nn.Parameter(
            torch.zeros(1, self.buffer_size_text, decoder_embed_dim)
        )
        self.decoder_blocks = nn.ModuleList(
            [
                Block(
                    decoder_embed_dim,
                    decoder_num_heads,
                    mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                    proj_drop=proj_dropout,
                    attn_drop=attn_dropout,
                )
                for _ in range(decoder_depth)
            ]
        )
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.diffusion_temporal_embed = nn.Parameter(
            torch.zeros(1, self.n_frames, decoder_embed_dim)
        )
        self.diffusion_spatial_embed = nn.Parameter(
            torch.zeros(1, self.seq_len, decoder_embed_dim)
        )

        self.initialize_weights()

        self.diffloss = DiffLoss(
            target_channels=self.token_embed_dim,
            z_channels=decoder_embed_dim,
            width=diffloss_w,
            depth=diffloss_d,
            num_sampling_steps=num_sampling_steps,
            grad_checkpointing=grad_checkpointing,
            n_frames=self.n_frames,
            language_emb_model="clip",
            language_emb_model_type=1,
        )

        self.predict_para = kwargs.get("predict_para", False)
        self.para_head = None
        if self.predict_para:
            from simple_uva.para_head import ParaHead
            self.para_head = ParaHead(
                decoder_embed_dim=decoder_embed_dim,
                n_bins=kwargs.get("para_n_bins", 32),
                in_grid_size=self.seq_h,
                out_size=kwargs.get("para_out_size", 64),
            )

    def initialize_weights(self):
        torch.nn.init.normal_(self.fake_latent_x, std=0.02)
        torch.nn.init.normal_(self.fake_action_latent, std=0.02)
        torch.nn.init.normal_(self.fake_latent, std=0.02)
        torch.nn.init.normal_(self.temporal_pos_embed, std=0.02)
        torch.nn.init.normal_(self.spatial_pos_embed, std=0.02)
        torch.nn.init.normal_(self.decoder_temporal_pos_embed, std=0.02)
        torch.nn.init.normal_(self.decoder_spatial_pos_embed, std=0.02)
        torch.nn.init.normal_(self.diffusion_temporal_embed, std=0.02)
        torch.nn.init.normal_(self.diffusion_spatial_embed, std=0.02)
        torch.nn.init.normal_(self.text_pos_embed, std=0.02)
        torch.nn.init.normal_(self.decoder_text_pos_embed, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, x):
        bsz, c, h, w = x.shape
        p = self.patch_size
        h_, w_ = h // p, w // p
        x = x.reshape(bsz, c, h_, p, w_, p)
        x = torch.einsum("nchpwq->nhwcpq", x)
        x = x.reshape(bsz, h_ * w_, c * p**2)
        return x

    def unpatchify(self, x):
        bsz = x.shape[0]
        p = self.patch_size
        c = self.vae_embed_dim
        h_, w_ = self.seq_h, self.seq_w
        x = x.reshape(bsz, h_, w_, c, p, p)
        x = torch.einsum("nhwcpq->nchpwq", x)
        x = x.reshape(bsz, c, h_ * p, w_ * p)
        return x

    def sample_orders(self, bsz):
        orders = []
        for _ in range(bsz):
            order = np.array(list(range(self.seq_len)))
            np.random.shuffle(order)
            orders.append(order)
        return torch.Tensor(np.array(orders)).to(self.device).long()

    def random_masking(self, x, orders):
        bsz, t, seq_len, embed_dim = x.shape
        mask_rate = self.mask_ratio_generator.rvs(1)[0]
        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        spatial_mask = torch.zeros(bsz, seq_len, device=x.device)
        spatial_mask = torch.scatter(
            spatial_mask,
            dim=-1,
            index=orders[:, :num_masked_tokens],
            src=torch.ones(bsz, seq_len, device=x.device),
        )
        mask = spatial_mask.unsqueeze(1).expand(-1, t, -1)
        return mask

    def forward_mae_encoder(self, x, mask, cond):
        B, T, S, _ = x.size()
        mask = rearrange(mask, "b t s -> b (t s)")
        cond = self.z_proj_cond(cond)
        cond = rearrange(cond, "b t s c -> b (t s) c")
        x = self.z_proj(x)
        x = rearrange(x, "b t s c -> b (t s) c")
        fake_latent_expanded = self.fake_latent_x.unsqueeze(1).expand(B, x.size(1), -1)
        x[mask == 1] = fake_latent_expanded[mask == 1].to(x.dtype)

        action_latents = self.fake_action_latent.unsqueeze(0).repeat(B, 16, 1)
        action_latents_expand = action_latents.repeat_interleave(
            self.buffer_size_action, dim=1
        )
        x = torch.cat([x, cond, action_latents_expand], dim=-1)
        x = self.proj_cond_x_layer(x)
        embed_dim = x.size(2)

        temporal_pos_embed_expanded = self.temporal_pos_embed.unsqueeze(2).expand(
            -1, -1, S, -1
        )
        spatial_pos_embed_expanded = self.spatial_pos_embed.unsqueeze(1).expand(
            -1, T, -1, -1
        )
        combined_pos_embed = (
            temporal_pos_embed_expanded + spatial_pos_embed_expanded
        ).reshape(-1, T * S, embed_dim)
        x = x + combined_pos_embed

        text_latents = (
            self.fake_latent.unsqueeze(1)
            .repeat(1, self.buffer_size_text, 1)
            .expand(B, -1, -1)
        )
        text_latents = text_latents + self.text_pos_embed
        x = torch.cat([text_latents, x], dim=1)
        x = self.z_proj_ln(x)

        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.encoder_blocks:
                x = checkpoint(block, x)
        else:
            for block in self.encoder_blocks:
                x = block(x)
        x = self.encoder_norm(x)
        return x

    def forward_mae_decoder(self, x, mask):
        B, T, S = mask.size()
        mask = rearrange(mask, "b t s -> b (t s)")
        x = self.decoder_embed(x)
        _, _, embed_dim = x.shape
        decoder_temporal_pos_embed_expanded = self.decoder_temporal_pos_embed.unsqueeze(
            2
        ).expand(-1, -1, S, -1)
        decoder_spatial_pos_embed_expanded = self.decoder_spatial_pos_embed.unsqueeze(
            1
        ).expand(-1, T, -1, -1)
        decoder_combined_pos_embed = (
            decoder_temporal_pos_embed_expanded + decoder_spatial_pos_embed_expanded
        ).reshape(1, T * S, embed_dim)
        combined_pos_embed = torch.cat(
            [self.decoder_text_pos_embed, decoder_combined_pos_embed], dim=1
        )
        x = x + combined_pos_embed

        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.decoder_blocks:
                x = checkpoint(block, x)
        else:
            for block in self.decoder_blocks:
                x = block(x)
        x = self.decoder_norm(x)
        x = x[:, self.buffer_size_text :]

        diffusion_temporal_pos_embed_expanded = self.diffusion_temporal_embed.unsqueeze(
            2
        ).expand(-1, -1, S, -1)
        diffusion_spatial_pos_embed_expanded = self.diffusion_spatial_embed.unsqueeze(
            1
        ).expand(-1, T, -1, -1)
        diffusion_combined_pos_embed = (
            diffusion_temporal_pos_embed_expanded + diffusion_spatial_pos_embed_expanded
        ).reshape(1, T * S, embed_dim)
        x = x + diffusion_combined_pos_embed
        return x

    def sample_tokens(
        self,
        bsz,
        cond,
        num_iter=64,
        cfg=1.0,
        cfg_schedule="linear",
        temperature=1.0,
        progress=False,
        **kwargs
    ):
        self.device = cond.device
        B, T, C, H, W = cond.size()
        cond = rearrange(cond, "b t c h w -> (b t) c h w")
        cond = self.patchify(cond)
        cond = rearrange(cond, "(b t) seq_len c -> b t seq_len c", b=B)

        tokens = torch.zeros(
            bsz, self.n_frames, self.seq_len, self.token_embed_dim, device=self.device
        )
        mask = torch.ones(bsz, self.n_frames, self.seq_len, device=self.device)
        orders = self.sample_orders(bsz)
        indices = list(range(num_iter))
        if progress:
            indices = tqdm(indices)

        for step in indices:
            cur_tokens = tokens.clone()
            x = self.forward_mae_encoder(tokens, mask, cond)
            z = self.forward_mae_decoder(x, mask)

            mask_ratio = np.cos(math.pi / 2.0 * (step + 1) / num_iter)
            mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).to(
                self.device
            )
            mask_ = mask[:, 0]
            mask_len = torch.maximum(
                torch.Tensor([1]).to(self.device),
                torch.minimum(
                    torch.sum(mask_, dim=-1, keepdims=True) - 1, mask_len
                ),
            )
            mask_next = mask_by_order(
                mask_len[0], orders, bsz, self.seq_len, self.device
            )
            mask_next = mask_next.unsqueeze(1).expand(-1, T, -1)
            mask_next = rearrange(mask_next, "b t s -> b (t s)")
            mask_flat = rearrange(mask, "b t s -> b (t s)")
            if step >= num_iter - 1:
                mask_to_pred = mask_flat[:bsz].bool()
            else:
                mask_to_pred = torch.logical_xor(
                    mask_flat[:bsz].bool(), mask_next.bool()
                )
            mask = rearrange(mask_next, "b (t s) -> b t s", t=self.n_frames)

            if not cfg == 1.0:
                mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)

            z_pred = z[mask_to_pred.nonzero(as_tuple=True)]
            if cfg_schedule == "linear":
                cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
            else:
                cfg_iter = cfg
            sampled_token_latent = self.diffloss.sample(
                z_pred, temperature, cfg_iter, text_latents=None
            )
            if not cfg == 1.0:
                sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
                mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)

            cur_tokens_flat = rearrange(cur_tokens, "b t s c -> b (t s) c")
            cur_tokens_flat[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
            tokens = rearrange(
                cur_tokens_flat, "b (t s) c -> b t s c", t=self.n_frames
            ).clone()

        tokens = rearrange(tokens, "b t s c -> (b t) s c")
        tokens = self.unpatchify(tokens)
        return tokens, None

    def forward_decode_tokens(self, x, cond, mask=None):
        """Run encoder and decoder, return decoder tokens (B, T, S, C). If mask is None, no masking (all visible)."""
        self.device = x.device
        B, T, S, _ = x.size()
        if mask is None:
            mask = torch.zeros(B, T, S, device=x.device, dtype=x.dtype)
        h = self.forward_mae_encoder(x, mask, cond)
        z = self.forward_mae_decoder(h, mask)
        z = rearrange(z, "b (t s) c -> b t s c", t=T, s=S)
        return z

    def forward_para(self, x, cond, mask=None):
        """Return volume_logits (B, T, n_bins, H_out, W_out) from decoder tokens + PARA head."""
        z = self.forward_decode_tokens(x, cond, mask=mask)
        return self.para_head(z)

    def forward(self, x, cond):
        """Forward for DataParallel: same as compute_loss. Returns scalar loss."""
        return self.compute_loss(x, cond)

    def compute_loss(self, x, cond):
        """Training loss: x, cond (B, T, S, C) token space. Returns scalar loss."""
        self.device = x.device
        B, T, S, _ = x.size()
        orders = self.sample_orders(B)
        mask = self.random_masking(x, orders)
        h = self.forward_mae_encoder(x, mask, cond)
        z = self.forward_mae_decoder(h, mask)
        gt_flat = rearrange(x, "b t s c -> b (t s) c")
        mask_flat = rearrange(mask, "b t s -> b (t s)")
        loss = self.diffloss(target=gt_flat, z=z, mask=mask_flat)
        return loss


def mar_base_video_only(**kwargs):
    """Video-only MAR with base architecture (same as mar_base)."""
    return MARVideoOnly(
        encoder_embed_dim=768,
        encoder_depth=12,
        encoder_num_heads=12,
        decoder_embed_dim=768,
        decoder_depth=12,
        decoder_num_heads=12,
        mlp_ratio=4,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
