#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''


'''

import pandas as pd
import os
import sys
import argparse
import random


from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster, set_link_color_palette

import matplotlib; matplotlib.use('agg')
#import matplotlib
from matplotlib import pyplot as plt
from matplotlib.pyplot import cm
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn import manifold
import numpy as np


class SimkaCluster():
    '''
    Simkaclustering:
        for now : HCA and silhouette scores are implemented.

    matrice: pairwise distance matrix
    silscore : best silhouette score
    partitions : dict of partitions and their respective silhouette score
    cluster : dict with clusterid as key [int] and cluster members as values
    HCA_threshold : height associated to best silhouette score
    HCA_labels : list of cluster id (size : length of matrix index)
    '''
    def __init__(self,matrice):#,method):
        self.matrice = matrice
        # self.method = method
        self.silscore = 0
        self.partitions = {}
        self.cluster = {}
        self.HCA_threshold = None
        self.HCA_labels = None


    def funHCA(self):
        condensed_df = squareform(self.matrice)
        self.HCA = linkage(condensed_df, method="ward")

    def silhouette(self, threshold=None):
        #TODO add a way to avoid ValueError Exception : Number of labels is 5.
        #Valid values are 2 to n_samples - 1 (inclusive).

        # if self.method == "HCA":
            # run HCA
        self.funHCA()

        #No threshold given, so iterate over height to find the best silhouette' score
        if threshold is None:
            self.threshold = 0
            #x = squareform(self.matrice)

            # HCA max height
            max_height = max(self.HCA[:,2])

            # Default labels if only one cluster
            print("Start test")
            print("Self.HCA", self.HCA)
            labels = fcluster(self.HCA, t=0, criterion='distance')
            print("Labels :", labels)
            ncluster = set(labels)
            print("ncluster", ncluster)
            unique_cluster = (list(ncluster))
            print("unique_cluster", unique_cluster)
            #test_silscore = silhouette_score(self.matrice, labels)
            #print("Test_silscore :", test_silscore)
            print("self.matrice", self.matrice)
            try:
                #try to write even if silhouette_score is not working
                self.silscore = silhouette_score(self.matrice, labels)
            except ValueError:
                self.silscore = 0
            self.HCA_threshold = threshold
            print("self.HCA_threshold", self.HCA_threshold)
            self.HCA_labels = labels
            print("self.HCA_labels", self.HCA_labels)

            #store silscores for each iteration
            silscores = []
            #store the size of the partition for each iteration
            k = []
            #store height cutoff for each iteration
            heights = []
            print("Threshold is undefined - searching for best avg silhouette score by iterating over height")
            print("Max_height : {}".format(max_height))
            for height in np.arange(0, max_height, 0.01):
                # Cut HCA for height = i
                labels = fcluster(self.HCA, t=height, criterion='distance')

                # Get the number of cluster
                ncluster = set(labels)
                print("Nombre cluster :", ncluster)
                unique_cluster = (list(ncluster))
                print("unique cluster :", unique_cluster)
                print("Number of unique cluster for height={}, k = {}".format(height, len(unique_cluster)))
                # Because we can't measure silhouette score if there is one or len(matrice.index) clusters.
                if len(unique_cluster) > 1 and len(unique_cluster) <= len(labels) - 1:
                    print("Labels :", labels)
                    it_silscore = silhouette_score(self.matrice, labels)
                    print("silscore", it_silscore)
                    print("Silhouette score for height= {}, sil = {}".format(height, it_silscore))
                    k.append(len(unique_cluster))
                    print("k_clusters= {}".format(len(unique_cluster)))
                    silscores.append(it_silscore)
                    heights.append(height)

                    if it_silscore > self.silscore:
                        # get the best silhouette score
                        print("Avg silscore improved for height={}, new best silscores = {}".format(height, it_silscore) )
                        self.silscore = it_silscore #silhouette_score(self.matrice,labels)
                        self.HCA_threshold = height
                        self.HCA_labels = labels

            self.partitions = {"silscores":silscores , "k_clusters":k, "height":heights }


        # An height is given for HCA partitionning
        else :
            labels=fcluster(self.HCA,t=threshold,criterion='distance')
            ncluster = set(labels)
            unique_cluster = (list(ncluster))
            try:
                #try to write even if silhouette_score is not working
                self.silscore = silhouette_score(self.matrice , labels)
            except ValueError:
                self.silscore = None
            self.HCA_threshold = threshold
            self.HCA_labels = labels

        # Build a dict with cluster id as keys and list of members as values
        self.cluster = {}
        index = 0
        for i in self.HCA_labels:
            if i in self.cluster.keys():
                self.cluster[i].append(self.matrice.index[index])
            else:
                self.cluster.update({i:[self.matrice.index[index]]})
            index += 1
        # else:
        #     raise NameError("not implemented yet")

        return self.cluster

    def plot_cluster(self, outdir):
        '''
        Plot samples 2 samples heatmap
        HCA : save a dendrogram
        '''

        # Plot Dendrogram for HCA #
        cmap = cm.plasma(np.linspace(0, 1, len(self.cluster.keys())))
        colors = [matplotlib.colors.rgb2hex(rgb[:3]) for rgb in cmap]
        colors_shuffled = random.sample(colors, len(colors))
        set_link_color_palette(colors_shuffled)

        if self.HCA_threshold is None:
            #only one cluster so no threshold set
            tree = dendrogram(self.HCA, leaf_rotation=90, \
                leaf_font_size=3, labels=list(self.matrice.index))
        else:
            tree = dendrogram(self.HCA,
                color_threshold=self.HCA_threshold,
                leaf_rotation=100,
                leaf_font_size=3,
                labels=list(self.matrice.index),
                above_threshold_color='lightgrey')
            plt.axhline(linestyle='--', y=self.HCA_threshold, color="black")

        plt.savefig(os.path.join(outdir, "cluster_dendrogram.png"))
        plt.close()

        # plot samples heatmap #
        hm = sns.clustermap(self.matrice,
            method="ward",
            cmap="mako",
            xticklabels=False,
            yticklabels=False)
        plt.savefig(os.path.join(outdir, "sample_heatmap.pdf"))
        plt.close()

        # plot cluster size distribution #
        clusters_size = [len(self.cluster.get(keys)) for keys in self.cluster.keys()]
        counts, bins = np.histogram(clusters_size)
        plt.hist(bins[:-1], bins, weights=counts, color='steelblue')
        plt.title('Cluster size distribution')
        plt.xlabel('Cluster size')
        plt.ylabel('counts')
        plt.savefig(os.path.join(outdir, "cluster_size_distribution.png"))
        plt.close()

    def top_partition(self):
        """
        Return a sorted df from highest to lowest silhouete scores
        associated with the number of clusters.
        """
        list1 = self.partitions.get("k_clusters")
        list2 = self.partitions.get("silscores")
        list3 = self.partitions.get("height")

        zipped_lists = zip(list2, list1, list3)
        zipped_lists = list(set(zipped_lists))
        sorted_pairs = sorted(zipped_lists, reverse=True)

        tuples = zip(*sorted_pairs)
        list2, list1 ,list3 = [list(tuple) for tuple in  tuples]

        df = pd.DataFrame({"silscores":list2, "k_cluster":list1, "height":list3})
        return df

    def plot_silhouette_distri(self, outdir):

        df = pd.DataFrame(self.partitions)

        df.plot(x="k_clusters", y="silscores", kind="scatter")
        plt.savefig(os.path.join(outdir, "Silscores_distribution.png"))
        plt.close()


    def write_cluster(self,outdir,split):
        '''
        Write each cluster in a file (one sample per line).
        Those files will be used for co-assembly.
        '''
        summary = os.path.join(outdir,"clusters.tsv")
        with open(summary,'w') as f_summary:
            for clusterid in self.cluster.keys():
                if split:
                    clusterfile = os.path.join(outdir, f"cluster_{clusterid}.txt")
                    with open(clusterfile, 'w') as f_cluster:
                        for sample in self.cluster.get(clusterid):
                            half=round(len(sample)/2)
                            sample_div=[]
                            for i in range(0, len(sample), half) :
                                sample_div.append(sample[i:i+half])
                                sample_name=sample_div[0]
                            f_summary.write(f"{clusterid}\t{sample_name}\n")
                            f_cluster.write(f"{sample_name}\n")
                else:
                    for sample in self.cluster.get(clusterid):
                        half=round(len(sample)/2)
                        sample_div=[]
                        for i in range(0, len(sample), half) :
                            sample_div.append(sample[i:i+half])
                            sample_name=sample_div[0]
                        f_summary.write(f"{clusterid}\t{sample_name}\n")



def main():
    try:
        input = snakemake.input[0]
        # cluster_strat = snakemake.params.clustering
        outdir = snakemake.output[0]
        plot = True
        split = True
        height = None if snakemake.params.height == "None" else snakemake.params.height
        sep = ";"

    except NameError:
        parser = argparse.ArgumentParser(
            prog='clustersample',
            description='use a similarity matrix to cluster samples using HAC. \
                Return HAC partition with the best silhouette score')
        parser.add_argument(
            'matrix', type=str,
            help='(required) similarity matrix .csv, gzip or not')
        # parser.add_argument(
        #     '-m','--method', type=str,
        #     default="HCA",
        #     help='(optional) clustering strategy, only HCA avalaible for now. default = HCA')
        parser.add_argument(
            '-o','--outdir', type=str,
            default=None,
            help='(optional) Output directory where clusters will be saved, default stdout')
        parser.add_argument(
            '-p','--plot', action='store_true',
            help='(optional) plots HCA and scatterplot (MDS)')
        parser.add_argument(
            '-s','--split', action='store_true',
            help='(optional) split main results by cluster, produce one file per cluster')
        parser.add_argument(
            '--height', type=float,
            default=None,
            help='(optional) threshold to apply when forming clusters')
        parser.add_argument(
            '--sep', type=str,
            default = "\t",
            help='(optional) separator for csv file (default: tab)')


        args = parser.parse_args()

        input = args.matrix
        # cluster_strat = args.method
        outdir = args.outdir
        plot = args.plot
        split =  args.split
        height = args.height
        sep = args.sep

    #open input
    try:
        if input.endswith('.gz'):
                df = pd.read_csv(input, compression='gzip', index_col=0, header=0, sep=sep)
        else:
                df = pd.read_csv(input, index_col=0, header=0, sep=sep)
    except:
        raise FileNotFoundError


    M_Sim = SimkaCluster(df)#,cluster_strat)

    if outdir is not None:
        os.makedirs(outdir, exist_ok=True)
        partitions_path = os.path.join(outdir, "silscores_partition.tsv")
        clusters_path = outdir
        plot_path = outdir
    else:
        #outdir = sys.stdout
        partitions_path = outdir
        clusters_path = outdir

    print("height value : {}".format(height))

    if height == None :
        clusters = M_Sim.silhouette()
        if(len(clusters.keys()) > 1):
            partition = M_Sim.top_partition()
            partition.to_csv(partitions_path, header=True, index=False, sep="\t")

        if plot and outdir is not None:
            if(len(clusters.keys()) > 1):
                M_Sim.plot_silhouette_distri(plot_path)
            M_Sim.plot_cluster(plot_path)
    else:
        M_Sim.silhouette(height)
        if plot and outdir is not None:
            M_Sim.plot_cluster(plot_path)

    if outdir is not None:
        M_Sim.write_cluster(clusters_path, split)


if __name__ == '__main__':
    main()
