"""
Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
A copy of the License is located at

http://aws.amazon.com/apache2.0

or in the "license" file accompanying this file. This file is distributed
on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License.

"""
import numpy as np
import pdb
class AmazonAED():

    '''
        Args:
            change_points (list): A list that holds all change points detected so far
            detection_points (list): A list that holds all detection points
    '''

    def __init__(self, data=None):

        self.change_points = []
        self.detection_points = []
        self.magnitudes = []
        self.p_values = []
        self.p_value_change_points =[]



    def streamingMPP(self, X, startup=5,
        alpha_level = 0.01,
        change_difference_window=5,
        confidence_intervals=None,
        practical_significance_threshold=0):

        '''
        Detect change point in a stream of data samples

        Args:
            X (numpy.ndarray): a 1D array of data points
            threshold (float): a threshold for the change
            startup (int): the number of samples to include in the initialization stage


        '''
        self.data = X
        Us = np.zeros(len(self.data))
        # array to keep unnormalised U_n statistics
        prev_change_point = 0                    # keep track of the previous change point. init with 1

        U_k_n_values = []
        change_differences = []

        T = 0
        while T < len(self.data):  # iterate through data samples

            # Un scores for the k in range [0, n-1]
            cur_Us = np.zeros(T - prev_change_point)
            ks = range(T-prev_change_point) # up to the T-th point keep T-1 scores
            for k in ks:
                if k == 0:
                    #initialise previous D_k_n
                    prev_D_k_n = 0


                '''store values for U_k_n to compute in next iterations and then update them.
                The length of the list is n-1 at nth iteration. For n+1 the list is updated
                and the value for U_k_n+1 is appended'''
                U_k_n_values[k], prev_D_k_n = U_k_n(self.data, k+prev_change_point, T, U_k_n_values[k],
                    prev_D_k_n, practical_significance_threshold, confidence_intervals)
                cur_Us[k] = normalise_U(U_k_n_values[k], k+1, T+1-prev_change_point)

            U_k_n_values.append(0)
            if T <= prev_change_point + startup:
                Us[T] = 0
            else:
                Us[T] = np.max(np.abs(cur_Us))

            cur_p_value = p_value(np.max(np.abs(U_k_n_values)), T+1-prev_change_point)
            self.p_values.append(cur_p_value)

            if T>= startup+prev_change_point:
                detected = cur_p_value < alpha_level
            else: detected= False



            if (detected):
                # T is detection point, change point is the argmax
                self.detection_points.append(T)
                prev_change_point = np.argmax(np.abs(cur_Us))+ prev_change_point        #save the previous change point
                Us[prev_change_point] = np.max(np.abs(cur_Us))
                self.change_points.append(prev_change_point)
                p_value_change_point = p_value(np.max(np.abs(U_k_n_values)), T+1-prev_change_point)
                self.p_value_change_points.append(p_value_change_point)

                #p_values[prev_change_point] = p_value_change_point
                before_median = np.median(X[prev_change_point-change_difference_window:prev_change_point])
                after_median = np.median(X[prev_change_point:prev_change_point+change_difference_window])
                self.magnitudes.append(after_median - before_median)
                # reset params
                T = prev_change_point + 1
                U_k_n_values = [0]

            else:
                T +=1

        #return Us

    def get_change_points(self):
        return self.change_points

    def changeDetected(self):
        """
            Checks if the model has detected a change point.
            Returns:
                (bool): True if a change point has been detected and False otherwise
        """
        if self.change_points: return True
        return False

    def detectChangePointBatch(self, X, threshold,
        use_p_value_threshold=True, practical_significance_threshold = 0,
        confidence_intervals=None):

        '''
            Get change points. Recursively checks the
        '''
        n = len(X)-1
        Us = np.zeros(n)
        if confidence_intervals is None:
            for i in range(n):
                Us[i] = np.abs(normalise_U(U_t_T(X, i, n+1), i+1, n+1))
                self.p_values.append(p_value(Us[i], n+1))
        else:
            for i in range(n):
                Us[i] = np.abs(normalise_U(U_t_T_probabilistic(X, i, n+1, confidence_intervals), i+1, n+1))
                self.p_values.append(p_value(Us[i], n+1))

        if use_p_value_threshold:
           if (np.max(Us) > threshold):
            return np.argmax(Us), Us
        elif p_value(np.max(Us)) < threshold:
            return ap.argmax(Us), Us

        return 0, Us

##################################################################### Utility functions ################################################
def V_t_T(X, t, T):
    '''
    Recursive implementation to find V_t_T to compute the batch statistic(Pettitt, 1979)

    Args:
        X (numpy.ndarray): 1D array of observations
        t (int): index of the current observation
        T (int): total number of observations
    Returns:
        (int) V_t_T for current observation t and total number of observaions T
    '''
    sum_V = 0
    if T == 0: return 0
    for i in range(T):
        sum_V += np.sign(X[t] - X[i])
    return sum_V

def V_t_T_probabilistic(X, t, T, sigmas):
    '''
    Probabilistic implementation of V_t_T. It takes uncertainty as confidence intervals and uses tanh function instead of sign.
    Recursive implementation to find V_t_T to compute the batch statistic(Pettitt, 1979)

    Args:
        X (numpy.ndarray): 1D array of observations
        t (int): index of the current observation
        T (int): total number of observations
        sigmas (numpy.ndarray): an array of confidence intervals for the observations
    Returns:
        (float): V_t_T for current observation t and total number of observaions T
    '''

    sum_V = 0
    if T == 0: return 0
    sigma_j = sigmas[t]
    for i in range(T):
        sigma_i = sigmas[i]
        sum_V += np.tanh((X[t] - X[i])/np.sqrt(sigma_i**2+sigma_j**2))
    return sum_V

