#!/usr/bin/env python3
from keras.models import Sequential
from keras.layers import Dense
import numpy
import numpy.random as rand
#from matplotlib import pyplot
from keras.models import model_from_json
import math
import sys


#python3.6 -m pip install keras
#python3.6 -m pip install theano
data_set = None

# estes vetores terao seus valores definidos a partir do treinamento do modelo.
INV1 = [1, 2]
INV2 = [0, 2]
INV3 = [0, 1]

MAXS = [4000, 4000, 4000]
MINS = [0, 0, 0] 

alpha = 2.5 
betha = 1.5

#MAES = [0.0040331535722, 0.00447660606592, 0.00339290297573]
MAES = [0.00912140117378,0.0104236871072,0.00361583732514]

def load_normalize(csvName, separator):
    ds = numpy.loadtxt(csvName, delimiter=separator)
    for i in range(0, len(ds)) :
        lin = ds[i,:]
        for j in range(0, len(lin)):
            ds[i,j] = (ds[i,j] - MINS[j])/(MAXS[j]-MINS[j])
    return numpy.array(ds)

def make_data(data, cols, res) :
    X = numpy.array(data[:, cols])
    Y = numpy.array(data[:, res])
    return X, Y

def make_data_set(data, dtype):
    if (dtype == 'inv1'):
        return make_data(data, INV1, 0)
    if (dtype == 'inv2') :
        return make_data(data, INV2, 1)
    if (dtype == 'inv3') :
        return make_data(data, INV3, 2)

def load_model(model_file, weights_file):
    jf = open(model_file, "r")
    desc = jf.read()
    model = model_from_json(desc)
    model.load_weights(weights_file)
    model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['accuracy'])
    return model

def rescale(data, index):
    resp = numpy.zeros(len(data))
    for i in range(0, len(data)) :
        resp[i] = (data[i] * (MAXS[index]-MINS[index])) + MINS[index]
    return resp

def predict(model, inputs, expected_value, MAE):
    aux = model.predict( inputs )
    pv = aux[0,0]
    delta = math.fabs(expected_value - pv)        
    if delta <= (betha * MAE):
        c = 1
    else :
        c = 1 - (( math.fabs(delta - (betha * MAE)) ) / ( alpha  * MAE ))
    if c < 0:
        c = 0
    return pv, c
 
def simulate_all( arquivo, saida ):
    ds = load_normalize( arquivo, "," )
    model_inv1 = load_model('inv1.json', 'inv1.h5')
    model_inv2 = load_model('inv2.json', 'inv2.h5')
    model_inv3 = load_model('inv3.json', 'inv3.h5')
    lin = numpy.array(ds[0]);
    vs1 = []
    vs2 = []
    vs3 = []
    ps1 = []
    ps2 = []
    ps3 = []
    c1 = []
    c2 = []
    c3 = []
    cta = 1
    for i in range(len(ds)):
        y1 = ds[i, 0]
        y2 = ds[i, 1]
        y3 = ds[i, 2]

        lin_inv1 = numpy.array(lin[INV1])
        lin_inv1 = numpy.reshape(lin_inv1, (1, len(lin_inv1)))
        vi1, cv1 = predict( model_inv1, lin_inv1, y1, MAES[0] )

        lin_inv2 = numpy.array(lin[INV2])
        lin_inv2 = numpy.reshape(lin_inv2, (1, len(lin_inv2)))
        vi2, cv2 = predict( model_inv2, lin_inv2, y2, MAES[1] )

        lin_inv3 = numpy.array(lin[INV3])
        lin_inv3 = numpy.reshape(lin_inv3, (1, len(lin_inv3)))
        vi3, cv3 = predict( model_inv3, lin_inv3, y3, MAES[2] )

        vs1.append(y1)
        ps1.append(vi1)
        vs2.append(y2)
        ps2.append(vi2)
        vs3.append(y3)
        ps3.append(vi3)
        c1.append(cv1 * 100)
        c2.append(cv2 * 100)
        c3.append(cv3 * 100)
      
        lin = numpy.array(ds[i])
        if (cv1 <= 0.5):           
            lin[0] = vi1 # ds[i-1, 10]
        if (cv2 <= 0.5):
            lin[1] = vi2 # ds[i-1, 5]
        if (cv3 <= 0.5):
            lin[2] = vi3 # ds[i-1, 7]
    ps1 = rescale(ps1,0)
    vs1 = rescale(vs1,0);
    ps2 = rescale(ps2,1)
    vs2 = rescale(vs2,1);
    ps3 = rescale(ps3,2)
    vs3 = rescale(vs3,2);
    fout = open(saida,"w")
    for i in range(len(c1)):
        fout.write(str(int(c1[i]))+","+str(int(c2[i]))+","+str(int(c3[i]))+"\n")
    fout.close()

if __name__ == '__main__':
    numpy.random.seed( 147)
    infile = sys.argv[1]
    outfile = sys.argv[2]
    if (len(sys.argv) >= 4):
        alpha = float(sys.argv[3])
    if (len(sys.argv) >= 5):
        betha = float(sys.argv[4])
   
    data_set = load_normalize(infile, ",")
    simulate_all( infile, outfile )