"""Tests for Spartan converter."""

import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch

import fsspec
import numpy as np
import pytest
import yaml

from vla_foundry.data.preprocessing.robotics.converters.spartan import SpartanConverter


# Helper functions for testing (non-Ray versions)
def check_episode_validity(episode_path: str) -> bool:
    """Check if an episode directory has valid processed data."""
    fs, _ = fsspec.core.url_to_fs(episode_path)
    processed_path = os.path.join(episode_path, "processed")
    fs_processed_path = processed_path.replace("s3://", "")

    try:
        if not fs.exists(fs_processed_path):
            return False

        # Check for required files
        required_files = ["metadata.yaml", "observations.npz"]
        for required_file in required_files:
            file_path = os.path.join(processed_path, required_file)
            fs_file_path = file_path.replace("s3://", "")
            if not fs.exists(fs_file_path):
                return False
        return True
    except Exception:
        return False


def discover_episodes_in_directory(diffusion_spartan_path: str, max_episodes: int = -1) -> list:
    """Discover and validate episodes in a diffusion_spartan directory."""
    fs, _ = fsspec.core.url_to_fs(diffusion_spartan_path)
    fs_path = diffusion_spartan_path.replace("s3://", "")

    try:
        items = fs.listdir(fs_path)
    except Exception as e:
        print(f"Warning: Cannot list directory {diffusion_spartan_path}: {e}")
        return []

    # First pass: identify episode directories only
    episode_paths = []
    for item in items:
        item_name = item["name"] if isinstance(item, dict) else item
        item_basename = os.path.basename(item_name.rstrip("/"))

        # Only process directories that start with "episode_" - skip all files
        if item_basename.startswith("episode_") and not any(
            item_basename.endswith(ext) for ext in [".pkl", ".npz", ".txt", ".json", ".yaml", ".tar", ".gz"]
        ):
            episode_path = os.path.join(diffusion_spartan_path, item_basename)
            episode_paths.append(episode_path)

            # Early exit if we have enough episodes
            if max_episodes > 0 and len(episode_paths) >= max_episodes:
                break

    if not episode_paths:
        return []

    # Second pass: validate episodes
    valid_episodes = [ep for ep in episode_paths if check_episode_validity(ep)]

    return valid_episodes


@pytest.fixture
def mock_config():
    """Create a mock config for SpartanConverter."""
    cfg = MagicMock()
    cfg.type = "spartan"
    cfg.language_annotations_path = None
    cfg.action_fields_config_path = None
    cfg.data_discard_keys = []
    cfg.camera_names = []
    cfg.resize_images_size = [384, 384]
    cfg.image_indices = [-1, 0]
    cfg.jpeg_quality = 90
    cfg.padding_strategy = "copy"  # Use valid padding strategy
    cfg.past_lowdim_steps = 5
    cfg.future_lowdim_steps = 10
    cfg.max_padding_left = 5
    cfg.max_padding_right = 10
    cfg.filter_still_samples = False
    cfg.still_threshold = 0.01
    cfg.stride = 1
    cfg.num_workers = 1
    cfg.output_dir = "/tmp/test_output"
    cfg.min_depth = 0.001
    cfg.max_depth = 3.0
    cfg.jpeg_quality = 90
    cfg.fail_on_nan = False
    return cfg


@pytest.fixture
def temp_spartan_episodes():
    """Create temporary episode directories with required files."""
    with tempfile.TemporaryDirectory() as tmpdir:
        base_path = Path(tmpdir) / "data" / "diffusion_spartan"
        base_path.mkdir(parents=True, exist_ok=True)

        # Create valid episodes
        for i in range(3):
            episode_dir = base_path / f"episode_{i:04d}"
            processed_dir = episode_dir / "processed"
            processed_dir.mkdir(parents=True, exist_ok=True)

            # Create required files
            metadata_path = processed_dir / "metadata.yaml"
            metadata_path.write_text(yaml.dump({"episode_id": i}))

            observations_path = processed_dir / "observations.npz"
            np.savez(observations_path, obs=np.zeros((10, 3)))

        # Create invalid episode (missing observations.npz)
        invalid_episode_dir = base_path / "episode_0003"
        invalid_processed_dir = invalid_episode_dir / "processed"
        invalid_processed_dir.mkdir(parents=True, exist_ok=True)
        metadata_path = invalid_processed_dir / "metadata.yaml"
        metadata_path.write_text(yaml.dump({"episode_id": 3}))

        yield str(base_path)


