from generators import *
import graph
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import stats
import scipy
import seaborn as sns
import csv
import multiprocessing

#shut up stupid warnings from not "the-most-up-to-date" libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

def subdict(keys,dictio):
    return {k : dictio[k] for k in keys} 

def values(N, seed, generators):
    """
    random values N elements and given seed.
    Returns a dictionnary with values
    """ 
    results = {}  
    #generate each diagrams for each random func
    for e in generators:
        name = type(e).__name__
        
        print("processing "+name+"...")
        e.seed(seed)
        res = e.batchRandomNormalized(N)
        results[name] = res
    return results

def iteratives(N, seed, generator):
    """
    Random values iteratively with same seed and list N
    """ 
    results = {}  
    #generate each diagrams for each random func
    for n in N:
        name = type(generator).__name__+"_"+str(n)
        
        print("processing "+name+"...")
        generator.seed(seed)
        res = generator.batchRandomNormalized(n)
        results[name] = res
    return results

def write_stats(data, name):
    fields=["name","quantil25", "quantil50", "quantil75", "mean"
            , "variance", "skew", "exckurt"] 
    quantil25 = np.quantile(data,0.25)
    quantil50 = np.quantile(data,0.5)
    quantil75 = np.quantile(data,0.75)
    mean      = np.mean(data)
    variance  = np.var(data)
    skew      = scipy.stats.skew(data)
    exckurt   = scipy.stats.kurtosis(data)
    #write file
    with open('stats.csv', 'a', newline='') as csvfile:
        datawriter = csv.DictWriter(csvfile, fieldnames=fields)
        #datawriter.writeheader()
        datawriter.writerow({
                    "name" : name,
                    "quantil25" : str(quantil25),
                    "quantil50" : str(quantil50),
                    "quantil75" : str(quantil75),
                    "mean"      : str(mean),
                    "variance"  : str(variance),
                    "skew"      : str(skew),
                    "exckurt"   : str(exckurt)})
    return

def write_stats_chi(data, name, RES):
    fields=["name","chi","p"] 
    val = stats.chisquared_uniform(data,RES)
    chi = val[0]
    p = val[1] 
    #write file
    with open('chi.csv', 'a', newline='') as csvfile:
        datawriter = csv.DictWriter(csvfile, fieldnames=fields)
        #datawriter.writeheader()
        datawriter.writerow({
                    "name" : name,
                    "chi" : str(chi),
                    "p" : str(p)})
    return

def mean_evolves(generator,number,iterN):
    datas = {}   
    for i in range(number):
        seed = int(datetime.now().timestamp()+i)
        sub_data = iteratives(iterN, seed, generator)
        res = [] 
        for a in sub_data:
            res.append(np.mean(sub_data[a]))
        datas[str(seed)] = res
    return datas


def pipeline(generator):
    N = 1000000
    RESOLUTION = 100
    iterN = [100,10000,1000000,100000000]
    iterNmean = [100, 1000, 10000, 100000, 1000000]
    seed = int(datetime.now().timestamp())
    name_ = type(generator).__name__

    sns.set_theme(style="darkgrid")
    
    data=values(N,seed,[generator])

    #generate a lot of diagrams
    graph.hist_distributivity_graph(N,RESOLUTION,seed,data)
    for name in data:
        write_stats(data[name],name)
        write_stats_chi(data[name],name,RESOLUTION)

    data_iter = iteratives(iterN, seed, generator)
    graph.compare(0, RESOLUTION, seed, data_iter)
    graph.compare(0, RESOLUTION, seed, data_iter)

    data_alot = mean_evolves(generator, 500, iterNmean)
    graph.lineplot_mean_alot(data_alot, "Size", "Mean"
    , "Mean values evolution for "+name_, iterNmean)


if __name__=="__main__":
    #init generators
    generators = [ParkMiller(), KnuthLewis(), Marsaglia()
                  , LavauxJenssens(), Haynes(), MitchelMoore()
                  , MersenneTwister(), BlumBlumShub()]
    for a in generators:
        p = multiprocessing.Process(target=pipeline, args=[a])
        p.start()

    """
    ##INIT PARAMS##
    N = 1000000
    RESOLUTION = 100
    iterN = [100, 10000, 1000000, 100000000]  

    
    #init seed on timesystem 
    seed = int(datetime.now().timestamp())
    
    #set seaborn params
    sns.set_theme(style="darkgrid")
     
    #generate N random values for each RNG
    #data = values(N,seed,generators)
     
    #graph.compare_ecdf(0, RESOLUTION, seed
    #                   , iteratives([100,1000,1000000], seed, ParkMiller()))

    a = mean_evolves(ParkMiller(), 500, [100,1000])
    graph.lineplot_mean_alot(a, "Size", "Mean"
                             , "Mean Value evolution for ParkMiller"
                             ,[100,1000])

    #generate a lot of diagrams
    graph.hist_distributivity_graph(N,RESOLUTION,seed,data)
    for name in data:
        write_stats(data[name],name)
        write_stats_chi(data[name],name,RES)
    for name in generators:
        data_iter = iteratives(iterN, seed, name)
        
        graph.ecdf_graph(0, RESOLUTION, seed, data_iter)
        graph.compare(0, RESOLUTION, seed, data_iter)
    """