"""
Monkey-patch for scservo_sdk compatibility issues.

This module patches the scservo_sdk library to fix API mismatches without
modifying the installed package. Import this module before using sms_sts.

Required package: feetech-servo-sdk>=1.0.0 (provides scservo_sdk with a flexible
protocol_packet_handler.__init__). On Python 3.11, if you see:
  TypeError: protocol_packet_handler.__init__() missing 2 required positional arguments: 'portHandler' and 'protocol_end'
then install the same SDK as Python 3.13: pip install feetech-servo-sdk
See robot_models/requirements-arm.txt.
"""

from scservo_sdk.sms_sts import sms_sts as _original_sms_sts
from scservo_sdk.group_sync_write import GroupSyncWrite
from scservo_sdk.sms_sts import SMS_STS_ACC

# Older scservo_sdk (e.g. on Python 3.11) may not define these in scservo_def; provide fallbacks
try:
    from scservo_sdk.scservo_def import (
        SCS_TOHOST, SCS_TOSCS, SCS_MAKEWORD, SCS_MAKEDWORD,
        SCS_LOWORD, SCS_HIWORD, SCS_LOBYTE, SCS_HIBYTE
    )
except ImportError:
    # Standard byte/word helpers used by Feetech SCS protocol (little-endian)
    def SCS_LOBYTE(word):
        return word & 0xFF

    def SCS_HIBYTE(word):
        return (word >> 8) & 0xFF

    def SCS_LOWORD(dword):
        return dword & 0xFFFF

    def SCS_HIWORD(dword):
        return (dword >> 16) & 0xFFFF

    def SCS_MAKEWORD(low, high):
        return (low & 0xFF) | ((high & 0xFF) << 8)

    def SCS_MAKEDWORD(low, high):
        return (low & 0xFFFF) | ((high & 0xFFFF) << 16)

    def SCS_TOHOST(value, bits):
        mask = (1 << bits) - 1 if bits < 32 else 0xFFFFFFFF
        return value & mask

    def SCS_TOSCS(value, bits):
        mask = (1 << bits) - 1 if bits < 32 else 0xFFFFFFFF
        return value & mask


def _patched_sms_sts_init(self, portHandler):
    """Patched __init__ for sms_sts that correctly calls GroupSyncWrite."""
    # Call parent __init__ properly
    super(_original_sms_sts, self).__init__()
    # Store port handler
    self.portHandler = portHandler
    # Fix GroupSyncWrite call: it expects (port, ph, start_address, data_length)
    self.groupSyncWrite = GroupSyncWrite(self.portHandler, self, SMS_STS_ACC, 7)


def _patched_read1ByteTxRx(self, scs_id, address):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).read1ByteTxRx(self.portHandler, scs_id, address)


def _patched_read2ByteTxRx(self, scs_id, address):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).read2ByteTxRx(self.portHandler, scs_id, address)


def _patched_read4ByteTxRx(self, scs_id, address):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).read4ByteTxRx(self.portHandler, scs_id, address)


def _patched_writeTxRx(self, scs_id, address, length, data):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).writeTxRx(self.portHandler, scs_id, address, length, data)


def _patched_regWriteTxRx(self, scs_id, address, length, data):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).regWriteTxRx(self.portHandler, scs_id, address, length, data)


def _patched_write1ByteTxRx(self, scs_id, address, data):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).write1ByteTxRx(self.portHandler, scs_id, address, data)


def _patched_write2ByteTxRx(self, scs_id, address, data):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).write2ByteTxRx(self.portHandler, scs_id, address, data)


def _patched_action(self, scs_id):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).action(self.portHandler, scs_id)


def _patched_ping(self, scs_id):
    """Wrapper to inject portHandler as first argument."""
    return super(_original_sms_sts, self).ping(self.portHandler, scs_id)


# Add utility methods that should be instance methods
def _scs_tohost(self, value, bits):
    """Convert SCS value to host format."""
    return SCS_TOHOST(value, bits)


def _scs_toscs(self, value, bits):
    """Convert host value to SCS format."""
    return SCS_TOSCS(value, bits)


def _scs_lobyte(self, word):
    """Get low byte of word."""
    return SCS_LOBYTE(word)


def _scs_hibyte(self, word):
    """Get high byte of word."""
    return SCS_HIBYTE(word)


def _scs_loword(self, dword):
    """Get low word of dword."""
    return SCS_LOWORD(dword)


def _scs_hiword(self, dword):
    """Get high word of dword."""
    return SCS_HIWORD(dword)


def _scs_makeword(self, low, high):
    """Make word from low and high bytes."""
    return SCS_MAKEWORD(low, high)


def _scs_makedword(self, low, high):
    """Make dword from low and high words."""
    return SCS_MAKEDWORD(low, high)


# Apply monkey-patches
_original_sms_sts.__init__ = _patched_sms_sts_init
_original_sms_sts.read1ByteTxRx = _patched_read1ByteTxRx
_original_sms_sts.read2ByteTxRx = _patched_read2ByteTxRx
_original_sms_sts.read4ByteTxRx = _patched_read4ByteTxRx
_original_sms_sts.writeTxRx = _patched_writeTxRx
_original_sms_sts.regWriteTxRx = _patched_regWriteTxRx
_original_sms_sts.write1ByteTxRx = _patched_write1ByteTxRx
_original_sms_sts.write2ByteTxRx = _patched_write2ByteTxRx
_original_sms_sts.action = _patched_action
_original_sms_sts.ping = _patched_ping

# Add utility methods as instance methods
_original_sms_sts.scs_tohost = _scs_tohost
_original_sms_sts.scs_toscs = _scs_toscs
_original_sms_sts.scs_lobyte = _scs_lobyte
_original_sms_sts.scs_hibyte = _scs_hibyte
_original_sms_sts.scs_loword = _scs_loword
_original_sms_sts.scs_hiword = _scs_hiword
_original_sms_sts.scs_makeword = _scs_makeword
_original_sms_sts.scs_makedword = _scs_makedword

# Export the patched class
sms_sts = _original_sms_sts
