# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmark script to measure the performance overhead of retry statistics tracking."""

import gc
import os
import statistics
import subprocess
import sys
import threading
import time
from http.client import IncompleteRead
from unittest.mock import MagicMock

import cosmos_predict2._src.imaginaire.datasets.webdataset.utils.stream as stream_module
from cosmos_predict2._src.imaginaire.datasets.webdataset.utils.stream import RetryingStream

# Configure faster logging interval for tests (10 seconds instead of 5 minutes)
stream_module.RETRY_STATS_LOG_INTERVAL = 10.0


def benchmark_iteration(enable_stats: bool, num_operations: int, network_delay_ms: float = 0) -> float:
    """Run a single benchmark iteration.

    Args:
        enable_stats: Whether to enable statistics tracking
        num_operations: Number of read operations to perform
        network_delay_ms: Simulated network delay in milliseconds (0 = no delay)

    Returns:
        Time taken in seconds
    """
    stream_module.ENABLE_RETRY_STATS = enable_stats

    # Setup mock
    client = MagicMock()
    test_data = b"X" * 1024  # 1KB chunks
    client.head_object.return_value = {"ContentLength": str(len(test_data))}

    mock_body = MagicMock()
    if network_delay_ms > 0:
        # Add simulated network delay to mock read
        def mock_read_with_delay(amt):
            time.sleep(network_delay_ms / 1000.0)  # Convert ms to seconds
            return test_data

        mock_body.read = mock_read_with_delay
    else:
        mock_body.read.return_value = test_data

    client.get_object.return_value = {"Body": mock_body, "ContentLength": len(test_data)}

    stream = RetryingStream(client, "benchmark-bucket", "test.tar", retries=3)

    # Disable GC during timing to reduce noise
    gc.collect()
    gc.disable()

    try:
        start_time = time.perf_counter()  # Use perf_counter for higher precision
        for _ in range(num_operations):
            stream.read(1024)
        end_time = time.perf_counter()
    finally:
        gc.enable()

    return end_time - start_time


