import tensorflow as tf
import os
import numpy as np
import cv2
import random

import pathlib

from PIL import Image
from PIL import ImageOps

tf.compat.v1.enable_eager_execution() # use with Tensorflow 1.x

TRANSFER = False # perform transfer learning or not
LEARNING_RATE = 1e-4
DATADIR = "" # path to the custom dataset (target for transfer learning)
CATEGORIES = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
SAVEDIR = "" # path to the folder, where the trained model will be stored
PRETRAINED_MODEL = "" # path to the pre-trained keras .h5 model
EPOCHS = 1
BATCH_SIZE = None


def create_model():
  image = tf.keras.layers.Input(shape=(28, 28, 1))

  y = tf.keras.layers.Conv2D(filters=32,
                             kernel_size=5,
                             padding='same',
                             activation='relu')(image)
  y = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
                                   strides=(2, 2),
                                   padding='same')(y)
  y = tf.keras.layers.Conv2D(filters=32,
                             kernel_size=5,
                             padding='same',
                             activation='relu')(y)
  y = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
                                   strides=(2, 2),
                                   padding='same')(y)
  y = tf.keras.layers.Flatten()(y)
  y = tf.keras.layers.Dense(1024, activation='relu')(y)
  y = tf.keras.layers.Dropout(0.4)(y)

  probs = tf.keras.layers.Dense(10, activation='softmax')(y)

  model = tf.keras.models.Model(image, probs, name='mnist')
  return model    
    
if __name__ == '__main__':
  training_data = []

  for category in CATEGORIES:
      path = os.path.join(DATADIR, category)
      class_num = CATEGORIES.index(category)
      for img in os.listdir(path):
        img_array = cv2.imread(os.path.join(path, img), cv2.IMREAD_GRAYSCALE)
        training_data.append([img_array, class_num])

  random.shuffle(training_data)

  data = []
  labels = []

  for img, label in training_data:
    img = img.astype(np.float32)
    img = img/255.0
    data.append(img)
    labels.append(label)
    
  data = np.array(data).reshape(-1, 28, 28, 1) # use with Tensorflow 1.x
  # labels = np.array(labels) # use with Tensorflow 2.x
   
  model = create_model()
  
  if TRANSFER:
    # Load pretrained weights
    model = tf.keras.models.load_model(filepath=PRETRAINED_MODEL, compile=False)  
      
    # Remove the last 3 layers
    model = tf.keras.models.Sequential(model.layers[:-3])
      
    # Freeze the pretrained layers
    for layer in model.layers:
      layer.trainable = False
      
    # Add the last 3 layers uninitialized
    new_layers = tf.keras.layers.Dense(1024, activation='relu')(model.layers[-1].output)
    new_layers = tf.keras.layers.Dropout(0.4)(new_layers)
    new_output = tf.keras.layers.Dense(10, activation='softmax')(new_layers)

    # Put it all together
    model = tf.keras.models.Model(model.layers[0].input, new_output, name='mnist_new')  
  
  # Compile
  model.compile(
        optimizer=tf.keras.optimizers.Adam(LEARNING_RATE),
        loss='sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy'])
        
  model.summary()
  
  ckpt_full_path = os.path.join('/tmp/mnist_model', 'model.ckpt-{epoch:04d}')
  callbacks = [
      tf.keras.callbacks.ModelCheckpoint(
          ckpt_full_path, save_weights_only=True),
      tf.keras.callbacks.TensorBoard(log_dir='/tmp/mnist_model'),
  ]

  model.fit(
      data,
      labels,
      epochs=EPOCHS,
      batch_size=BATCH_SIZE,
      callbacks=callbacks,
      validation_split=0.1)

  model.save(SAVEDIR + "mnist.h5")    

  tflite_models_dir = pathlib.Path(SAVEDIR)
  tflite_models_dir.mkdir(exist_ok=True, parents=True)

  converter = tf.lite.TFLiteConverter.from_keras_model_file(SAVEDIR + "mnist.h5") # use with Tensorflow 1.x
  # converter = tf.lite.TFLiteConverter.from_keras_model(model) # use with Tensorflow 2.x

  tflite_model = converter.convert()
  tflite_model_file = tflite_models_dir/"mnist.tflite"
  tflite_model_file.write_bytes(tflite_model)
  
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
  tflite_model_quant_default = converter.convert()
  tflite_model_quant_default_file = tflite_models_dir/"mnist_quant_default.tflite"
  tflite_model_quant_default_file.write_bytes(tflite_model_quant_default)
  
  converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
  tflite_model_quant_size = converter.convert()
  tflite_model_quant_size_file = tflite_models_dir/"mnist_quant_size.tflite"
  tflite_model_quant_size_file.write_bytes(tflite_model_quant_size)

  # ############################################################################################## #
  # The following commented lines contain code for full integer post-training quantization.        #
  # However, this type of quantization was not supported in NXP SDK 2.7.0. If you are using        #
  # SDK 2.8.0 or newer, you can try enabling these options as well to achieve better performance.  #
  # ############################################################################################## #

  # mnist_ds = tf.data.Dataset.from_tensor_slices(data).batch(1) 
  # def representative_data_gen():
    # for input_value in mnist_ds.take(100):
      # yield [input_value]

  # converter.optimizations = [tf.lite.Optimize.DEFAULT]
  # converter.representative_dataset = representative_data_gen
  # tflite_model_quant_default_full_int = converter.convert()
  # tflite_model_quant_default_full_int_file = tflite_models_dir/"test/mnist_model_quant_default_full_int.tflite"
  # tflite_model_quant_default_full_int_file.write_bytes(tflite_model_quant_default_full_int)
  
  # converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
  # converter.representative_dataset = representative_data_gen
  # tflite_model_quant_size_full_int = converter.convert()
  # tflite_model_quant_size_full_int_file = tflite_models_dir/"test/mnist_model_quant_size_full_int.tflite"
  # tflite_model_quant_size_full_int_file.write_bytes(tflite_model_quant_size_full_int)
