# -*- coding: utf-8 -*-
"""
Created on Wed Jun 26 09:12:59 2019

@author: Jahanzaib Malik
"""


import numpy as np
import itertools
import os
import keras
from IPython.display import clear_output
import matplotlib.pyplot as plt
from collections import Counter
from keras.models import Model # basic class for specifying and training a neural network
import statistics

def accuracy_by_class(T, P):
    #Fixing Issue with missing values
    c = 0
    U = len(np.unique(T))
    count_pred = dict(Counter(P[0]))
    for x in range(0,U):
        try:
            if(x in count_pred):
                c = c + 1
            else:
                count_pred.update( {x: 0} )
        except:
            print("Error")

    print("{}/{} are OK...!!".format(c,U))    
#    Calculating Accuracy
    acc = {}
    count_pred = dict(Counter(P[0]))
    count_test = dict(Counter(T[0]))
    T = np.array(T)
    P = np.array(P)
    for y in range(0,U):
        count = 0
        for x in range(0,len(P)):
            if(T[x] == y):
               if(P[x] == y):
                   count = count + 1
        acc[y] =  ((count * 100) /    count_test[y])
        print("Accuracy of class {} is {}".format(y,acc[y]))
    return acc
    
def Flush_results(Traning_Time,Testing_Time,Acc,TPR,TNR,NPV,FNR,FDR,FPR,F1,MCC,FOR,Recall,Precision,TS,BM,MK):
    print("Training Time")
    for x in Traning_Time:
        print(x)
    #print("Average:",statistics.mean(Traning_Time))
    
    print("\nTesting Time")
    for x in Testing_Time:
        print(x)
    #print("Average:",statistics.mean(Testing_Time))
    
    print("\nAccuracy")
    for x in Acc:
        print(x)
    print("Average:",statistics.mean(Acc))
     
    print("\nTPR")
    for x in TPR:
        print(x)
    print("Average:",statistics.mean(TPR))
    
    print("\nTNR")
    for x in TNR:
        print(x)
    print("Average:",statistics.mean(TNR))
    
    print("\nNPV")
    for x in NPV:
        print(x)
    print("Average:",statistics.mean(NPV))
    
    print("\nFNR")
    for x in FNR:
        print(x)
    print("Average:",statistics.mean(FNR))
    
    print("\nFPR")
    for x in FPR:
        print(x)
    print("Average:",statistics.mean(FPR))
    
    print("\nFDR")
    for x in FDR:
        print(x)
    print("Average:",statistics.mean(FDR))
    
    print("\nFOR")
    for x in FOR:
        print(x)
    print("Average:",statistics.mean(FOR))
   
    print("\nRecall")
    for x in Recall:
        print(x)
    print("Average:",statistics.mean(Recall))
    
    print("\nPrecision")
    for x in Precision:
        print(x)
    print("Average:",statistics.mean(Precision))
    
    print("\nMCC")
    for x in MCC:
        print(x)
    print("Average:",statistics.mean(MCC))
    
    print("\nF1")
    for x in F1:
        print(x)
    print("Average:",statistics.mean(F1))
    print("\nTS")
    for x in TS:
        print(x)
    print("Average:",statistics.mean(TS))
    print("\nBM")
    for x in BM:
        print(x)
    print("Average:",statistics.mean(BM))
    print("\nMK")
    for x in MK:
        print(x)
    print("Average:",statistics.mean(MK))




       

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix' , cmap=plt.cm.Greens , nn=0 , Algo="None",
                          path=os.path.dirname(os.path.realpath('Functions.py'))):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)
    
    plt.figure(figsize=(5,5))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)    
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    s_name = str(path+"\\Plots\\Cnf\\") + str(Algo) + str(nn) + ".png"
    plt.savefig(s_name)
    nn = nn + 1
#    he = he + 1
    
class PlotLearning(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.losses = []
        self.val_losses = []
        self.acc = []
        self.val_acc = []
        self.fig = plt.figure()
        
        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        
        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        # self.val_losses.append(logs.get('val_loss'))
        self.acc.append(logs.get('acc'))
        # self.val_acc.append(logs.get('val_acc'))
        self.i += 1
        f, (ax1, ax2) = plt.subplots(1, 2, sharex=True)
        
        clear_output(wait=True)
        
        ax1.set_yscale('log')
        ax1.plot(self.x, self.losses, label="loss", color='red')
        # ax1.plot(self.x, self.val_losses, label="val_loss")
        ax1.legend()
        
        ax2.plot(self.x, self.acc, label="accuracy", color='green')
        # ax2.plot(self.x, self.val_acc, label="validation accuracy")
        ax2.legend()
        
        plt.show();
