#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import argparse
import sys
import re
import pandas as pd
import logging


def contigstats(cov, covfilt = None):       
        
        data = {
        "total_length":None,
        "post_filter_total_length":None,
        "n_contigs":None,
        "post_filter_n_contigs":None,
        "n_longcontigs":None,
        "post_filter_n_longcontigs":None,
        "n_verylongcontigs":None,
        "post_filter_n_verylongcontigs":None,            
        "N50":None,
        "post_filter_N50":None,
        "L50":None,
        "post_filter_L50":None,            
        }
        try:
            coverage_dict = pd.read_csv(cov, header=0, index_col=0,sep="\t", squeeze=True).transpose().to_dict()
        except FileNotFoundError as e:
            raise e("{} not found".format(cov))
        # labels
        # data["labels"] = sample_name
        # data["strategy"] = "co_assembly" if re.match("cluster",sample_name) else "single_assembly"
        # total lenght
        total=0
        for contigs in coverage_dict.keys():
            total += coverage_dict.get(contigs).get("length")
        data["total_length"]=total
        # number of contigs
        data["n_contigs"]=len(coverage_dict.keys())
        # number of long contigs
        data["n_longcontigs"]=len([contigs for contigs in coverage_dict.keys() if coverage_dict.get(contigs).get("length") > 10000])
        # number of very long contigs
        data["n_verylongcontigs"]=len([contigs for contigs in coverage_dict.keys() if coverage_dict.get(contigs).get("length") > 50000])
        #N50
        lengths = sorted([coverage_dict.get(contigs).get("length") for contigs in coverage_dict.keys()],reverse=True)
        cum_l= 0
        i = 0
        while cum_l < total/2:
            cum_l += lengths[i]
            N50=lengths[i]
            L50 = i
            i += 1
        data["N50"]=N50
        data["L50"]=L50
        
        if covfilt is not None:
            filtered = pd.read_csv(covfilt, header=0, index_col=0,sep="\t", squeeze=True).transpose().to_dict()
            #total length
            total=0
            for contigs in filtered.keys():
                total += filtered.get(contigs).get("length")
            data["post_filter_total_length"]=total
            # number of contigs
            data["post_filter_n_contigs"]=len(filtered.keys())
            # number of long contigs
            data["post_filter_n_longcontigs"] = len([contigs for contigs in filtered.keys() if filtered.get(contigs).get("length") > 1000])
            # number of very long contigs
            data["post_filter_n_verylongcontigs"]=len([contigs for contigs in filtered.keys() if filtered.get(contigs).get("length") > 50000])
            #N50
            lengths = sorted([filtered.get(contigs).get("length") for contigs in filtered.keys()],reverse=True)
            cum_l= 0
            i = 0
            while cum_l < total/2:
                cum_l += lengths[i]
                N50=lengths[i]
                L50 = i
                i += 1
            data["post_filter_N50"]=N50
            data["post_filter_L50"]=L50
         
        # if sample_name is None:
        #     sample_name = "undefined"
    
        return data




def main():
    # parsing script arguments
    try:
        covstats = snakemake.params.unfiltered
        filtered = snakemake.params.filtered
        out = str(snakemake.output)
        index_name = snakemake.params.src
    except NameError:
        parser = argparse.ArgumentParser(
            prog='contigstats',
            description='Estimate assembly statistics (N50, length, etc)'
            'from metacovest results.')
        parser.add_argument(
            'covstats', type=str,
            help='coverage statistics file (comma separeted)')
        parser.add_argument(
            '-f','--filtered', type=str,
            default=None,
            help='(Optional) coverage statistics files after filtering step (comma separeted)')
        parser.add_argument(
            '-o', '--out', type=str,
            default=sys.stdout,
            help='(Optional) file to store concatenated results')
        parser.add_argument(
            '--index-name', type=str,
            default=None,
            help='(Optional) row name, (default is filename)')
        args = parser.parse_args()

        covstats = args.covstats
        filtered = args.filtered
        out = args.out 
        index_name = args.index_name

    #prepare results for concatenation
    data = {}
    if index_name is not None:
        row_index = index_name
    else:
        row_index=os.path.basename(covstats).split("_")[0]
    
    logging.info("assembly statistics running for {} ".format(covstats))
    
    data[row_index]=contigstats(covstats,covfilt=filtered)

    #write results to args.out (default stdout)
    df = pd.DataFrame(data).transpose()
    df.to_csv(out , sep="\t", index_label="ID")


if __name__ == "__main__":
    main()