def test_check_episode_validity_valid(temp_spartan_episodes):
    """Test that check_episode_validity returns True for valid episodes."""
    episode_path = os.path.join(temp_spartan_episodes, "episode_0000")
    result = check_episode_validity(episode_path)
    assert result is True


def test_check_episode_validity_invalid(temp_spartan_episodes):
    """Test that check_episode_validity returns False for invalid episodes."""
    episode_path = os.path.join(temp_spartan_episodes, "episode_0003")
    result = check_episode_validity(episode_path)
    assert result is False


def test_check_episode_validity_nonexistent(temp_spartan_episodes):
    """Test that check_episode_validity returns False for non-existent episodes."""
    episode_path = os.path.join(temp_spartan_episodes, "episode_9999")
    result = check_episode_validity(episode_path)
    assert result is False


def test_discover_episodes_in_directory(temp_spartan_episodes):
    """Test discovery and validation of episodes."""
    result = discover_episodes_in_directory(temp_spartan_episodes)

    # Should find 3 valid episodes (0000, 0001, 0002) and exclude the invalid one (0003)
    assert len(result) == 3
    assert all("episode_" in ep for ep in result)

    # Check that valid episodes are in the result
    episode_names = [os.path.basename(ep) for ep in result]
    assert "episode_0000" in episode_names
    assert "episode_0001" in episode_names
    assert "episode_0002" in episode_names
    assert "episode_0003" not in episode_names


def test_discover_episodes_with_max_limit(temp_spartan_episodes):
    """Test that max_episodes parameter limits the number of episodes discovered."""
    result = discover_episodes_in_directory(temp_spartan_episodes, max_episodes=2)

    # Should find at most 2 episodes due to the limit
    assert len(result) <= 2


def test_discover_episodes_integration(temp_spartan_episodes, mock_config):
    """Test the full discover_episodes method by mocking Ray calls."""
    # Mock the language annotations and action fields to avoid file loading
    with patch("builtins.open"), patch("yaml.safe_load") as mock_yaml:
        mock_yaml.return_value = {"language_dict": {}}
        mock_config.language_annotations_path = "/tmp/fake_annotations.yaml"
        mock_config.action_fields_config_path = "/tmp/fake_action_fields.yaml"
        mock_config.validation_episodes_path = None

        with (
            patch(
                "vla_foundry.data.robotics.utils.load_action_field_config",
                return_value={"action_key_fields": [], "action_index_fields": [], "pose_groups": []},
            ),
            patch("ray.get") as mock_ray_get,
            patch(
                "vla_foundry.data.preprocessing.robotics.converters.spartan.discover_and_validate_episodes_in_directory"
            ),
        ):
            # Set up mock to return valid episodes
            valid_episodes = [os.path.join(temp_spartan_episodes, f"episode_{i:04d}") for i in range(3)]
            mock_ray_get.return_value = [valid_episodes]

            converter = SpartanConverter(mock_config)
            episodes = converter.discover_episodes([temp_spartan_episodes])

            # Should discover 3 valid episodes
            assert len(episodes) == 3
            assert all("episode_" in ep for ep in episodes)


