/*
 * Copyright (c) 2015, Freescale Semiconductor, Inc.
 * Copyright 2023-2023 NXP
 * All rights reserved.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include "fsl_sdmmc_common.h"
#include "fsl_debug_console.h"
#include "rpmb.h"

// big endian to CPU byte order (little endian)
static inline uint16_t be16_to_cpu(uint16_t x) {
    return (x << 8) | (x >> 8);
}

static inline uint32_t be32_to_cpu(uint32_t x) {
    return ((x & 0xff000000) >> 24) |
           ((x & 0x00ff0000) >>  8) |
           ((x & 0x0000ff00) <<  8) |
           ((x & 0x000000ff) << 24);
}

// CPU to big endian byte order (little endian)
static inline uint16_t cpu_to_be16(uint16_t x) {
    return (x << 8) | (x >> 8);
}

static inline uint32_t cpu_to_be32(uint32_t x) {
    return ((x & 0xff000000) >> 24) |
           ((x & 0x00ff0000) >>  8) |
           ((x & 0x0000ff00) <<  8) |
           ((x & 0x000000ff) << 24);
}

status_t mmc_rpmb_request(mmc_card_t *card, rpmb_frame_t *s, unsigned int count, bool is_rel_write)
{
    assert(card != NULL);
	
    sdmmchost_cmd_t command      = {0};
    sdmmchost_transfer_t content = {0};
    sdmmchost_data_t data        = {0};
    status_t error               = kStatus_Success;

    /* Legacy mmc card , do not support the command */
    if ((card->csd.systemSpecificationVersion == (uint32_t)kMMC_SpecificationVersion3) &&
        (card->csd.csdStructureVersion == (uint32_t)kMMC_CsdStrucureVersion12))
    {
        return kStatus_Success;
    }
    error = SDMMC_SetBlockCount(card -> host, count, is_rel_write);
	
    if(kStatus_Success != error)
    {
        PRINTF("mmc set block count fail!!\r\n");
        return error;
    }
    
    command.index        = (uint32_t)kSDMMC_WriteMultipleBlock;
    command.argument     = 0U;
    command.responseType = kCARD_ResponseTypeR1;

    data.txData = (const uint32_t *)s;
    data.blockCount = count;
    data.blockSize = MMC_MAX_BLOCK_LEN;
    
    content.command = &command;
    content.data    = &data;
    error           = SDMMCHOST_TransferFunction(card->host, &content);

    if(kStatus_Success != error)
    {
        PRINTF("mmc write multiple block fail!!\r\n");
        return error;
    }
    return kStatus_Success;
}

status_t mmc_rpmb_response(mmc_card_t *card, rpmb_frame_t *s, unsigned int count, unsigned short expected)
{
    assert(card != NULL);
	
    sdmmchost_cmd_t command      = {0};
    sdmmchost_transfer_t content = {0};
    sdmmchost_data_t data        = {0};
    status_t error               = kStatus_Success;

    /* Legacy mmc card , do not support the command */
    if ((card->csd.systemSpecificationVersion == (uint32_t)kMMC_SpecificationVersion3) &&
        (card->csd.csdStructureVersion == (uint32_t)kMMC_CsdStrucureVersion12))
    {
        return kStatus_Success;
    }
    
    error = SDMMC_SetBlockCount(card -> host, count, false);
    
    command.index        = (uint32_t)kSDMMC_ReadMultipleBlock;
    command.argument     = 0U;
    command.responseType = kCARD_ResponseTypeR1;
    
    data.rxData = (uint32_t *)s;
    data.blockCount = count;
    data.blockSize = MMC_MAX_BLOCK_LEN;
    
    content.command = &command;
    content.data    = &data;
    error           = SDMMCHOST_TransferFunction(card->host, &content);
    
    if(kStatus_Success != error)
    {
        PRINTF("mmc read multiple block fail!!\r\n");
        return error;
    }
    
    if (expected && be16_to_cpu(s->request) != expected) {
        PRINTF("ERROR: request command not match!!\r\n");
        return kStatus_Fail;
    }
    
    /* Check the response and the status */
    if (be16_to_cpu(s->result)) {
            PRINTF("%s %s\n", rpmb_err_msg[be16_to_cpu(s->result) & RPMB_ERR_MSK],
                  (be16_to_cpu(s->result) & RPMB_ERR_CNT_EXPIRED) ?
                  "Write counter has expired" : "");
    }
    
    /* Return the status of the command */
    return kStatus_Success;
}

status_t mmc_rpmb_status(mmc_card_t *card, unsigned short expected)
{
    assert(card != NULL);
    rpmb_frame_t rpmb_frame;
    
    memset(&rpmb_frame, 0, sizeof(rpmb_frame));
    rpmb_frame.request = cpu_to_be16(RPMB_REQ_STATUS); 

    if (mmc_rpmb_request(card, &rpmb_frame, 1, false) != kStatus_Success){
        PRINTF("ERROR: request command not match!!\r\n");
        return kStatus_Fail;
    }
    /* Read the result */
    return mmc_rpmb_response(card, &rpmb_frame, 1, expected);
}

