#include "PhyPacket.h"
#include "FunctionLib.h"
#include <assert.h>

#define PHY_PACKET_KEY_ID_MODE_MASK             (3 << 3)
#define PHY_PACKET_KEY_ID_MODE_SHIFT            (3)
#define PHY_PACKET_SEC_LEVEL_MASK               (7 << 0)
#define PHY_PACKET_SEC_LEVEL_SHIFT              (0)
#define PHY_PACKET_FRAME_CNT_SUPRESSION_MASK    (1 << 5)
#define PHY_PACKET_FRAME_CNT_SUPRESSION_SHIFT   (5)


#define PHY_PACKET_ADDR_LENGTH(__addr_mode)		((__addr_mode) ? (((__addr_mode) == 3) ? 8 : 2) : 0)

static uint8_t PhyPacket_GetPanIdLength(const phyFcf_t* fcf)
{
    /*
     * This Look-Up Table (LUT) contains the sizes of the PanID (Src + Dst) fields
     * based on the Addressing Modes and PanId Compression fields as described in
     * the table 7-2 from the 2015 standard.
     * The index in the table is computed by concatinating the mentioned fields.
     */
    static uint8_t panidSizeLut[] = {
                /* Dst Addr Mode   | Src Addr Mode   | PanId Compression */
        0,      /* 0 (not present) | 0 (not present) |       0           */
        2,      /* 0 (not present) | 0 (not present) |       1           */
        2,      /* 0 (not present) | 2 (short)       |       0           */
        0,      /* 0 (not present) | 2 (short)       |       1           */
        2,      /* 0 (not present) | 3 (extended)    |       0           */
        0,      /* 0 (not present) | 3 (extended)    |       1           */
        2,      /* 2 (short)       | 0 (not present) |       0           */
        0,      /* 2 (short)       | 0 (not present) |       1           */
        4,      /* 2 (short)       | 2 (short)       |       0           */
        2,      /* 2 (short)       | 2 (short)       |       1           */
        4,      /* 2 (short)       | 2 (extended)    |       0           */
        2,      /* 2 (short)       | 2 (extended)    |       1           */
        2,      /* 2 (extended)    | 0 (not present) |       0           */
        0,      /* 2 (extended)    | 0 (not present) |       1           */
        4,      /* 2 (extended)    | 2 (short)       |       0           */
        2,      /* 2 (extended)    | 2 (short)       |       1           */
        2,      /* 2 (extended)    | 3 (extended)    |       0           */
        0,      /* 2 (extended)    | 3 (extended)    |       1           */
    };

    /* According to 7.2.1.5 PanId field specifications of the 2015 standard */
    if ((fcf->frameVersion == 0) || (fcf->frameVersion == 1))
    {
        if (fcf->panIdCompression == 1)
        {
            return 2;
        }
        else
        {
            return 4;
        }
    }
    else /* fcf->frameVersion == 2 */
    {
    uint8_t index = (fcf->dstAddressingMode ? fcf->dstAddressingMode - 1 : fcf->dstAddressingMode) * 6 +
                    (fcf->srcAddressingMode ? fcf->srcAddressingMode - 1 : fcf->srcAddressingMode) * 2 +
                     fcf->panIdCompression;
    return panidSizeLut[index];
    }
}

uint8_t PhyPacket_GetHdrLength(const uint8_t *packet)
{
    phyFcf_t *fcf = (phyFcf_t *)packet;
    uint8_t length = sizeof(phyFcf_t);

    if (fcf->snSupression == 0)
    {
        length += 1;
    }

    length += PhyPacket_GetPanIdLength(fcf);
    length += PHY_PACKET_ADDR_LENGTH(fcf->dstAddressingMode);
    length += PHY_PACKET_ADDR_LENGTH(fcf->srcAddressingMode);

    return length;
}

uint8_t * PhyPacket_GetSecurityHeader(const uint8_t *packet)
{
    assert(((phyFcf_t *)packet)->securityEnabled == 1);

    uint8_t secHeaderIndex = PhyPacket_GetHdrLength(packet);
    uint8_t *secHeader     = (uint8_t *)packet + secHeaderIndex;
    return secHeader;
}

uint8_t PhyPacket_GetSecurityLevel(const uint8_t *packet)
{
    uint8_t *secHeader = PhyPacket_GetSecurityHeader(packet);
    uint8_t secLevel   = (secHeader[0] & PHY_PACKET_SEC_LEVEL_MASK) >> PHY_PACKET_SEC_LEVEL_SHIFT;

    return secLevel;
}

uint8_t PhyPacket_ComputeMicLength(const uint8_t secLevel)
{
    uint8_t micLength = 0;

    assert((secLevel < 7) && (secLevel != 4));

    switch (secLevel & 0x3) {
    case 0:
        micLength = 0;
        break;
    case 1:
        micLength = 4;
        break;
    case 2:
        micLength = 8;
        break;
    case 3:
        micLength = 16;
        break;
    }

    return micLength;
}

static uint8_t PhyPacket_GetMicLength(const uint8_t *packet)
{
    phyFcf_t *fcf = (phyFcf_t *)packet;
    uint8_t secLevel;

    if (fcf->securityEnabled == 0)
    {
        return 0;
    }

    secLevel = PhyPacket_GetSecurityLevel(packet);
    return PhyPacket_ComputeMicLength(secLevel);
}