def test_discover_episodes_with_max_episodes_to_process(temp_spartan_episodes, mock_config):
    """Test that max_episodes_to_process limits the total number of episodes returned."""
    with patch("builtins.open"), patch("yaml.safe_load") as mock_yaml:
        mock_yaml.return_value = {"language_dict": {}}
        mock_config.language_annotations_path = "/tmp/fake_annotations.yaml"
        mock_config.action_fields_config_path = "/tmp/fake_action_fields.yaml"
        mock_config.validation_episodes_path = None

        with (
            patch(
                "vla_foundry.data.robotics.utils.load_action_field_config",
                return_value={"action_key_fields": [], "action_index_fields": [], "pose_groups": []},
            ),
            patch("ray.get") as mock_ray_get,
            patch(
                "vla_foundry.data.preprocessing.robotics.converters.spartan.discover_and_validate_episodes_in_directory"
            ),
        ):
            # Set up mock to return 3 valid episodes
            valid_episodes = [os.path.join(temp_spartan_episodes, f"episode_{i:04d}") for i in range(3)]
            mock_ray_get.return_value = [valid_episodes]

            converter = SpartanConverter(mock_config)
            episodes = converter.discover_episodes([temp_spartan_episodes], max_episodes_to_process=2)

            # Should only return 2 episodes due to the limit
            assert len(episodes) == 2


def test_discover_episodes_empty_directory(mock_config):
    """Test discover_episodes with an empty directory."""
    with tempfile.TemporaryDirectory() as tmpdir:
        empty_dir = Path(tmpdir) / "empty"
        empty_dir.mkdir(parents=True, exist_ok=True)

        with patch("builtins.open"), patch("yaml.safe_load") as mock_yaml:
            mock_yaml.return_value = {"language_dict": {}}
            mock_config.language_annotations_path = "/tmp/fake_annotations.yaml"
            mock_config.action_fields_config_path = "/tmp/fake_action_fields.yaml"
            mock_config.validation_episodes_path = None

            with (
                patch(
                    "vla_foundry.data.robotics.utils.load_action_field_config",
                    return_value={"action_key_fields": [], "action_index_fields": [], "pose_groups": []},
                ),
                patch("ray.get") as mock_ray_get,
                patch(
                    "vla_foundry.data.preprocessing.robotics.converters.spartan.discover_and_validate_episodes_in_directory"
                ),
            ):
                mock_ray_get.return_value = [[]]

                converter = SpartanConverter(mock_config)
                episodes = converter.discover_episodes([str(empty_dir)])

                # Should return empty list
                assert len(episodes) == 0