status_t mmc_rpmb_get_counter(mmc_card_t *card, unsigned long *pcounter)
{
    status_t error = kStatus_Success;
    rpmb_frame_t rpmb_frame;
    assert(card != NULL);
    
    memset(&rpmb_frame, 0, sizeof(rpmb_frame));
    rpmb_frame.request = cpu_to_be16(RPMB_REQ_WCOUNTER);
    if (mmc_rpmb_request(card, &rpmb_frame, 1, false) != kStatus_Success){
        PRINTF("ERROR: request command not match!!\r\n");
        return kStatus_Fail;
    }
    /* Read the result */
    error = mmc_rpmb_response(card, &rpmb_frame, 1, RPMB_RESP_WCOUNTER);
    if (error != kStatus_Success){
        PRINTF("ERROR: response fail!!\r\n");
        return error;
    }
    *pcounter = be32_to_cpu(rpmb_frame.write_counter);
     return kStatus_Success;
}

status_t mmc_rpmb_set_key(mmc_card_t *card, const uint8_t *key)
{
    assert(card != NULL);
    rpmb_frame_t rpmb_frame;
    
    memset(&rpmb_frame, 0, sizeof(rpmb_frame));
    /* Fill the request */
    rpmb_frame.request = cpu_to_be16(RPMB_REQ_KEY);
    memcpy(rpmb_frame.mac, key, RPMB_SZ_MAC);
    
    if (mmc_rpmb_request(card, &rpmb_frame, 1, true) != kStatus_Success){
        PRINTF("ERROR: request command not match!!\r\n");
        return kStatus_Fail;
    }

    /* read the operation status */
    return mmc_rpmb_status(card, RPMB_RESP_KEY);
}

uint8_t mmc_rpmb_read(mmc_card_t *card, uint8_t *addr, unsigned short blk, unsigned short cnt, const uint8_t *key)
{
    rpmb_frame_t rpmb_frame;
    assert(card != NULL);
    const mbedtls_md_info_t *md_info = mbedtls_md_info_from_string("SHA256");
    int i;
    
    for (i = 0; i < cnt; i++) {
        /* Fill the request */
        memset(&rpmb_frame, 0, sizeof(rpmb_frame));
        rpmb_frame.address = cpu_to_be16(blk + i);
        rpmb_frame.request = cpu_to_be16(RPMB_REQ_READ_DATA);
        if (mmc_rpmb_request(card, &rpmb_frame, 1, false) != kStatus_Success){
            PRINTF("ERROR: request command not match!!\r\n");
            break;
        }
        
        /* Read the result */
        if (mmc_rpmb_response(card, &rpmb_frame, 1, RPMB_RESP_READ_DATA) != kStatus_Success){
            PRINTF("ERROR: response fail!!\r\n");  
            break;
        }
        
        /* Check the HMAC if key is provided */
        if (key) {
            unsigned char ret_hmac[RPMB_SZ_MAC];
            memset(ret_hmac, 0, sizeof(ret_hmac));
            mbedtls_md_hmac(md_info, key, 32, rpmb_frame.data, 284, ret_hmac);
            
            if (memcmp(ret_hmac, rpmb_frame.mac, RPMB_SZ_MAC)) {
                PRINTF("\r\nMAC error on block #%d\n", i);
                break;
            }
        }
        /* Copy data */
        memcpy(addr + i * RPMB_SZ_DATA, rpmb_frame.data, RPMB_SZ_DATA);
    }
    return i;
}

uint8_t mmc_rpmb_write(mmc_card_t *card, uint8_t *addr, unsigned short blk, unsigned short cnt, const uint8_t *key)
{
    rpmb_frame_t rpmb_frame;
    unsigned long wcount;
    int i;
    assert(card != NULL);
    const mbedtls_md_info_t *md_info = mbedtls_md_info_from_string("SHA256");
    
    for (i = 0; i < cnt; i++) {
        if (mmc_rpmb_get_counter(card, &wcount) != kStatus_Success) {
            PRINTF("Cannot read RPMB write counter\n");
            break;
        }
        
        /* Fill the request */
        memset(&rpmb_frame, 0, sizeof(rpmb_frame));
        memcpy(rpmb_frame.data, addr + i * RPMB_SZ_DATA, RPMB_SZ_DATA);
        rpmb_frame.address = cpu_to_be16(blk + i);
        rpmb_frame.block_count = cpu_to_be16(1);
        rpmb_frame.write_counter = cpu_to_be32(wcount);
        rpmb_frame.request = cpu_to_be16(RPMB_REQ_WRITE_DATA);
        /* Computes HMAC */
        mbedtls_md_hmac(md_info, key, 32, rpmb_frame.data, 284, rpmb_frame.mac);
    
        if (mmc_rpmb_request(card, &rpmb_frame, 1, true) != kStatus_Success){
            PRINTF("ERROR: request command not match!!\r\n");
            break;
        }
  
        /* Get status */
        if (mmc_rpmb_status(card, RPMB_RESP_WRITE_DATA) != kStatus_Success){
            PRINTF("ERROR: status wrong with RPMB_RESP_WRITE_DATA!!\r\n");
            break;
        }
    }
    return i;
}