def U_t_T(X, t, T):
    '''
    Recursive implementation of the Matt-Whitney statistic for observation t and total number of observations T.
    The algorithm used in Pettitt, 1979

    Args:
        X (numpy.ndarray): 1D array of observations
        t (int): index of the current observation
        T (int): total number of observations
    Returns:
        (float): result of the Mann-Whitney statistic for current observation t and total number of observations T
    '''
    if t==0: return V_t_T(X, t, T)
    else:
        curr_U = U_t_T(X, t-1, T) + V_t_T(X, t, T)
        return curr_U

def U_t_T_probabilistic(X, t, T, sigmas):
    '''
    Recursive probabilistic implementation of the Matt-Whitney statistic for observation t and total number of observations T.
    The algorithm used in Pettitt, 1979 and uncertainty in the form of confidence intervals is added

    Args:
        X (numpy.ndarray): 1D array of observations
        t (int): index of the current observation
        T (int): total number of observations
        sigmas (numpy.ndarray): an array of confidence intervals for the observations
    Returns:
        (float): result of the Mann-Whitney statistic with uncertainty included for current observation t and total number of observations T

    '''
    if t==0: return V_t_T_probabilistic(X, t, T, sigmas)
    else:
        curr_U = U_t_T_probabilistic(X, t-1, T, sigmas) + V_t_T_probabilistic(X, t, T, sigmas)
        return curr_U

def D_i_j(X, i, j, sigma_i=0, sigma_j=0):
    '''Function to compute the sign of the difference X[i]-X[j].
    Given confidence intervals we use hyperbolic tangent instead of sign function taking into account the uncertainty

    Args:
        X (numpy.ndarray): 1D array of observations
        i (int): current index
        j (int): final observation in processing window index
    Returns:
        (float): either sign(X[i]-X[j]) or tanh((X[i]-X[j])/np.sqrt(sigma_i**2+sigma_j**2)) for observations i and j
    '''
    difference = X[i] - X[j]
    if not (sigma_i == 0 and sigma_j==0):
        return np.tanh((difference)/np.sqrt(sigma_i**2+sigma_j**2))

    return np.sign(X[i] - X[j])[0]

def D_i_j_threshold(X, i, j, threshold=0, sigma_i=0, sigma_j=0):
    '''Implements D_i_j sign function using practical significance threshold and confidence intervals

    Args:
        X (numpy.ndarray): !D array of observations
        i (int): current index
        j (int): final observation in processing window index
        threshold (float): threshold for practical significance
        sigma_i (float): confidence interval for sample i
        sigma_j (float): confidence interval for sample j
    Returns:
        (float): The D_i_j function with a practical significance threshold included

    '''
    threshold = np.abs(threshold)
    difference = X[i] - X[j]
    if not (sigma_i == 0 and sigma_j==0):
        if difference < -threshold: return np.tanh((difference+threshold)/np.sqrt(sigma_i**2+sigma_j**2))
        elif difference > threshold: return np.tanh((difference-threshold)/np.sqrt(sigma_i**2+sigma_j**2))
    else:
        if difference < -threshold: return -1
        elif difference > threshold: return 1

    return 0


def U_k_n(X, k, n, prev_u_k, prev_D_k_n,
    practical_significance_threshold = 0,
    confidence_intervals = None):
    ''' Calculate Matt-Whitney statistic given previous window size.

    Args:
        X (numpy.ndarray): an array holding data samples
        k (int): the current evaluation point
        n (int): the number of observations considered in this measurement
        prev_u_k (float): value of the statistic for k and n-1 number of observations
        prev_D_k_n (float): value of the sign function for k-1 and n number of observations
        practical_significance_threshold (float): a for a practical significance change
        confidence_intervals (numpy.ndarray): an optional array holding confidence intervals for the data points
    Returns:
        (float, float): U_k_n - The value of Matt-Whitney statistic for at sample k and n observations D_k_n - The value fom the sign function at sample k and n observations

    '''
    if confidence_intervals is not None:
        sigma_k = confidence_intervals[k-1]
        sigma_n = confidence_intervals[n-1]
    else:
        sigma_k = 0
        sigma_n = 0
    if practical_significance_threshold: D_k_n =  prev_D_k_n + D_i_j_threshold(X, k-1, n-1, practical_significance_threshold, sigma_k, sigma_n)
    else: D_k_n =  prev_D_k_n + D_i_j(X, k-1, n-1, sigma_k, sigma_n)
    U_k_n = prev_u_k + D_k_n
    return (U_k_n, D_k_n)

def normalise_U(U, X_k, n):
    '''
    Normalisation according to Hawkins, 2010 of U_k_n at point k for sequence of length n

    Args:
        U (float): the value of U_k_n
        k (float): the value
    Return:
        (float): returns the normalised value using normalisation equation in Hawkins, 2010
    '''
    return U/np.sqrt(X_k*(n-X_k)*(n+1)/3)

def normalise_U_2(U_k_n, n):
    '''
    Normalisation according to Pettitt, 1978

    Args:
        U_k_n (float): value of the statistic for k and n number of observations
        n (int): number of observations
    Return:
        (float): returns the normalised value using normalisation equation in Pettitt

    '''
    return 1/n*np.sqrt(3/(n+1))*U

def p_value(U_k_n, n):
    '''
    Function to compute the p values for the statistic  U_k_n

    Args:
        U_k_n (float): value of the statistic for k and n number of observations
        n (int): number of observations
    Returns:
        (float): the p-value computed for the statistic at for k for n observations in total
    '''
    return 2*np.exp(-(6*U_k_n**2)/(n**3+n**2))