class TestPointCloudGeneration:
    """Tests for point cloud generation in SpartanConverter."""

    @pytest.fixture
    def mock_episode_data(self):
        """Create mock episode data with depth images, RGB images, and calibration."""
        # Create mock observations with RGB and depth images
        observations = {
            "camera1": np.random.randint(0, 255, (10, 480, 640, 3), dtype=np.uint8),  # RGB
            "camera1_depth": np.random.randint(1000, 3000, (10, 480, 640), dtype=np.uint16),  # Depth
            "robot_joint_positions": np.random.randn(10, 7),
        }

        # Create mock intrinsics and extrinsics
        intrinsics = {
            "camera1": np.tile(
                np.array([[[500.0, 0.0, 320.0], [0.0, 500.0, 240.0], [0.0, 0.0, 1.0]]]), (10, 1, 1)
            )  # (T, 3, 3)
        }

        extrinsics = {"camera1": np.tile(np.eye(4)[np.newaxis, :, :], (10, 1, 1))}  # (T, 4, 4)

        metadata = {
            "camera_id_to_semantic_name": {"camera1": "camera1"},
            "episode_length": 10,
        }

        actions = {"actions": np.random.randn(10, 7)}

        return {
            "observations": observations,
            "intrinsics": intrinsics,
            "extrinsics": extrinsics,
            "metadata": metadata,
            "actions": actions,
        }

    def test_use_depth_data_true_generates_point_clouds(self, mock_config, mock_episode_data):
        """Test that extract_sample_data generates point clouds when use_depth_data=True."""
        mock_config.use_depth_data = True
        mock_config.point_cloud_num_points = 1000

        with patch("builtins.open"), patch("yaml.safe_load") as mock_yaml:
            mock_yaml.return_value = {"language_dict": {"test_task": {"original": ["test instruction"]}}}
            mock_config.language_annotations_path = "/tmp/fake_annotations.yaml"
            mock_config.action_fields_config_path = "/tmp/fake_action_fields.yaml"
            mock_config.validation_episodes_path = None

            with patch(
                "vla_foundry.data.robotics.utils.load_action_field_config",
                return_value={"action_key_fields": ["action"], "action_index_fields": [7], "pose_groups": []},
            ):
                converter = SpartanConverter(mock_config)

                # Mock logger actor
                logger_actor = MagicMock()
                logger_actor.increment_total_potential_samples = MagicMock(return_value=MagicMock())
                logger_actor.increment_padding_samples_filtered = MagicMock(return_value=MagicMock())
                logger_actor.increment_still_samples_filtered = MagicMock(return_value=MagicMock())

                # Call extract_sample_data
                result = converter.extract_sample_data(
                    anchor_timestep=5,
                    episode_path="/fake/path/tasks/test_task/episode_0000",
                    episode_length=10,
                    camera_data=mock_episode_data["observations"],
                    lowdim_data={"action": mock_episode_data["actions"]["actions"]},
                    intrinsics_data=mock_episode_data["intrinsics"],
                    extrinsics_data=mock_episode_data["extrinsics"],
                    metadata_data=mock_episode_data["metadata"],
                    statistics_ray_actor=None,
                    logger_actor=logger_actor,
                )

                # Unpack result (should be 7-tuple with point_clouds, point_maps, and stats)
                assert len(result) == 7
                (
                    sample_images,
                    sample_lowdim,
                    sample_metadata,
                    language_instructions,
                    sample_point_clouds,
                    sample_point_maps,
                    sample_stats,
                ) = result

                # Verify point clouds are generated
                assert sample_point_clouds is not None
                assert sample_point_clouds.shape == (len(mock_config.image_indices), 1000, 6)
                assert sample_point_clouds.dtype == np.float16

                # Verify point clouds are not all zeros (regression test)
                assert not np.all(sample_point_clouds == 0.0)

    def test_use_depth_data_false_skips_point_clouds(self, mock_config, mock_episode_data):
        """Test that extract_sample_data skips point clouds when use_depth_data=False."""
        mock_config.use_depth_data = False
        mock_config.point_cloud_num_points = 1000

        with patch("builtins.open"), patch("yaml.safe_load") as mock_yaml:
            mock_yaml.return_value = {"language_dict": {"test_task": {"original": ["test instruction"]}}}
            mock_config.language_annotations_path = "/tmp/fake_annotations.yaml"
            mock_config.action_fields_config_path = "/tmp/fake_action_fields.yaml"
            mock_config.validation_episodes_path = None

            with patch(
                "vla_foundry.data.robotics.utils.load_action_field_config",
                return_value={"action_key_fields": ["action"], "action_index_fields": [7], "pose_groups": []},
            ):
                converter = SpartanConverter(mock_config)

                # Mock logger actor
                logger_actor = MagicMock()
                logger_actor.increment_total_potential_samples = MagicMock(return_value=MagicMock())
                logger_actor.increment_padding_samples_filtered = MagicMock(return_value=MagicMock())
                logger_actor.increment_still_samples_filtered = MagicMock(return_value=MagicMock())

                # Call extract_sample_data
                result = converter.extract_sample_data(
                    anchor_timestep=5,
                    episode_path="/fake/path/tasks/test_task/episode_0000",
                    episode_length=10,
                    camera_data=mock_episode_data["observations"],
                    lowdim_data={"action": mock_episode_data["actions"]["actions"]},
                    intrinsics_data=mock_episode_data["intrinsics"],
                    extrinsics_data=mock_episode_data["extrinsics"],
                    metadata_data=mock_episode_data["metadata"],
                    statistics_ray_actor=None,
                    logger_actor=logger_actor,
                )

                # Unpack result (should be 7-tuple but point_clouds is None)
                assert len(result) == 7
                (
                    sample_images,
                    sample_lowdim,
                    sample_metadata,
                    language_instructions,
                    sample_point_clouds,
                    sample_point_maps,
                    sample_stats,
                ) = result

                # Verify point clouds are NOT generated
                assert sample_point_clouds is None

    def test_use_depth_data_false_skips_depth_images(self, mock_config, mock_episode_data):
        """Test that depth images are not in sample_images when use_depth_data=False."""
        mock_config.use_depth_data = False

        with patch("builtins.open"), patch("yaml.safe_load") as mock_yaml:
            mock_yaml.return_value = {"language_dict": {"test_task": {"original": ["test instruction"]}}}
            mock_config.language_annotations_path = "/tmp/fake_annotations.yaml"
            mock_config.action_fields_config_path = "/tmp/fake_action_fields.yaml"
            mock_config.validation_episodes_path = None

            with patch(
                "vla_foundry.data.robotics.utils.load_action_field_config",
                return_value={"action_key_fields": ["action"], "action_index_fields": [7], "pose_groups": []},
            ):
                converter = SpartanConverter(mock_config)

                # Mock logger actor
                logger_actor = MagicMock()
                logger_actor.increment_total_potential_samples = MagicMock(return_value=MagicMock())
                logger_actor.increment_padding_samples_filtered = MagicMock(return_value=MagicMock())
                logger_actor.increment_still_samples_filtered = MagicMock(return_value=MagicMock())

                # Extract camera data using the converter's method (which filters depth based on use_depth_data)
                camera_data = converter.extract_camera_data(mock_episode_data)

                # Call extract_sample_data
                result = converter.extract_sample_data(
                    anchor_timestep=5,
                    episode_path="/fake/path/tasks/test_task/episode_0000",
                    episode_length=10,
                    camera_data=camera_data,
                    lowdim_data={"action": mock_episode_data["actions"]["actions"]},
                    intrinsics_data=mock_episode_data["intrinsics"],
                    extrinsics_data=mock_episode_data["extrinsics"],
                    metadata_data=mock_episode_data["metadata"],
                    statistics_ray_actor=None,
                    logger_actor=logger_actor,
                )

                (
                    sample_images,
                    sample_lowdim,
                    sample_metadata,
                    language_instructions,
                    sample_point_clouds,
                    sample_point_maps,
                    sample_stats,
                ) = result

                # Verify depth images are NOT in sample_images
                depth_keys = [k for k in sample_images if "_depth" in k]
                assert len(depth_keys) == 0

                # Verify RGB images ARE in sample_images
                rgb_keys = [k for k in sample_images if "_depth" not in k]
                assert len(rgb_keys) > 0

    def test_intrinsics_extrinsics_lookup_regression(self, mock_config):
        """Regression test: intrinsics/extrinsics should use camera_name as key, not 'intrinsics.camera_name'."""
        # This tests the bug fix where we were looking for "intrinsics.camera_name" but
        # extract_sample_camera_calibration() stores with just "camera_name" as the key

        with patch("builtins.open"), patch("yaml.safe_load") as mock_yaml:
            mock_yaml.return_value = {"language_dict": {}}
            mock_config.language_annotations_path = "/tmp/fake_annotations.yaml"
            mock_config.action_fields_config_path = "/tmp/fake_action_fields.yaml"
            mock_config.validation_episodes_path = None

            with patch(
                "vla_foundry.data.robotics.utils.load_action_field_config",
                return_value={"action_key_fields": [], "action_index_fields": []},
            ):
                # Create sample intrinsics/extrinsics with camera_name keys (NOT prefixed)
                sample_intrinsics = {"camera_left": np.random.randn(10, 3, 3)}

                camera_name = "camera_left"

                # Verify lookup works with camera_name (not "intrinsics.camera_name")
                assert camera_name in sample_intrinsics  # Should be True
                assert f"intrinsics.{camera_name}" not in sample_intrinsics  # Should be False
                assert f"original_intrinsics.{camera_name}" not in sample_intrinsics  # Should be False

                # This is the correct lookup pattern (what the fixed code does)
                if camera_name in sample_intrinsics:
                    intrinsics = sample_intrinsics[camera_name]
                    assert intrinsics is not None
                    assert intrinsics.shape == (10, 3, 3)

                # This was the bug (incorrect lookup pattern)
                incorrect_key = f"intrinsics.{camera_name}"
                assert incorrect_key not in sample_intrinsics
