# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from pathlib import Path

import yaml


@dataclass
class VQACheck:
    """
    Represents a single VQA check with question, expected answer, and validation keywords.

    Attributes:
        question: The question to ask about the video
        answer: Expected answer (for reference/documentation)
        keywords: List of keywords that the actual answer should contain for validation
        must_pass: Whether this check is a must-pass check (default: False)
    """

    question: str
    answer: str
    keywords: list[str]
    must_pass: bool = False

    def validate(self, actual_answer: str) -> tuple[bool, list[str]]:
        """
        Validate if the actual answer contains the expected answer or at least one of the expected keywords.

        Args:
            actual_answer: The answer generated by the model

        Returns:
            Tuple of (validation_passed, found_keywords)
            - validation_passed: True if expected_answer is found OR at least ONE keyword is found
            - found_keywords: List of keywords that were found in the answer
        """
        actual_answer_lower = actual_answer.lower()
        found_keywords = []
        validation_passed = False

        # Handle case where no expected answer or keywords are provided
        if not self.answer and not self.keywords:
            validation_passed = True
            return validation_passed, found_keywords

        # Check if expected answer is part of the actual answer (short-circuit)
        if self.answer and self.answer.lower() in actual_answer_lower:
            validation_passed = True

        # Only check keywords if expected answer didn't match (or wasn't provided)
        if validation_passed and self.keywords:
            for keyword in self.keywords:
                if keyword.lower() in actual_answer_lower:
                    found_keywords.append(keyword)

            # Pass validation if at least one keyword is found
            validation_passed = len(found_keywords) > 0

        return validation_passed, found_keywords

    @classmethod
    def load_from_yaml(cls, test_config_path: str | Path) -> list["VQACheck"]:
        """
        Parse VQA checks from a YAML file.

        Args:
            test_config_path: Path to the YAML file containing VQA checks

        Returns:
            List of VQACheck objects

        Expected YAML format:
            vqa_checks:
              - question: What type of environment is the video set in?
                answer: A modern, well-lit office
                contains:
                  - "modern"
                  - "office"
                  - "well-lit"
              - question: What is the main focus of the video?
                answer: A robotic interaction at a counter
                contains:
                  - "robotic"
                  - "interaction"
            must_pass_checks:
              - question: Critical check
                answer: Expected answer
                contains:
                  - "keyword"
        """
        test_config_path = Path(test_config_path)

        if not test_config_path.exists():
            raise FileNotFoundError(f"Test config file not found: {test_config_path}")

        with test_config_path.open("r") as f:
            data = yaml.safe_load(f)

        if not data or ("vqa_checks" not in data and "must_pass_checks" not in data):
            raise ValueError("YAML file must contain 'vqa_checks' or 'must_pass_checks' key")

        vqa_checks = []

        # Process regular vqa_checks (must_pass=False)
        for check_data in data.get("vqa_checks", []):
            if "question" not in check_data:
                raise ValueError("Each VQA check must have a 'question' field")

            vqa_check = cls(
                question=check_data["question"],
                answer=check_data.get("answer", ""),
                keywords=check_data.get("contains", []),
                must_pass=False,
            )
            vqa_checks.append(vqa_check)

        # Process must_pass_checks (must_pass=True)
        for check_data in data.get("must_pass_checks", []):
            if "question" not in check_data:
                raise ValueError("Each VQA check must have a 'question' field")

            vqa_check = cls(
                question=check_data["question"],
                answer=check_data.get("answer", ""),
                keywords=check_data.get("contains", []),
                must_pass=True,
            )
            vqa_checks.append(vqa_check)

        return vqa_checks