def run_benchmark_suite(name: str, num_operations: int, num_runs: int, network_delay_ms: float = 0):
    """Run a complete benchmark suite."""
    print(f"\n{'=' * 70}")
    print(f"{name}")
    print(f"{'=' * 70}")
    print(f"Operations: {num_operations:,} per run, {num_runs} runs")
    if network_delay_ms > 0:
        print(f"Network delay: {network_delay_ms}ms per read (simulates S3 latency)")
    else:
        print(f"Network delay: None (synthetic benchmark)")
    print(f"GC disabled during timing for accuracy")
    print("-" * 70)

    # Interleave runs to reduce system variance
    with_stats_times = []
    without_stats_times = []

    for i in range(num_runs):
        # Run both configs in same iteration to reduce variance
        elapsed_with = benchmark_iteration(True, num_operations, network_delay_ms)
        elapsed_without = benchmark_iteration(False, num_operations, network_delay_ms)

        with_stats_times.append(elapsed_with)
        without_stats_times.append(elapsed_without)

        overhead_this_run = ((elapsed_with - elapsed_without) / elapsed_without) * 100
        print(
            f"Run {i + 1:2d}: stats ON={elapsed_with:.4f}s ({num_operations / elapsed_with:,.0f} ops/s) | "
            f"stats OFF={elapsed_without:.4f}s ({num_operations / elapsed_without:,.0f} ops/s) | "
            f"overhead={overhead_this_run:+.1f}%"
        )

    # Calculate statistics (use trimmed mean to reduce outlier impact)
    median_with_stats = statistics.median(with_stats_times)
    median_without_stats = statistics.median(without_stats_times)

    # Also calculate trimmed mean (remove top/bottom 10%)
    sorted_with = sorted(with_stats_times)
    sorted_without = sorted(without_stats_times)
    trim_count = max(1, len(sorted_with) // 10)
    trimmed_with = sorted_with[trim_count:-trim_count] if len(sorted_with) > 2 * trim_count else sorted_with
    trimmed_without = sorted_without[trim_count:-trim_count] if len(sorted_without) > 2 * trim_count else sorted_without

    mean_trimmed_with = statistics.mean(trimmed_with)
    mean_trimmed_without = statistics.mean(trimmed_without)

    stddev_with_stats = statistics.stdev(with_stats_times) if len(with_stats_times) > 1 else 0
    stddev_without_stats = statistics.stdev(without_stats_times) if len(without_stats_times) > 1 else 0

    # Calculate coefficient of variation (CV) to show relative stability
    cv_with = (stddev_with_stats / median_with_stats) * 100 if median_with_stats > 0 else 0
    cv_without = (stddev_without_stats / median_without_stats) * 100 if median_without_stats > 0 else 0

    print("-" * 70)
    print(f"Stats ON:  median={median_with_stats:.4f}s, stddev={stddev_with_stats:.4f}s, CV={cv_with:.1f}%")
    print(f"Stats OFF: median={median_without_stats:.4f}s, stddev={stddev_without_stats:.4f}s, CV={cv_without:.1f}%")

    if max(cv_with, cv_without) > 15:
        print(f"⚠ High variance detected (CV > 15%) - results may be unreliable due to system noise")

    # Use both median and trimmed mean for overhead calculation
    overhead_median = ((median_with_stats - median_without_stats) / median_without_stats) * 100
    overhead_trimmed = ((mean_trimmed_with - mean_trimmed_without) / mean_trimmed_without) * 100

    print(f"\nMedian overhead: {overhead_median:+.2f}%")
    print(f"Trimmed mean overhead: {overhead_trimmed:+.2f}% (outliers removed)")

    # Show per-operation overhead (using trimmed mean for robustness)
    per_op_overhead_ns = ((mean_trimmed_with - mean_trimmed_without) / num_operations) * 1e9
    per_op_overhead_us = per_op_overhead_ns / 1000.0
    print(f"Per-operation overhead: {per_op_overhead_ns:.1f} nanoseconds ({per_op_overhead_us:.3f} microseconds)")

    if network_delay_ms > 0:
        network_delay_us = network_delay_ms * 1000.0
        overhead_vs_network = (per_op_overhead_us / network_delay_us) * 100
        print(f"Overhead vs network delay: {overhead_vs_network:.4f}% of {network_delay_ms}ms")

    # Use trimmed mean for final assessment (more robust)
    if abs(overhead_trimmed) < 1.0:
        print("✓ Negligible overhead (< 1%)")
    elif abs(overhead_trimmed) < 5.0:
        print("✓ Low overhead (< 5%)")
    else:
        print("⚠ Measurable overhead (>= 5%)")

    return overhead_trimmed


def test_multithreaded_stats_correctness():
    """Test that global statistics correctly aggregate across multiple threads and instances."""
    print("\n" + "=" * 70)
    print("CORRECTNESS TEST: Multi-threaded Statistics Aggregation")
    print("=" * 70)

    # Enable stats for testing
    stream_module.ENABLE_RETRY_STATS = True

    # Reset global stats
    with stream_module._global_retry_stats.lock:
        stream_module._global_retry_stats.registered_threads.clear()
        stream_module._global_retry_stats.active_instances.clear()
        stream_module._global_retry_stats.cumulative_operations_started = 0
        stream_module._global_retry_stats.cumulative_failed_operations = 0
        stream_module._global_retry_stats.cumulative_attempts = 0

    # Test configuration
    num_threads = 4
    operations_per_thread = 50
    retry_probability = 0.3  # 30% of operations will require a retry

    # Calculate exact expected retries (deterministic based on modulo)
    total_read_ops = num_threads * operations_per_thread
    retry_every_n = int(1 / retry_probability)  # Every 3rd operation retries
    expected_retries = sum(1 for i in range(total_read_ops) if i % retry_every_n == 0)

    print(f"Configuration: {num_threads} threads, {operations_per_thread} operations per thread")
    print(f"Expected: {total_read_ops} read operations (+ init operations)")
    print(f"Expected retries: {expected_retries} operations with retries (every {retry_every_n}th operation)")
    print("-" * 70)

    # Counter to track which operation should fail
    operation_counter = {"count": 0, "lock": threading.Lock()}

    def thread_worker(thread_id: int):
        """Worker function that creates streams and performs operations."""
        client = MagicMock()
        test_data = b"X" * 1024
        client.head_object.return_value = {"ContentLength": str(len(test_data))}

        # Track operations in this thread
        local_ops = 0
        local_retries = 0

        # Keep all streams alive until thread completes to prevent premature destructor calls
        streams = []

        for i in range(operations_per_thread):
            # Determine if this operation should require a retry
            with operation_counter["lock"]:
                op_num = operation_counter["count"]
                operation_counter["count"] += 1
                should_retry = (op_num % int(1 / retry_probability)) == 0

            # Create mock body that may fail once then succeed
            mock_body = MagicMock()
            if should_retry:
                # First read fails with IncompleteRead, second succeeds
                mock_body.read.side_effect = [
                    IncompleteRead(b"partial"),
                    test_data,
                ]
                local_retries += 1
            else:
                # Always succeeds
                mock_body.read.return_value = test_data

            client.get_object.return_value = {"Body": mock_body, "ContentLength": len(test_data)}

            # Create stream and perform read
            stream = RetryingStream(client, f"bucket-{thread_id}", f"file-{i}.tar", retries=5)
            streams.append(stream)  # Keep alive to prevent premature destructor calls
            try:
                data = stream.read(1024)
                local_ops += 1
            except Exception as e:
                print(f"Thread {thread_id}: Unexpected error: {e}")

        print(f"Thread {thread_id}: Completed {local_ops} operations, {local_retries} with retries")
        return local_ops, local_retries

    # Run threads
    print("Starting threads...")
    threads = []
    results = []

    def thread_wrapper(thread_id):
        result = thread_worker(thread_id)
        results.append(result)

    start_time = time.time()
    for i in range(num_threads):
        t = threading.Thread(target=thread_wrapper, args=(i,))
        threads.append(t)
        t.start()

    # Wait for all threads to complete
    for t in threads:
        t.join()

    elapsed = time.time() - start_time
    print(f"All threads completed in {elapsed:.2f}s")
    print("-" * 70)

    # Give threads a moment to finish cleanup
    time.sleep(0.1)

    # Aggregate local values from thread results (for sanity check)
    local_total_ops = sum(r[0] for r in results)
    local_ops_with_retries = sum(r[1] for r in results)

    # Get actual stats from global tracker (aggregate from per-thread stats)
    with stream_module._global_retry_stats.lock:
        actual_ops_started = 0
        actual_failed_ops = 0
        actual_total_attempts = 0

        for thread_stats in stream_module._global_retry_stats.registered_threads.values():
            actual_ops_started += thread_stats["operations_started"]
            actual_failed_ops += thread_stats["failed_operations"]
            actual_total_attempts += thread_stats["total_attempts"]

    # Note: actual_ops_started includes init operations (get_length, get_stream) too
    # Each RetryingStream.__init__ calls _retry_operation twice
    expected_init_ops = num_threads * operations_per_thread * 2  # get_length + get_stream
    expected_read_ops = local_total_ops  # Should equal num_threads * operations_per_thread
    expected_total_with_init = expected_init_ops + expected_read_ops

    print("RESULTS:")
    print(f"Local tracking (from threads):")
    print(f"  Read operations: {local_total_ops}")
    print(f"  Operations with retries: {local_ops_with_retries}")
    print(f"Global tracking (from stats aggregator):")
    print(f"  Operations started (including init): {actual_ops_started}")
    print(f"    Expected: {expected_total_with_init} ({expected_init_ops} init + {expected_read_ops} read)")
    print(f"  Failed operations: {actual_failed_ops}")
    print(f"    Expected: {expected_retries} (deterministic)")
    print(f"  Total attempts: {actual_total_attempts}")
    print(f"    Expected: {expected_total_with_init + expected_retries} (base + retry attempts)")
    print("-" * 70)

    # Verify correctness
    success = True

    # Sanity check: local tracking should match expected
    if local_ops_with_retries != expected_retries:
        print(f"⚠ WARNING: Local thread tracking mismatch (bug in test itself!)")
        print(f"   Expected retries: {expected_retries}, Local tracked: {local_ops_with_retries}")

    # Check that we tracked the right number of operations
    if actual_ops_started != expected_total_with_init:
        print(f"❌ FAIL: Operations started mismatch!")
        print(f"   Expected: {expected_total_with_init}, Got: {actual_ops_started}")
        success = False
    else:
        print(f"✓ Operations started tracked correctly")

    # Check that failed operations matches exactly (deterministic)
    if actual_failed_ops != expected_retries:
        print(f"❌ FAIL: Failed operations mismatch!")
        print(f"   Expected: {expected_retries}, Got: {actual_failed_ops}")
        success = False
    else:
        print(f"✓ Failed operations tracked correctly")

    # Check that total attempts equals base operations + retry attempts
    expected_total_attempts = expected_total_with_init + expected_retries
    if actual_total_attempts != expected_total_attempts:
        print(f"❌ FAIL: Total attempts mismatch!")
        print(f"   Expected: {expected_total_attempts}, Got: {actual_total_attempts}")
        success = False
    else:
        print(f"✓ Total attempts tracked correctly")

    # Check that we created thread-local stats for each thread
    num_registered_threads = len(stream_module._global_retry_stats.registered_threads)
    if num_registered_threads != num_threads:
        print(f"❌ FAIL: Incorrect number of threads registered!")
        print(f"   Expected: {num_threads}, Got: {num_registered_threads}")
        success = False
    else:
        print(f"✓ All {num_threads} threads registered correctly")

    print("-" * 70)
    if success:
        print("✅ PASS: Multi-threaded statistics aggregation is CORRECT!")
    else:
        print("❌ FAIL: Multi-threaded statistics aggregation has ERRORS!")

    print("\nNote: Final stats will be logged via atexit handler when the program exits.")

    return success


def test_weakref_robustness():
    """Test that WeakSet-based tracking handles failures gracefully.

    Note: Final stats are logged via atexit handler at program exit,
    not via destructors, so we won't see "Final" logs during this test.
    """
    print("\n" + "=" * 70)
    print("ROBUSTNESS TEST: WeakSet Handles Thread Death & Init Failures")
    print("=" * 70)

    # Enable stats for testing
    stream_module.ENABLE_RETRY_STATS = True

    # Reset global stats
    with stream_module._global_retry_stats.lock:
        stream_module._global_retry_stats.registered_threads.clear()
        stream_module._global_retry_stats.active_instances.clear()
        stream_module._global_retry_stats.cumulative_operations_started = 0
        stream_module._global_retry_stats.cumulative_failed_operations = 0
        stream_module._global_retry_stats.cumulative_attempts = 0

    client = MagicMock()
    test_data = b"X" * 1024
    client.head_object.return_value = {"ContentLength": str(len(test_data))}

    # Test 1: Normal construction and destruction
    print("Test 1: Normal construction and destruction")
    mock_body = MagicMock()
    mock_body.read.return_value = test_data
    client.get_object.return_value = {"Body": mock_body, "ContentLength": len(test_data)}

    stream1 = RetryingStream(client, "bucket", "file1.tar", retries=5)
    with stream_module._global_retry_stats.lock:
        count = len(stream_module._global_retry_stats.active_instances)
    print(f"  After creating stream1: {count} active instance(s)")
    assert count == 1, f"Expected 1 instance, got {count}"

    del stream1
    gc.collect()  # Force garbage collection

    with stream_module._global_retry_stats.lock:
        count = len(stream_module._global_retry_stats.active_instances)
    print(f"  After deleting stream1: {count} active instance(s)")
    assert count == 0, f"Expected 0 instances, got {count}"
    print("  ✓ Pass: Normal lifecycle works correctly")

    # Test 2: Exception during init (simulated by creating then raising)
    print("\nTest 2: WeakSet cleans up even if instance only partially constructed")
    stream2 = RetryingStream(client, "bucket", "file2.tar", retries=5)
    with stream_module._global_retry_stats.lock:
        count_before = len(stream_module._global_retry_stats.active_instances)
    print(f"  Created stream2: {count_before} active instance(s)")

    # Simulate early destruction (exception path, thread death, etc.)
    del stream2
    gc.collect()  # Force garbage collection

    with stream_module._global_retry_stats.lock:
        count_after = len(stream_module._global_retry_stats.active_instances)
    print(f"  After destruction: {count_after} active instance(s)")
    assert count_after == 0, f"Expected 0 instances after cleanup, got {count_after}"
    print("  ✓ Pass: WeakSet automatically cleaned up")

    # Test 3: Multiple instances, destroy in random order
    print("\nTest 3: Multiple instances with out-of-order destruction")
    # Create streams with explicit references so we can delete specific ones
    s0 = RetryingStream(client, "bucket", "file0.tar", retries=5)
    s1 = RetryingStream(client, "bucket", "file1.tar", retries=5)
    s2 = RetryingStream(client, "bucket", "file2.tar", retries=5)
    s3 = RetryingStream(client, "bucket", "file3.tar", retries=5)
    s4 = RetryingStream(client, "bucket", "file4.tar", retries=5)

    with stream_module._global_retry_stats.lock:
        count = len(stream_module._global_retry_stats.active_instances)
    print(f"  Created 5 streams: {count} active instance(s)")
    assert count == 5, f"Expected 5 instances, got {count}"

    # Delete specific streams (keep s1 and s3 alive)
    del s0
    del s2
    del s4
    gc.collect()  # Force garbage collection to ensure destructors run

    with stream_module._global_retry_stats.lock:
        count = len(stream_module._global_retry_stats.active_instances)
    print(f"  After deleting 3 streams: {count} active instance(s)")
    assert count == 2, f"Expected 2 instances (s1, s3), got {count}"

    # Clean up remaining (s1 and s3)
    del s1
    del s3
    gc.collect()  # Force garbage collection

    with stream_module._global_retry_stats.lock:
        count = len(stream_module._global_retry_stats.active_instances)
    print(f"  After deleting all: {count} active instance(s)")
    assert count == 0, f"Expected 0 instances, got {count}"
    print("  ✓ Pass: Out-of-order destruction handled correctly")

    print("-" * 70)
    print("✅ PASS: WeakSet-based tracking is ROBUST!")
    print("  - Handles normal lifecycle")
    print("  - Automatically cleans up dead references")
    print("  - Works with arbitrary destruction order")
    print("  - No risk of stuck counters or deadlocks")

    return True


def test_multi_rank_stats_logging():
    """Test stats logging with multiple ranks AND multiple workers per rank (simulating DataLoader).

    This test uses actual distributed launchers:
    1. Tests with torchrun (PyTorch's distributed launcher) if available
    2. Tests with mpirun (OpenMPI/MPICH) if available (requires mpi4py: `uv pip install mpi4py`)
    3. Skips test if neither available

    The worker script (mpi_rank_worker.py) is launched by real launchers.
    Each rank spawns multiple worker processes (via multiprocessing) to simulate
    DataLoader workers with num_workers > 0.

    Multi-level testing:
    - Multiple ranks (distributed training)
    - Multiple workers per rank (DataLoader processes)
    - Different workload per worker (simulates real workload imbalance)
    - Different failure patterns per worker

    This ensures:
    - Each worker process has independent statistics (separate PID, separate _global_retry_stats)
    - Each worker logs its own statistics with correct PID
    - Worker processes can have the same thread ID but different PIDs
    - Statistics are correctly isolated between workers and between ranks

    Note: mpi4py is an optional dependency only needed for mpirun testing.
    """

    print("\n" + "=" * 70)
    print("MULTI-RANK (REAL MPI/TORCHRUN) TEST")
    print("=" * 70)

    world_size = 10

    # Get path to worker script
    worker_script = os.path.join(os.path.dirname(__file__), "mpi_rank_worker.py")

    # Check which launchers are available
    available_launchers = {}

    print("Checking available launchers...")

    # Check torchrun
    try:
        result = subprocess.run(["torchrun", "--help"], capture_output=True, timeout=5)
        if result.returncode == 0:
            available_launchers["torchrun"] = [
                "torchrun",
                "--standalone",
                "--nnodes=1",
                f"--nproc_per_node={world_size}",
                worker_script,
            ]
            print("  ✓ torchrun available")
    except (FileNotFoundError, subprocess.TimeoutExpired):
        print("  ✗ torchrun not found")

    # Check mpirun
    try:
        result = subprocess.run(["mpirun", "--version"], capture_output=True, timeout=5)
        if result.returncode == 0:
            available_launchers["mpirun"] = [
                "mpirun",
                "--oversubscribe",  # Allow more processes than physical cores
                "--tag-output",  # Prefix output with [rank,node]
                "-np",
                str(world_size),
                sys.executable,
                "-u",  # Unbuffered Python output
                worker_script,
            ]
            print("  ✓ mpirun available")
    except (FileNotFoundError, subprocess.TimeoutExpired):
        print("  ✗ mpirun not found")

    if not available_launchers:
        print("\n⚠ SKIP: Neither torchrun nor mpirun available")
        print("  Install PyTorch (for torchrun) or OpenMPI (for mpirun)")
        print("=" * 70)
        return True  # Not a failure, just skip

    print(f"\nTesting with {len(available_launchers)} launcher(s): {', '.join(available_launchers.keys())}")
    print(f"  World size: {world_size} ranks per launcher")
    print(f"  Worker script: {worker_script}")
    print("=" * 70)

    # Test with each available launcher
    all_passed = True
    results = {}

    for launcher_name, launcher_cmd in available_launchers.items():
        print(f"\n{'=' * 70}")
        print(f"TESTING WITH: {launcher_name}")
        print(f"{'=' * 70}")
        print(f"Command: {' '.join(launcher_cmd)}")
        print("-" * 70)

        try:
            result = subprocess.run(
                launcher_cmd,
                capture_output=True,  # Capture output for verification
                text=True,
                timeout=120,  # 2 minute timeout
            )
            returncode = result.returncode
            output = result.stdout + result.stderr

            # Print output to terminal
            print(output)

            # Note: If a rank fails, torchrun/mpirun will show a stack trace indicating
            # which rank failed. This is NORMAL and EXPECTED behavior - it's how
            # distributed launchers report failures, not a crash or bug.

        except subprocess.TimeoutExpired:
            print(f"\n❌ TIMEOUT: {launcher_name} did not complete within 2 minutes")
            results[launcher_name] = "TIMEOUT"
            all_passed = False
            continue
        except Exception as e:
            print(f"\n❌ ERROR launching {launcher_name}: {e}")
            results[launcher_name] = f"ERROR: {e}"
            all_passed = False
            continue

        if returncode != 0:
            print(f"\n❌ {launcher_name} FAILED: exit code {returncode}")
            results[launcher_name] = f"FAILED (exit {returncode})"
            all_passed = False
        else:
            # Verify that each rank reported success
            # Each rank now has multiple workers, so we check for rank-level success messages
            # Look for the pattern "✅ Rank X: All N workers verified"
            # Use regex to count occurrences (not lines) because stdout buffering can concatenate outputs
            import re

            rank_summaries = len(re.findall(r"✅ Rank \d+: All \d+ workers verified", output))

            if rank_summaries == world_size:
                print(f"\n✅ {launcher_name} PASSED - All {world_size} ranks with their workers verified")
                results[launcher_name] = "PASSED"
            else:
                print(f"\n❌ {launcher_name} PARTIAL SUCCESS - Only {rank_summaries}/{world_size} ranks verified")
                results[launcher_name] = f"PARTIAL ({rank_summaries}/{world_size} ranks OK)"
                all_passed = False

    # Print summary
    print("\n" + "=" * 70)
    print("MULTI-RANK TEST SUMMARY")
    print("=" * 70)

    for launcher_name, result in results.items():
        status_icon = "✅" if result == "PASSED" else "❌"
        print(f"  {status_icon} {launcher_name}: {result}")

    print("-" * 70)

    if all_passed:
        print(f"✅ OVERALL: PASS")
        print(f"\nAll {world_size} ranks completed successfully with all launchers!")
        print(f"\nVerified:")
        print(f"  ✓ Each rank ran as independent process")
        print(f"  ✓ Each rank had separate Python interpreter and statistics")
        print(f"  ✓ Each rank used different failure patterns (rank-specific)")
        print(f"  ✓ Each rank's statistics matched expected values exactly")
        print(f"  ✓ Each rank logged its own statistics independently")
        print(f"  ✓ No shared memory or stat contamination between ranks")
        print(f"  ✓ Works with both torchrun and mpirun")
        print(f"\nThis confirms statistics work correctly in distributed training!")
    else:
        print(f"❌ OVERALL: FAIL")
        print(f"\nSome launchers failed - check output above for details")

    print("=" * 70)
    return all_passed


def test_rank_failure_robustness():
    """Test that statistics logging doesn't cause deadlocks when ranks fail unexpectedly.

    This is a fault tolerance test that simulates real-world distributed training failures.

    What we test:
    - Random ranks are killed mid-execution (os._exit() simulates crash)
    - Tests complete within timeout (NO DEADLOCK)
    - Statistics logging doesn't hold locks that cause hangs
    - Atexit handlers don't deadlock when ranks die

    What we expect:
    - Test completes within 60s (key requirement - no hang/deadlock)
    - torchrun: Kills all ranks (elastic fail-fast behavior - CORRECT)
    - mpirun: May kill all ranks or allow survivors (both OK)
    - Non-zero exit code (some ranks died - EXPECTED)

    What we DON'T test:
    - Whether survivors complete their work (launcher-dependent)
    - Recovery or re-launching (not our responsibility)

    This test is critical because in production, ranks can fail due to:
    - Hardware failures (GPU crashes, node failures)
    - OOM errors
    - Network issues
    - Data corruption

    The statistics logging must NOT make the system more brittle by adding deadlock risks.
    """
    print("\n" + "=" * 70)
    print("RANK FAILURE ROBUSTNESS TEST")
    print("=" * 70)
    print("Testing that statistics logging doesn't cause deadlocks when ranks fail")
    print("=" * 70)

    world_size = 10
    num_failed_ranks = 3  # Kill 30% of ranks

    # Get path to worker script
    worker_script = os.path.join(os.path.dirname(__file__), "mpi_rank_worker.py")

    # Randomly select ranks to kill
    import random

    random.seed(42)  # Deterministic for reproducibility
    failed_ranks = random.sample(range(world_size), num_failed_ranks)
    failed_ranks_str = ",".join(map(str, failed_ranks))

    print(f"\n  World size: {world_size} ranks")
    print(f"  Simulating failures: {num_failed_ranks} ranks will be killed mid-execution")
    print(f"  Failed ranks: {failed_ranks}")
    print(f"\n  SUCCESS CRITERIA:")
    print(f"    ✓ Test completes within 60s timeout (NO DEADLOCK)")
    print(f"    ✓ Job exits with non-zero code (ranks died as expected)")
    print(f"\n  NOTE: Launchers use fail-fast behavior (kill all ranks when one fails)")
    print(f"        This is CORRECT and EXPECTED behavior!")
    print("=" * 70)

    # Check which launchers are available
    available_launchers = {}

    # Check torchrun
    try:
        result = subprocess.run(["torchrun", "--help"], capture_output=True, timeout=5)
        if result.returncode == 0:
            available_launchers["torchrun"] = [
                "torchrun",
                f"--nproc_per_node={world_size}",
                worker_script,
            ]
    except (FileNotFoundError, subprocess.TimeoutExpired):
        pass

    # Check mpirun
    try:
        result = subprocess.run(["mpirun", "--version"], capture_output=True, timeout=5)
        if result.returncode == 0:
            available_launchers["mpirun"] = [
                "mpirun",
                "-np",
                str(world_size),
                "--oversubscribe",
                "--tag-output",
                "python",
                worker_script,
            ]
    except (FileNotFoundError, subprocess.TimeoutExpired):
        pass

    if not available_launchers:
        print("\n⚠ SKIPPING: Neither torchrun nor mpirun available")
        print("This test requires a distributed launcher.")
        return True  # Skip test, don't fail

    print(f"\nAvailable launchers: {', '.join(available_launchers.keys())}")

    # Test with each available launcher
    all_passed = True
    results = {}

    for launcher_name, launcher_cmd in available_launchers.items():
        print(f"\n{'=' * 70}")
        print(f"TESTING WITH: {launcher_name}")
        print(f"{'=' * 70}")

        # Set environment variables for failure simulation
        env = os.environ.copy()
        env["SIMULATE_FAILURE_RANKS"] = failed_ranks_str
        env["SKIP_BARRIER"] = "1"  # Skip barrier to avoid deadlock

        try:
            result = subprocess.run(
                launcher_cmd,
                capture_output=True,
                text=True,
                timeout=60,  # Shorter timeout - should fail fast
                env=env,
            )
            returncode = result.returncode
            output = result.stdout + result.stderr

            # Print output
            print(output)

        except subprocess.TimeoutExpired:
            print(f"\n❌ TIMEOUT: Test hung for 60 seconds - likely a DEADLOCK!")
            print(f"This means the statistics logging caused a deadlock when ranks failed.")
            results[launcher_name] = "DEADLOCK"
            all_passed = False
            continue
        except Exception as e:
            print(f"\n❌ ERROR launching {launcher_name}: {e}")
            results[launcher_name] = f"ERROR: {e}"
            all_passed = False
            continue

        # Expected behavior: Job should fail (some ranks died) but NOT hang/deadlock
        # Key metric: Did it complete within timeout? If yes, NO DEADLOCK!

        # Count ranks that completed their work
        success_count = output.count("All statistics verified correctly!")
        killed_ranks_confirmed = output.count("SIMULATING RANK FAILURE")
        expected_survivors = world_size - num_failed_ranks

        if returncode == 0:
            print(f"\n⚠️  WARNING: {launcher_name} succeeded even though ranks were killed")
            print(f"This is unexpected - check if failure simulation worked")
            results[launcher_name] = "UNEXPECTED_SUCCESS"
            all_passed = False
        else:
            # Job failed as expected (some ranks died)
            # Check for deadlock: If we got here without timeout, NO DEADLOCK!
            print(f"\n✅ {launcher_name} PASSED - No deadlock detected!")
            print(f"   Job completed within timeout (no hang/deadlock)")
            print(f"   Simulated failures: {killed_ranks_confirmed}/{num_failed_ranks} ranks")
            print(f"   Completed successfully: {success_count} ranks")

            # torchrun kills ALL ranks when ANY rank fails (elastic behavior)
            # mpirun may allow survivors to continue
            if launcher_name == "torchrun":
                if success_count == 0:
                    print(f"   ℹ️  torchrun killed all ranks (expected elastic fail-fast behavior)")
                    results[launcher_name] = "PASSED"
                else:
                    print(f"   ⚠️  Some ranks survived despite torchrun elastic mode")
                    results[launcher_name] = "PASSED"
            else:  # mpirun or other
                if success_count >= expected_survivors - 1:  # Allow 1 off due to timing
                    print(f"   ℹ️  {success_count}/{expected_survivors} expected survivors completed")
                    results[launcher_name] = "PASSED"
                elif success_count > 0:
                    print(f"   ⚠️  Only {success_count}/{expected_survivors} survivors completed")
                    print(f"   Some ranks may have been killed by launcher")
                    results[launcher_name] = "PASSED"  # Still no deadlock
                else:
                    print(f"   ℹ️  Launcher killed all ranks (fail-fast behavior)")
                    results[launcher_name] = "PASSED"

    # Print summary
    print("\n" + "=" * 70)
    print("RANK FAILURE ROBUSTNESS TEST SUMMARY")
    print("=" * 70)

    for launcher_name, result in results.items():
        status_icon = "✅" if result == "PASSED" else "⚠️" if "UNEXPECTED" in result else "❌"
        print(f"  {status_icon} {launcher_name}: {result}")

    print("-" * 70)

    if all_passed and results:
        print(f"✅ OVERALL: PASS - No deadlocks detected!")
        print(f"\nAll launchers handled rank failures without deadlock!")
        print(f"\nVerified:")
        print(f"  ✓ {num_failed_ranks} ranks were killed mid-execution (simulated failures)")
        print(f"  ✓ No deadlocks or hangs detected (completed within timeout)")
        print(f"  ✓ Statistics logging didn't hold locks that cause deadlocks")
        print(f"  ✓ Launchers handled failures with fail-fast behavior")
        print(f"\nKey finding:")
        print(f"  • torchrun: Uses elastic fail-fast (kills all ranks when one fails)")
        print(f"  • mpirun: May allow survivors or use fail-fast depending on config")
        print(f"  • Both behaviors are CORRECT - no deadlock is the critical requirement")
        print(f"\nThis confirms the statistics infrastructure is fault-tolerant!")
    elif not results:
        print(f"⚠️  OVERALL: SKIP - No launchers available for testing")
    else:
        print(f"❌ OVERALL: FAIL")
        print(f"\nSome tests failed - check for DEADLOCK or TIMEOUT issues above")
        print(f"\nIf tests TIMEOUT, that indicates a DEADLOCK problem.")
        print(f"If tests complete quickly with 0 survivors, that's fail-fast (OK).")

    print("=" * 70)
    return all_passed


def test_dataloader_worker_process_isolation():
    """Test to prove DataLoader workers are separate processes with independent statistics.

    This test demonstrates that:
    1. Each DataLoader worker is a separate process (different PID)
    2. Each worker process's main thread can have the same thread ID
    3. Each worker has its own _global_retry_stats instance with independent counters
    4. This explains why production logs show the same thread ID with different operation counts
    """
    import multiprocessing
    import queue

    from torch.utils.data import DataLoader, Dataset

    print("\n" + "=" * 70)
    print("TEST: DataLoader Worker Process Isolation")
    print("=" * 70)

    # Queue to collect results from worker processes
    result_queue = multiprocessing.Manager().Queue()

    class DataLoaderTestDataset(Dataset):
        """Dataset that reports process/thread info from workers."""

        def __init__(self, num_items: int, result_queue: queue.Queue):
            self.num_items = num_items
            self.result_queue = result_queue

        def __len__(self) -> int:
            return self.num_items

        def __getitem__(self, idx: int) -> dict:
            """Each worker performs S3 operations and reports its PID/thread ID/stats."""
            import os
            import threading
            from unittest.mock import MagicMock

            import cosmos_predict2._src.imaginaire.datasets.webdataset.utils.stream as stream_module
            from cosmos_predict2._src.imaginaire.datasets.webdataset.utils.stream import RetryingStream
            from cosmos_predict2._src.imaginaire.utils import log

            # Initialize logging in worker process (each worker subprocess needs this)
            # Only initialize once per worker
            if not hasattr(self, "_log_initialized"):
                log.init_loguru_stdout()
                self._log_initialized = True

            # Get process and thread info
            pid = os.getpid()
            thread_id = threading.get_ident()

            # Create mock S3 client
            client = MagicMock()
            test_data = b"X" * 1024
            client.head_object.return_value = {"ContentLength": str(len(test_data))}

            # Vary number of operations per item so each worker gets unique totals
            # This makes it clear in logs that each worker has independent statistics
            num_ops = 5 + (idx % 4) * 5  # 5, 10, 15, 20 ops depending on idx

            # Perform S3 operations to increment stats
            for i in range(num_ops):
                mock_body = MagicMock()
                mock_body.read.return_value = test_data
                client.get_object.return_value = {"Body": mock_body, "ContentLength": len(test_data)}

                stream = RetryingStream(client, f"bucket-worker{pid}", f"file-{idx}-{i}.tar", retries=5)
                _ = stream.read(1024)

            # Force a log every few items to demonstrate the logging behavior
            # This simulates what would happen in production when 60 seconds elapse
            if idx % 3 == 0:  # Log every 3rd item
                # Force logging by resetting the timer
                with stream_module._global_retry_stats.lock:
                    stream_module._global_retry_stats.last_log_time = 0

                # Trigger periodic log (will show cumulative stats for this worker)
                stream_module._maybe_log_retry_stats()

            # Get cumulative stats from this worker's _global_retry_stats
            with stream_module._global_retry_stats.lock:
                ops_started = stream_module._global_retry_stats.cumulative_operations_started
                failed_ops = stream_module._global_retry_stats.cumulative_failed_operations
                total_attempts = stream_module._global_retry_stats.cumulative_attempts

            # Report to main process
            result = {
                "idx": idx,
                "pid": pid,
                "thread_id": thread_id,
                "ops_started": ops_started,
                "failed_ops": failed_ops,
                "total_attempts": total_attempts,
            }

            self.result_queue.put(result)
            return result

    # Create DataLoader with multiple workers
    num_workers = 4
    items_per_worker = 5
    dataset = DataLoaderTestDataset(num_items=num_workers * items_per_worker, result_queue=result_queue)

    print(f"\nCreating DataLoader with {num_workers} workers...")
    print(f"Each worker will process ~{items_per_worker} items")
    print(f"Number of S3 operations varies per item: 5, 10, 15, or 20 read ops")
    print(f"Each item: (2 * num_read_ops) init operations + num_read_ops read operations")
    print(f"This creates DIFFERENT operation counts per worker (proves isolation)")
    print(f"\nWorkers will log statistics every 3 items (simulating periodic logs)")
    print(f"Watch for logs showing the SAME thread ID but DIFFERENT operation counts!\n")

    dataloader = DataLoader(
        dataset,
        batch_size=1,
        num_workers=num_workers,
        shuffle=False,
    )

    # Consume the dataloader (this triggers worker processes)
    for batch in dataloader:
        pass  # Workers send results via queue

    # Collect all results from workers
    results = []
    while not result_queue.empty():
        try:
            results.append(result_queue.get_nowait())
        except queue.Empty:
            break

    # Analyze results
    print(f"Collected {len(results)} results from workers\n")

    # Group by PID
    by_pid = {}
    for result in results:
        pid = result["pid"]
        if pid not in by_pid:
            by_pid[pid] = []
        by_pid[pid].append(result)

    print(f"Found {len(by_pid)} unique worker PIDs:")

    thread_ids_by_pid = {}
    for pid, items in sorted(by_pid.items()):
        thread_ids = set(item["thread_id"] for item in items)
        thread_ids_by_pid[pid] = thread_ids

        # Get final stats for this worker (last item has cumulative total)
        final_stats = max(items, key=lambda x: x["ops_started"])

        print(f"  PID {pid}:")
        print(f"    Thread ID(s): {thread_ids}")
        print(f"    Processed {len(items)} items")
        print(
            f"    Cumulative stats: {final_stats['ops_started']} ops, "
            f"{final_stats['failed_ops']} failed, {final_stats['total_attempts']} attempts"
        )

    # Check for thread ID collisions across processes
    all_thread_ids = []
    for thread_ids in thread_ids_by_pid.values():
        all_thread_ids.extend(thread_ids)

    thread_id_counts = {}
    for tid in all_thread_ids:
        thread_id_counts[tid] = thread_id_counts.get(tid, 0) + 1

    duplicated_thread_ids = {tid: count for tid, count in thread_id_counts.items() if count > 1}

    print("\n" + "-" * 70)
    print("ANALYSIS:")
    print("-" * 70)

    if duplicated_thread_ids:
        print(f"✅ Found thread ID collision(s): {len(duplicated_thread_ids)} thread IDs shared across processes")
        for tid, count in duplicated_thread_ids.items():
            print(f"   Thread ID {tid} appears in {count} different worker processes")
        print("\nThis proves that the same thread ID can exist in multiple processes!")
    else:
        print("⚠️  No thread ID collisions found (less common, but still valid)")
        print("   Each worker happened to get a unique thread ID")

    print(f"\nEach worker process has INDEPENDENT _global_retry_stats (different counts):")
    operation_counts = []
    for pid, items in sorted(by_pid.items()):
        final_stats = max(items, key=lambda x: x["ops_started"])
        operation_counts.append(final_stats["ops_started"])
        print(f"   PID {pid}: {final_stats['ops_started']} operations (independent counter)")

    # Verify that operation counts are different (proving independence)
    if len(set(operation_counts)) > 1:
        print(f"\n✅ Workers have DIFFERENT operation counts: {operation_counts}")
        print("   This proves each process has its own independent statistics!")
    else:
        print(f"\n⚠️  Workers have same counts (less typical, but still independent processes)")

    print("\n" + "=" * 70)
    print("CONCLUSION:")
    print("=" * 70)
    print("During the test, you should have seen WARNING logs like:")
    print("  [RetryingStream Stats] RANK-LOCAL: X ops, Y failed...")
    print("  Thread NNNN: X ops...")
    print("\nThese logs likely showed:")
    print("  - The SAME thread ID appearing multiple times")
    print("  - DIFFERENT operation counts with that same thread ID")
    print("  - This EXACTLY matches what you see in production!")
    print("\nWhy production logs show:")
    print("  - Same thread ID (e.g., 23456244278592) across ALL log entries")
    print("  - DIFFERENT/DECREASING operation counts with the same thread ID")
    print("\nThe explanation:")
    print("  - Each DataLoader worker is a SEPARATE PROCESS with its own memory")
    print("  - Each process has its own independent _global_retry_stats instance")
    print("  - Thread IDs are process-local (NOT globally unique)")
    print("  - The main thread in each process often gets the same thread ID")
    print("  - When different workers log at different times → interleaved stats")
    print("  - Different workers have different workloads → different operation counts")
    print("\nThe solution:")
    print("  1. Use cumulative_* counters (already maintained, protected by lock)")
    print("  2. Add PID to log messages to distinguish which worker is logging")
    print("  3. Understand that 'decreasing' counts are actually different processes")
    print("=" * 70)


if __name__ == "__main__":
    # First, run robustness test
    test_passed = test_weakref_robustness()

    if not test_passed:
        print("\n⚠ WARNING: Robustness test failed!")
        time.sleep(2)

    # Second, run correctness test
    test_passed = test_multithreaded_stats_correctness()

    if not test_passed:
        print("\n⚠ WARNING: Correctness test failed! Proceeding with benchmarks anyway...")
        time.sleep(2)

    # Third, run multi-rank test
    test_passed = test_multi_rank_stats_logging()

    if not test_passed:
        print("\n⚠ WARNING: Multi-rank test failed! Check errors above.")
        time.sleep(2)

    # Fourth, run rank failure robustness test
    print("\n\nWaiting 3 seconds before fault tolerance test...")
    time.sleep(3)

    test_passed = test_rank_failure_robustness()

    if not test_passed:
        print("\n⚠ WARNING: Rank failure robustness test failed!")
        print("This means the statistics logging may cause deadlocks when ranks fail.")
        time.sleep(2)

    # Fifth, run DataLoader worker isolation test
    print("\n\nWaiting 3 seconds before DataLoader worker isolation test...")
    time.sleep(3)

    test_dataloader_worker_process_isolation()

    # Warmup (longer to stabilize JIT and caches)
    print("\n" + "=" * 70)
    print("PERFORMANCE BENCHMARKS")
    print("=" * 70)
    print("\nWarming up...")
    for _ in range(10):
        benchmark_iteration(True, 1000, 0)
        benchmark_iteration(False, 1000, 0)

    # Benchmark 1: Synthetic (no network delay) - shows maximum overhead
    run_benchmark_suite(
        name="BENCHMARK 1: Synthetic (No Network Delay)", num_operations=100000, num_runs=10, network_delay_ms=0
    )

    # Benchmark 2: Aggressive in-region latency (1ms per read)
    # This represents same-region VM-to-S3/GCS with high-bandwidth network
    print("\n\nWaiting 2 seconds before realistic benchmark...")
    time.sleep(2)

    run_benchmark_suite(
        name="BENCHMARK 2: In-Region VM to S3/GCS (1ms latency)",
        num_operations=1000,  # 1 second total with 1ms each
        num_runs=10,
        network_delay_ms=1.0,
    )

    # Benchmark 3: Typical cross-region or slower network (10ms per read)
    print("\n\nWaiting 2 seconds before final benchmark...")
    time.sleep(2)

    run_benchmark_suite(
        name="BENCHMARK 3: Cross-Region or Slower Network (10ms latency)",
        num_operations=200,  # 2 seconds total with 10ms each
        num_runs=10,
        network_delay_ms=10.0,
    )

    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    print("The synthetic benchmark (no delay) shows maximum theoretical overhead.")
    print("The realistic benchmarks show actual production impact:")
    print("  - 1ms:  Aggressive same-region VM to S3/GCS")
    print("  - 10ms: Typical cross-region or shared network")
    print("\nIn real-world S3/GCS usage, the overhead is completely negligible")
    print("because network I/O dominates (1-100ms vs ~200ns overhead).")

    print("\n" + "-" * 70)
    print("NOTE: Synthetic benchmark variance (5-20%) is normal and caused by:")
    print("  - CPU frequency scaling (1.2GHz → 4.5GHz dynamically)")
    print("  - Thermal throttling as CPU heats up")
    print("  - Background OS processes")
    print("  - Cache effects (cold vs warm)")
    print("\nThis variance does NOT exist in production with real network I/O!")
    print("The ~200ns overhead is consistent; only the baseline varies.")

    # Re-enable stats for normal operation
    stream_module.ENABLE_RETRY_STATS = True