uint8_t PhyPacket_GetKeyIdMode(const uint8_t *secHeader)
{
    return (secHeader[0] & PHY_PACKET_KEY_ID_MODE_MASK) >> PHY_PACKET_KEY_ID_MODE_SHIFT;
}

uint8_t PhyPacket_GetFCSuppression(const uint8_t *secHeader)
{
    return (secHeader[0] & PHY_PACKET_FRAME_CNT_SUPRESSION_MASK) >> PHY_PACKET_FRAME_CNT_SUPRESSION_SHIFT;
}

static uint8_t PhyPacket_GetKeySourceSize(const uint8_t *secHeader)
{
    uint8_t keyIdMode = (secHeader[0] & PHY_PACKET_KEY_ID_MODE_MASK) >> PHY_PACKET_KEY_ID_MODE_SHIFT;

    if (keyIdMode == 2)
    {
        return 4;
    }

    if (keyIdMode == 3)
    {
        return 8;
    }

    return 0; // keyIdMode == 0 or keyIdMode == 1
}

uint8_t PhyPacket_GetKeyIndex(const uint8_t *packet)
{
    uint8_t *secHeader 		= PhyPacket_GetSecurityHeader(packet);
    uint32_t keySourceSize  = PhyPacket_GetKeySourceSize(secHeader);
    uint8_t  fcSuppression  = PhyPacket_GetFCSuppression(secHeader);
    uint8_t  keyIdMode      = PhyPacket_GetKeyIdMode(secHeader);
    uint32_t index          = 1 + keySourceSize;

    assert(keyIdMode > 0);

    if (!fcSuppression)
    {
        index += 4;
    }

    return secHeader[index];
}

void PhyPacket_SetKeyIndex(uint8_t *packet, uint8_t keyIndex)
{
    uint8_t *secHeader     = PhyPacket_GetSecurityHeader(packet);
    uint32_t keySourceSize = PhyPacket_GetKeySourceSize(secHeader);
    uint8_t  fcSuppression = PhyPacket_GetFCSuppression(secHeader);
    uint8_t  keyIdMode     = PhyPacket_GetKeyIdMode(secHeader);
    uint8_t  index         = 1 + keySourceSize;

    assert(keyIdMode > 0);

    if (!fcSuppression)
    {
        index += 4;
    }

    secHeader[index] = keyIndex;
}

void PhyPacket_SetFrameCounter(uint8_t *packet, uint32_t frameCounter)
{
    uint8_t *secHeader     = PhyPacket_GetSecurityHeader(packet);
    uint8_t  fcSuppression = PhyPacket_GetFCSuppression(secHeader);

    if (!fcSuppression)
    {
        /* Little-Endian Format */
        secHeader[1] = (frameCounter >>  0) & 0xff;
        secHeader[2] = (frameCounter >>  8) & 0xff;
        secHeader[3] = (frameCounter >> 16) & 0xff;
        secHeader[4] = (frameCounter >> 24) & 0xff;
    }
}

void PhyPacket_GetFrameCounter(uint8_t *packet, uint32_t *frameCounter)
{
    uint8_t *secHeader     = PhyPacket_GetSecurityHeader(packet);
    uint8_t  fcSuppression = PhyPacket_GetFCSuppression(secHeader);
    
    if (!fcSuppression)
    {
        /* Little-Endian Format */
        FLib_MemCpy((uint8_t *)frameCounter, &secHeader[1], sizeof(uint32_t));
    }
    /* else
     * {
     *     *frameCounter++;
     * }
     */
}

uint8_t PhyPacket_GetSecurityHeaderLength(uint8_t *packet)
{
    phyFcf_t *fcf = (phyFcf_t *)packet;
    uint8_t  keyIdMode;
    uint8_t  *secHeader;
    uint8_t   fcSuppression;
    uint8_t   length = 1;

    if (fcf->securityEnabled == 0)
    {
        return 0;
    }

    secHeader     = PhyPacket_GetSecurityHeader(packet);
    keyIdMode     = PhyPacket_GetKeyIdMode(secHeader);
    fcSuppression = PhyPacket_GetFCSuppression(secHeader);

    if (!fcSuppression)
    {
        length += 4;
    }

    length += PhyPacket_GetKeySourceSize(secHeader);

    if (keyIdMode > 0)
    {
        length += 1;
    }

    return length; 
}

uint8_t PhyPacket_GetMacHdrLength(uint8_t *packet, uint8_t packetLength)
{
    phyFcf_t *fcf     = (phyFcf_t *)packet;
    uint8_t mhrLength = PhyPacket_GetHdrLength(packet) + PhyPacket_GetSecurityHeaderLength(packet);

    packetLength -= (2 + PhyPacket_GetMicLength(packet));
    packet += mhrLength;

    if (fcf->iePresent)
    {
        HdrIe_t *ie;

        while (mhrLength < packetLength)
        {
            ie         = (HdrIe_t *)packet;
            packet    += (2 + ie->length);
            mhrLength += (2 + ie->length);

            if ((ie->id == HDR_IE_ID_HT1) || (ie->id == HDR_IE_ID_HT2))
            {
                break;
            }
        }
    }
    
    if ((fcf->frameVersion != 2) && (fcf->frameType == MAC_FRAME_TYPE_CMD))
    {
        mhrLength++;
    }

    return mhrLength;
}
