#!/usr/bin/env python

###############################################################################
#
# Copyright 2017 - 2019 NXP
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
#
###############################################################################

# DESCRIPTION:		  Generating saved model
# INPUT/S:			  RMS and Centroid values of X, Y and Z axis
# OUTPUT/S:			  Saved model

import numpy as np
import tensorflow as tf
import os
import shutil
	
###################################################################
# Main function
###################################################################	

if __name__ == "__main__":

	print ("="*73)
	print ("="*32," Start ","="*32)
	print ("="*73)

	###################################################################
	# Saved model directory
	###################################################################
	export_dir = "../Saved_Model"
	shutil.rmtree(export_dir, ignore_errors=True)
	os.mkdir(export_dir)

	# X-axis RMS Values
	np_array_rms_x = np.array([395 , 255 , 303 , 392 , 316 , 296 , 302 , 330 , 293 , 286 , 329 , 274 , 263 , 346 , 338 , 257 , 334 , 347 , 352 , 331 , 386 , 383 , 295 , 373 , 389 , 253 , 355 , 398 , 300 , 291 , 367 , 282 ,2207 ,1730 ,2133 ,1383 , 610 , 857 , 681 , 395 , 558 , 417 , 362 , 278 , 480 , 367 ,2161 ,2220 ,1321 ,1169 , 988 , 719 , 542 , 439 , 405 , 322 , 353 , 348 , 327 , 329 , 327 , 417 ,2955 ,1612 ,2007 ,1434 , 834 , 754 , 512 , 393 , 271 , 409 , 406 , 321 , 341 , 358 , 408 , 309 ,2711 ,2084 ,2014 ,1085 ,1035 , 610 , 635 , 483 , 406 , 314 , 341 , 420 , 249 , 274 , 391 , 313 , 333 , 304 ,3056 ,1898 ,1544 ,1352 ,1110 ,1070 , 543 , 442 , 446 , 366 , 432 , 361 , 277 , 389 , 310 , 363 , 314 , 379 , 279 , 322 , 389 , 360 , 337 , 383 , 352 , 297 , 333 , 327 , 310 , 283 , 354 , 359 , 289 , 338 , 365 , 278 , 333 , 348 , 352 , 260 , 334 , 407 , 317 , 291 , 365 , 393 , 304 , 346 , 349 , 369 , 353 , 399 , 347 , 284 , 411 , 415 , 300 , 354 , 457 , 315 , 367 , 411 , 299 , 342 , 402 , 430 , 290 , 381 , 395 , 329 , 394 , 362 , 319 , 322 , 402 , 345 , 293 , 339 , 355 , 368 , 275 , 386 , 332 , 305 , 399 , 358 , 283 , 411 , 355 , 294 , 420 , 362 , 353 , 319 , 403 , 399 , 258 , 336 , 407 , 316 , 261 , 418 , 373 , 266]) 

	# Y-axis RMS Values
	np_array_rms_y = np.array([1979,2002,1986,1973,1984,1947,1986,1947,1979,2013,1962,2017,1971,1990,1967,1952,1960,1984,1960,1998,1998,1961,1948,1966,1980,1972,2008,1944,2003,1969,1967,1996,1952,2014,1971,1987,1977,1973,1988,1953,1966,1977,1982,1976,1993,1943,1974,1934,2001,1962,1957,2005,1947,2010,1952,2010,1958,1982,1971,1976,2002,1984,1949,1995,1959,1976,1980,1979,1997,1947,2006,1960,1970,1975,1988,1974,1997,1963,1991,1942,1997,1976,1967,1997,1991,2008,1981,1968,1959,1981,1955,1991,1991,1977,1997,1961,1979,1966,1956,1996,1997,1988,1978,1973,1954,1990,1975,1970,1985,1968,2004,1927,1995,1953,1968,1997,1978,2001,1972,1952,1997,1968,1992,1968,1962,1993,1989,1958,1952,1968,1982,1991,1969,1988,1958,1962,1970,1982,1992,1975,1973,1979,1958,1970,2002,1973,1973,2005,1957,1969,1949,1990,1968,1978,1980,1983,1978,1979,1960,1954,1973,1969,1983,1936,1988,1959,2002,1970,1987,2011,1990,1963,1978,1946,1975,1988,1943,1981,1973,1997,1983,1956,1964,2002,1973,1985,1994,1943,1980,1968,1953,1970,1988,1981,1974,1959,1950,1962,1962,1997])

	# Z-axis RMS Values
	np_array_rms_z = np.array([212,294,266,182,238,299,243,249,239,258,290,250,199,195,269,215,221,217,179,242,179,245,193,258,196,191,222,201,240,260,163,252,223,234,184,214,242,232,227,190,229,265,187,233,226,208,235,200,238,185,266,247,196,229,221,259,242,173,231,270,255,180,233,263,227,263,236,184,284,248,225,259,234,254,205,214,206,271,258,189,263,266,267,228,229,253,309,255,190,218,285,204,227,223,199,266,255,226,245,290,242,207,254,243,262,248,160,259,281,256,197,245,292,242,238,257,218,294,192,261,241,241,218,207,248,195,257,241,163,249,232,235,211,178,205,237,227,181,217,266,174,246,191,218,285,201,223,257,247,237,184,237,242,255,263,168,247,243,245,208,228,237,276,243,202,270,303,179,261,264,226,250,204,245,221,260,201,211,238,228,259,229,180,258,272,235,207,200,253,253,262,195,257,260,166,222,215,209,247,217])


	np_array_rms_x = np_array_rms_x.reshape([1,200,1])
	np_array_rms_y = np_array_rms_x.reshape([1,200,1])
	np_array_rms_z = np_array_rms_z.reshape([1,200,1])

	#################################
	# 		  Trained Data			#
	#################################

	#x-axis centroid values
	input_datax = np.array([2007,2001,2011])
	#y-axis centroid values
	input_datay = np.array([84,674,440])
	#z- axis centroid values
	input_dataz = np.array([992,201,845])

	# Add dimensions of centroids for subtraction
	centroids_expanded_x = tf.constant(input_datax, shape=[3,1,1])
	centroids_expanded_y = tf.constant(input_datay, shape=[3,1,1])
	centroids_expanded_z = tf.constant(input_dataz, shape = [3,1,1])

	# Declare place holders to feed the input
	points_expanded_x = tf.placeholder(tf.int32, shape=(1, 200, 1))
	points_expanded_y = tf.placeholder(tf.int32, shape=(1, 200, 1))
	points_expanded_z = tf.placeholder(tf.int32, shape=(1, 200, 1))

	# Calculate distances of each point from all three centroids across all X,Y & Z axis
	sub_x = tf.subtract(points_expanded_x, centroids_expanded_x)
	sub_y = tf.subtract(points_expanded_y, centroids_expanded_y)
	sub_z = tf.subtract(points_expanded_z, centroids_expanded_z)

	square_x = sub_x * sub_x
	square_y = sub_y * sub_y
	square_z = sub_z * sub_z

	distances_x = tf.reduce_sum(square_x,2)
	distances_y = tf.reduce_sum(square_y,2)
	distances_z = tf.reduce_sum(square_z,2)

	# Assign clusters to each point based on the minimum distance from centroids for all 3 axis
	assignments_x = tf.math.argmin(distances_x,0)
	assignments_y = tf.math.argmin(distances_y,0)
	assignments_z = tf.math.argmin(distances_z,0)

	assignments_x = tf.dtypes.cast(assignments_x,tf.int32)
	assignments_y = tf.dtypes.cast(assignments_y,tf.int32)
	assignments_z = tf.dtypes.cast(assignments_z,tf.int32)

	# Run the TensorFlow session
	sess = tf.Session()
	
	# Feed the input points
	feed_dict = {points_expanded_x: np_array_rms_x,points_expanded_y: np_array_rms_y,points_expanded_z:np_array_rms_z}
	sess.run([sub_x,sub_y,sub_z,square_x,square_y,square_z,distances_x,distances_y,distances_z,assignments_x,assignments_y,assignments_z], feed_dict)

	# Save the model in TensorFlow SavedModel Format 
	tf.compat.v1.saved_model.simple_save(sess,export_dir,inputs={"x": points_expanded_x,"y": points_expanded_y,"z":points_expanded_z},outputs={"myOutput": distances_x,"m_y":distances_y,"m_z":distances_z})
	
	print ("="*73)
	print ("="*33," End ","="*33)
	print ("="*73)
	