#!/usr/bin/env python3
# -*- coding: utf-8 -*-

################################################################################
#
#    Declaration
#        mags : directory where mags are stored.
#        dRep_levels : level of dereplication.
#        dRep_tables : expected tables generated by dRep
#
################################################################################
working_dir = config['project']
tmpdir = os.path.join(working_dir, "tmp")

bins_collection = config["Binning"]["filter_mags"][0]
dRep_levels = config["clustering_parameters"]["nucmer_ani"]

wildcard_constraints:
    level = "0.\d\d",
    identifyoralign = "identify|align",

annotation_level = config["GTDBTK"]["level"]
GTDBTK_batch_size = config["GTDBTK"]["batch_size"]

if simka_type != 'None':
    input_of_target_simka = directory("intermediate_results/assembly/co_assembly")
else:
    input_of_target_simka = 'tmp/simka_skip.txt' #skip cluster_simka checkpoint

if config["eggNOGG"]["enable"]:
    annotate_fct = "tmp/mags_functionnal_annotation.done"
else:
    annotate_fct = ""

if config["GTDBTK"]["enable"]:
    annotate_taxo = "tmp/mags_taxonomic_annotation.done"
else:
    annotate_taxo = ""

coverm_genome_calcultation = config["coverm_genome"]["calculation_options"]

def input_cmd(wildcards):
    if wildcards.assembly == "single_assembly":
        list_reads = []
        for run in reads2use[wildcards.src]:
            list_reads.extend(reads2use[wildcards.src][run])
        return list_reads
    elif wildcards.assembly == "co_assembly":
        if simka_type == "None":
            return os.path.join(tmpdir, 'samples.txt')
        return os.path.join(intermediate_results_dir, "assembly/co_assembly/clusters/{src}.txt")
    else:
        raise ValueError

dRep_tables = [
#    'Bdb',  # Sequence locations and filenames
    'Cdb',  # Genomes and cluster designations
    'Mdb',  # Raw results of MASH comparisons
    'Ndb',  # Raw results of ANIn comparisons
    'Sdb',  # Scoring information
    'Wdb',  # Winning genomes
    'Widb', # Winning genomes' checkM information
]

################################################################################
#
#    Functions to handle checkpoints and inputs
#
################################################################################

def get_genomes(wildcards):
    checkpoint_output = checkpoints.move_dereplicated_genomes.get(level=str(annotation_level)).output[0]
    return expand(os.path.join("genomes_collection/dereplicated_at_" + str(annotation_level), "{bin}.fa"), bin=glob_wildcards(os.path.join(checkpoint_output, "{bin}.fa")).bin)


def get_function_tables(wildcards):
    #checkpoint_output = checkpoints.move_dereplicated_genomes.get(level=annotation_level).output[0]
    checkpoint_output = checkpoints.resolve_dereplicated_bins.get(**wildcards).output[0]
    genomes = expand(
        #os.path.join(tmp, "genomes_function", "genomes_{bin}_function.tsv"),
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.emapper.annotations"),
        bin=glob_wildcards(os.path.join(checkpoint_output, "{bin}.fa")).bin
        )
    return genomes


def get_genome_eggNOGG(wildcards):
    #checkpoint_output = checkpoints.move_dereplicated_genomes.get(level=annotation_level).output[0]
    checkpoint_output = checkpoints.resolve_dereplicated_bins.get(**wildcards).output[0]
    genomes = expand(
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.emapper.annotations"),
        bin=glob_wildcards(os.path.join(checkpoint_output, "{bin}.fa")).bin
        )
    return genomes


def resolve_final_target(wildcards):
        targets = []
        targets.extend(expand(os.path.join("genomes_collection", "tables", "genomes_abundance.{level}.tsv"), level=dRep_levels))
        targets.extend(expand(os.path.join("genomes_collection", "tables", "bp_covered.{level}.tsv"), level=dRep_levels))
        targets.extend(expand(os.path.join("genomes_collection", "tables", "genomes_abundance.{level}.tsv"), level=dRep_levels))
        targets.append(os.path.join("genomes_collection", "tables", "genomes_length.tsv"))
        targets.extend(expand(os.path.join(working_dir, "genomes_collection/tables/coverm_genomes_abundance.{level}"), level=dRep_levels))
        targets.append(os.path.join("genomes_collection", "tables", "genomes_abricate_functions.tsv"))
        targets.append(os.path.join("genomes_collection", "tables", "genomes_amrfinder_functions.tsv"))
        targets.append(os.path.join("genomes_collection", "tables", "abricate_summary.tsv"))
        targets.append(os.path.join("tmp", "genomad_bin.done"))
        if config["GTDBTK"]["enable"] :
            targets.append(os.path.join("genomes_collection", "tables", "gtdbtk.ar53.bac120.summary.tsv"))
            targets.append(os.path.join("tmp", "mags_taxonomic_annotation.done"))
        if config["eggNOGG"]["enable"]:
            targets.append(os.path.join("genomes_collection", "tables", "genomes_functions.tsv"))
        return targets

################################################################################
#
#    final target of genomes collection
#
################################################################################


rule target_Genomes_catalogue:
    output:
        os.path.join(tmpdir, "genome_catalogue.checkpoint"),
    input:
        resolve_final_target,
    shell:
        "touch {output}"


################################################################################
#
#    generate tables :
#        v- genomes x functions
#        v- genomes x taxo
#        v- genomes x length
#        v- genomes x sample , reads count edition
#        v- genomes x sample , bp covered edition
#        v- genomes x sample , relative abundance edition
#
################################################################################
def expand_sample2genomes(wildcards):
    inputs = []
    if "single_assembly" in assembly_request :
        request="single_assembly"
        inputs.extend(expand(os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.filtered.sorted.bam"),
            src = assembly_dict.get(request), level=dRep_levels, assembly=request))
    if "co_assembly" in assembly_request :
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
        request="co_assembly"
        inputs.extend(expand(os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.filtered.sorted.bam"),
            src = assembly_dict.get(request), level=dRep_levels, assembly=request))
    return inputs

def expand_bam_coverm_genome(wildcards):
    inputs = []
    if "single_assembly" in assembly_request :
        request="single_assembly"
        inputs.extend(expand(os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{{level}}.filtered.sorted.bam"),
            src = assembly_dict.get(request), assembly=request))
    if "co_assembly" in assembly_request :
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
        request="co_assembly"
        inputs.extend(expand(os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{{level}}.filtered.sorted.bam"),
            src = assembly_dict.get(request), assembly=request))
    return inputs

rule coverm_genomes_abundances:
    output:
        os.path.join(working_dir, "genomes_collection/tables/coverm_genomes_abundance.{level}"),
    input:
        expand_bam_coverm_genome,
    params:
        genome_directory = directory("genomes_collection/dereplicated_at_{level}"),
        method = coverm_genome_calcultation,
    conda:
        os.path.join(CONDAENV, "coverm.yaml")
    shell:
        "coverm genome --methods {params.method} --bam-files {input} --genome-fasta-directory {params.genome_directory} --genome-fasta-extension fa -o {output}"

rule genomes_abundances:
    output:
        os.path.join("genomes_collection", "tables", "genomes_abundance.{level}.tsv"),
    input:
        bp_covered = os.path.join("genomes_collection", "tables", "bp_covered.{level}.tsv"),
        length = os.path.join("genomes_collection", "tables", "genomes_length.tsv"),
    run:
        import pandas as pd
        bp_df = pd.read_csv(input.bp_covered, header=0, index_col=0, sep="\t")
        length = pd.read_csv(input.length, sep="\t", index_col=0, header=0, usecols=["genomes", "length"])

        length_dict = length.to_dict().get("length")
        data = []
        for i in bp_df.index:
            data.append(length_dict.get(i))
        ab_df = bp_df.divide(data, axis=0)
        ab_df.to_csv(str(output), sep="\t", header=True, index=True)


rule genomes_raw_counts:
    output:
        reads_count = os.path.join("genomes_collection", "tables", "reads_counts.{level}.tsv"),
        bp_covered = os.path.join("genomes_collection", "tables", "bp_covered.{level}.tsv"),
        horizontal_coverage = os.path.join("genomes_collection", "tables", "horizontal_coverage.{level}.tsv"),
    input:
        expand_bam_coverm_genome,
    params:
        min_baseq = config["genomes_tables"]["min_baseq"],
        min_depth_bp = config["genomes_tables"]["min_depth"],
        only_covered = config["genomes_tables"]["only_covered"],
        properly_paired = config["genomes_tables"]["properly_paired"],
        is_genomes = True
    conda:
        os.path.join(CONDAENV, "bamutils.yaml")
    script:
        "../scripts/generate_raw_tables.py"


rule genomes_length_table:
    output:
        os.path.join("genomes_collection", "tables", "genomes_length.tsv"),
    input:
        sample2genome = "tmp/genomes_collection_finished.checkpoint",
    params:
        genomes_dir = "genomes_collection/dereplicated_at_{}".format(annotation_level),
        ambiguous = True #counts also ambiguous base
    conda:
        os.path.join(CONDAENV, "bamutils.yaml")
    script:
         "../scripts/genomes_length.py"

rule concate_function_tables:
    output:
        os.path.join("genomes_collection", "tables", "genomes_functions.tsv")
    input:
        "tmp/mags_functionnal_annotation.done",
        tables = get_function_tables
    params:
        tables = get_function_tables,
    shell:
        "head -n 1 {params.tables[0]} > {output} && "
        "tail -q -n +2 {params.tables} >> {output} "

####################################################################################################
#
#    target : genomes collection final target
#    below : eggNOGG annotation , GTDBTK annotation, samples2genomes mapping, dReplication and QC with checkm
#
#
####################################################################################################

rule target_annotations:
    output:
        "tmp/genomes_annotation_finished.checkpoint"
    input:
        annotate_fct,
        annotate_taxo,
    shell:
        "touch {output}"

####################################################################################################
#
#    above :
#    target : MAGS functional annotation using eggNOGG-mapper.
#    below : eggNOGGmapper , read-back-mapping , GTDBTK annotation , dreplication , QC
#
####################################################################################################

rule target_mags_function:
    """
    Create temporary .done file to check annotation step has well finished.
    """
    output:
        temp("tmp/mags_functionnal_annotation.done")
    input:
        get_genome_eggNOGG,
    shell:
        "touch {output} "

rule search_orthology_GENOME:
    output:
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.emapper.annotations"),

    input:
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.emapper.seed_orthologs"),
        os.path.join(db_dir, "eggNOG"),
    params:
        prefix = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}")#.emapper.annotations
    conda:
        os.path.join(CONDAENV, "eggnog.yaml")
    threads: 10
    shell:
        "emapper.py --annotate_hits_table {input[0]} --no_file_comments -o {params.prefix} --cpu 10  --data_dir {input[1]} --override "

rule search_homology_GENOME:
    output:
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.emapper.seed_orthologs"),
    input:
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.prot_c95.faa"),
        os.path.join(db_dir, "eggNOG"),
    params:
        prefix = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}")#.emapper.seed_orthologs
    conda:
        os.path.join(CONDAENV, "eggnog.yaml")
    threads: 16
    shell:
        "emapper.py -m diamond --no_annot --no_file_comments --cpu {threads} -i {input[0]} -o {params.prefix} --data_dir {input[1]} --override "

def get_amrfinder_function_tables(wildcards):
    checkpoint_output = checkpoints.resolve_dereplicated_bins.get(**wildcards).output[0]
    genomes = expand(
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}_amrfinder.tsv"),
        bin=glob_wildcards(os.path.join(checkpoint_output, "{bin}.fa")).bin
        )
    return genomes

rule concate_amrfinder_tables:
    output:
        os.path.join("genomes_collection", "tables", "genomes_amrfinder_functions.tsv"),
    input:
        tables = get_amrfinder_function_tables,
    shell:
        "head -n 1 {input.tables[0]} > {output} && "
        "tail -q -n +2 {input.tables} >> {output} "

rule amrfinderplus:
    """
    Find acquired antimicrobial resistance genes and point mutations in protein and/or assembled nucleotide sequences.
    """
    output:
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}_amrfinder.tsv")
    input:
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.prot_c95.faa"),
        "tmp/amrfinderplus_db.download"
    conda:
        os.path.join(CONDAENV, "amrfinderplus.yaml")
    shell:
        'amrfinder --protein {input[0]} -o {output} '

rule amrfinderplus_db:
    """
    Download the latest version of the AMRFinderPlus database to the default location (location of the AMRFinderPlus binaries/data)
    """
    output:
        temp("tmp/amrfinderplus_db.download")
    conda:
        os.path.join(CONDAENV, "amrfinderplus.yaml")
    shell:
        'amrfinder -u && '
        'touch {output} ' 

def get_abricate_function_tables(wildcards):
    checkpoint_output = checkpoints.resolve_dereplicated_bins.get(**wildcards).output[0]
    genomes = expand(
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}_abricate.tsv"),
        bin=glob_wildcards(os.path.join(checkpoint_output, "{bin}.fa")).bin
        )
    return genomes

rule concate_abricate_tables:
    output:
        concate = os.path.join("genomes_collection", "tables", "genomes_abricate_functions.tsv"),
        summary = os.path.join("genomes_collection", "tables", "abricate_summary.tsv")
    input:
        tables = get_abricate_function_tables,
    conda:
        os.path.join(CONDAENV, "abricate.yaml")
    shell:
        "abricate --summary {input.tables} > {output.summary} && "
        "head -n 1 {input.tables[0]} > {output.concate} && "
        "tail -q -n +2 {input.tables} >> {output.concate} "

rule abricate:
    """
    Mass screening of contigs for antimicrobial resistance or virulence genes.
    """
    output:
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}_abricate.tsv")
    input:
        os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.nucl_c95.fnn"),
    conda:
        os.path.join(CONDAENV, "abricate.yaml")
    shell:
        'abricate {input} > {output} '

# representative set of genes
rule prot_95_GENOME:
    output:
        prot_95 = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.prot_c95.faa"),
        nucl_95 = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.nucl_c95.fnn"),
        lin_dir = temp(directory(os.path.join("tmp", "linclust_MAGS", "{bin}"))),
    input:
        prod_prot = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.prot.translation.faa"),
        prod_nucl = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.prot.nucleotides.fnn"),
    conda:
        os.path.join(CONDAENV, "linclust.yaml")
    threads: 10
    params:
        tmp = os.path.join(tmp, "linclust_MAGS"),
        tmpDB = os.path.join(tmp, "linclust_MAGS", "{bin}", "DB"),
        tmpDB_clust = os.path.join(tmp, "linclust_MAGS", "{bin}", "DB_clust"),
        tmpDB_rep = os.path.join(tmp, "linclust_MAGS", "{bin}", "DB_rep"),
        tmpfolder = os.path.join(tmp, "linclust_MAGS", "{bin}", "temp"),
        seqid=0.95,
    shell:
        "mkdir -p {params.tmp}/{wildcards.bin} && "
        "mmseqs createdb {input.prod_prot} {params.tmpDB} && "
        "mmseqs linclust {params.tmpDB} {params.tmpDB_clust} {params.tmpfolder} --min-seq-id {params.seqid} --threads {threads} && "
        "mmseqs createsubdb {params.tmpDB_clust} {params.tmpDB} {params.tmpDB_rep} && "
        "mmseqs convert2fasta {params.tmpDB_rep} {output.prot_95} && "
        "rm -r {params.tmp}/{wildcards.bin}/* && "
        "mmseqs createdb {input.prod_nucl} {params.tmpDB} && "
        "mmseqs linclust {params.tmpDB} {params.tmpDB_clust} {params.tmpfolder} --min-seq-id {params.seqid} --threads {threads} && "
        "mmseqs createsubdb {params.tmpDB_clust} {params.tmpDB} {params.tmpDB_rep} && "
        "mmseqs convert2fasta {params.tmpDB_rep} {output.nucl_95};"

## CDS_search
rule fetch_CDS_GENOME:
    """
    Summon Prodigal to perform CDS search in genomes.
    """
    output:
        coordinates = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.coords.gff"),
        nucleo = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.prot.nucleotides.fnn"),
        transl = os.path.join("genomes_collection", "annotations", "functions", "{bin}", "{bin}.prot.translation.faa"),
    input:
        os.path.join("tmp", "annotation_genomes", "{bin}.fa"),
    conda:
        os.path.join(CONDAENV, "prodigal.yaml")
    shell:
        'prodigal -i {input} '
        '-o {output.coordinates} '
        '-f gff '
        '-d {output.nucleo} '   # nucleotide sequence
        '-a {output.transl} '   # predicted proteins (used in next step)

def aggregate_genomad_summary(wildcards):
    checkpoint_output = checkpoints.resolve_dereplicated_bins.get(**wildcards).output[0]
    genomad_summary = expand(
        os.path.join(working_dir, "genomes_collection/annotations/genomad/{bin}/{bin}_summary/{bin}_virus_summary.tsv"),
        bin=glob_wildcards(os.path.join(checkpoint_output, "{bin}.fa")).bin
        )
    return genomad_summary

rule aggregate_genomad:
    output:
        temp(os.path.join("tmp", "genomad_bin.done")),            
    input:
        aggregate_genomad_summary
    shell:
        'touch {output} '

rule genomad_on_bin:
    output:
        os.path.join(working_dir, "genomes_collection/annotations/genomad/{bin}/{bin}_aggregated_classification/{bin}_aggregated_classification.tsv"),
        os.path.join(working_dir, "genomes_collection/annotations/genomad/{bin}/{bin}_annotate/{bin}_taxonomy.tsv"),
        os.path.join(working_dir, "genomes_collection/annotations/genomad/{bin}/{bin}_find_proviruses/{bin}_provirus_genes.tsv"),
        os.path.join(working_dir, "genomes_collection/annotations/genomad/{bin}/{bin}_marker_classification/{bin}_marker_classification.tsv"),
        os.path.join(working_dir, "genomes_collection/annotations/genomad/{bin}/{bin}_nn_classification/{bin}_nn_classification.tsv"),
        os.path.join(working_dir, "genomes_collection/annotations/genomad/{bin}/{bin}_summary/{bin}_virus_summary.tsv"),            
    input:
        os.path.join("tmp", "annotation_genomes", "{bin}.fa"),
        os.path.join(db_dir, "genomad_db/genomad_db"),
    params:
        db = os.path.join(db_dir, "genomad_db"),
        output_dir = os.path.join(working_dir, "genomes_collection/annotations/genomad/{bin}"),
    conda:
        os.path.join(CONDAENV, "genomad.yaml")
    shell:
        'mkdir -p {params.output_dir} && '
        'genomad end-to-end --cleanup {input[0]} {params.output_dir} {params.db} '

checkpoint resolve_dereplicated_bins:
    output:
        temp(directory("tmp/annotation_genomes")),
    input:
        "tmp/genomes_collection_finished.checkpoint"
    params:
        wgs = os.path.join("genomes_collection", "dereplicated_at_" + str(annotation_level)),
    shell:
        "mkdir -p {output} && "
        "cp {params.wgs}/* {output} "

####################################################################################################
#
#    perform GTDBTK annotation
#    Taxonomic annotation of genomes at a user define level
#
####################################################################################################

rule target_mags_taxo:
    """
    Create temporary .done file to check GTDB-tk step has well finished.
    """
    output:
        temp("tmp/mags_taxonomic_annotation.done"),
    input:
        os.path.join("genomes_collection", "tables", "gtdbtk.ar53.bac120.summary.tsv"),
    shell:
        "touch {output} "

rule concatenate_archaea_and_bacteria_results:
    """
    Concatenate .tsv output files from GTDB-tk for bacteria and archaea taxa
    into one unique file.
    """
    output:
        os.path.join("genomes_collection", "tables", "gtdbtk.ar53.bac120.summary.tsv")
    input:
        bac = os.path.join("genomes_collection", "annotations", "taxonomy", \
                     "classify", "gtdbtk.bac120.summary.tsv"),
        ar = os.path.join("genomes_collection", "annotations", "taxonomy", \
                     "classify", "gtdbtk.ar53.summary.tsv"),
    shell:
        "cat {input.ar} "                # keep header from first input
        "<(tail -q -n +2 {input.bac}) "  # remove header from other input
        "> {output} "

rule GTDBTK_Classify:
    """ GTDBTk classify step on a given batch of genomes.

    From documentation:
        Finally, the classify step uses pplacer to find the maximum-likelihood
        placement of each genome in the GTDB-Tk reference tree. GTDB-Tk
        classifies each genome based on its placement in the reference tree,
        its relative evolutionary divergence, and/or average nucleotide
        identity (ANI) to reference genomes.
    """
    output:
        os.path.join("genomes_collection", "annotations", "taxonomy",
                     "classify", "gtdbtk.ar53.summary.tsv"),
        os.path.join("genomes_collection", "annotations", "taxonomy",
                     "classify", "gtdbtk.ar53.classify.tree"),
        os.path.join("genomes_collection", "annotations", "taxonomy",
                     "classify", "gtdbtk.bac120.summary.tsv"),
        os.path.join("genomes_collection", "annotations", "taxonomy",
                     "classify", "gtdbtk.bac120.classify.tree"),
    input:
        ##align
        os.path.join("genomes_collection", "annotations", "taxonomy","align", "gtdbtk.ar53.msa.fasta.gz"),
        os.path.join("genomes_collection", "annotations", "taxonomy","align", "gtdbtk.ar53.user_msa.fasta.gz"),
        os.path.join("genomes_collection", "annotations", "taxonomy","align", "gtdbtk.bac120.msa.fasta.gz"),
        os.path.join("genomes_collection", "annotations", "taxonomy","align", "gtdbtk.bac120.user_msa.fasta.gz"),
        os.path.join("genomes_collection", "annotations", "taxonomy","align", "gtdbtk.bac120.filtered.tsv"),
        os.path.join("genomes_collection", "annotations", "taxonomy","align", "gtdbtk.ar53.filtered.tsv"),
        ##identify
        os.path.join("genomes_collection", "annotations", "taxonomy","identify", "gtdbtk.bac120.markers_summary.tsv"),
        os.path.join("genomes_collection", "annotations", "taxonomy","identify", "gtdbtk.translation_table_summary.tsv"),
        os.path.join("genomes_collection", "annotations", "taxonomy","identify", "gtdbtk.ar53.markers_summary.tsv"),
    conda:
        os.path.join(CONDAENV, "gtdbtk.yaml")
    threads: 20
    log:
        os.path.join("logs", "genomes_collection", "taxonomy",  "gtdbtk_classify_all.log")
    params:
        all_dir = os.path.join("genomes_collection", "annotations", "taxonomy"),
        genomes_dir = os.path.join("genomes_collection", "dereplicated_at_0.99"),
        gtdbtk_data = os.path.abspath(os.path.join(db_dir, "release220")),
    shell:
        "find genomes_collection/annotations/taxonomy/align -type f -size -2k -name '*.gz' -exec rm -v {{}} \; " #remove fasta.gz files that are empty (ar53 most of the time)
        "&& export GTDBTK_DATA_PATH={params.gtdbtk_data}; "
        "gtdbtk classify "
        "--genome_dir {params.genomes_dir} "
        "--extension fa "
        "--skip_ani_screen "
        "--align_dir {params.all_dir} "
        "--out_dir {params.all_dir} "
        "--cpus {threads} "
        "&>> {log} "
        "&& touch {output} "


def aggregate_tsv_identify(wildcards):
    checkpoint_output = checkpoints.symlink_dereplicated_genomes_into_batches.get(**wildcards).output[0]
    tables = expand(os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "identify", "gtdbtk.{{resultfile}}.tsv"), batch=glob_wildcards(os.path.join(checkpoint_output, "batch_{batch}/genomes")).batch)
    return tables

rule merge_batches_tsv_identify:
    output:
        os.path.join("genomes_collection", "annotations", "taxonomy", "identify", "gtdbtk.{resultfile}.tsv"),
    input:
        aggregate_tsv_identify,
    shell:
        "head -n 1 {input[0]} > {output} && "
        "tail -q -n +2 {input} >> {output} "

def aggregate_tsv_align(wildcards):
    checkpoint_output = checkpoints.symlink_dereplicated_genomes_into_batches.get(**wildcards).output[0]
    tables = expand(os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "align", "gtdbtk.{{resultfile}}.tsv"), batch=glob_wildcards(os.path.join(checkpoint_output, "batch_{batch}/genomes")).batch)
    return tables

rule merge_batches_tsv_align:
    output:
        os.path.join("genomes_collection", "annotations", "taxonomy", "align", "gtdbtk.{resultfile}.tsv"),
    input:
        aggregate_tsv_align,
    shell:
        "head -n 1 {input[0]} > {output} && "
        " tail -q -n +2 {input} >> {output} "

def aggregate_fastagz_align(wildcards):
    checkpoint_output = checkpoints.symlink_dereplicated_genomes_into_batches.get(**wildcards).output[0]
    tables = expand(os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "align", "gtdbtk.{{resultfile}}.fasta.gz"), batch=glob_wildcards(os.path.join(checkpoint_output, "batch_{batch}/genomes")).batch)
    return tables

rule merge_batches_fastagz_align:
    """
    Merge fasta files from previous steps, so we only have
    one tree for all the genomes (and not one tree per batch) after
    the Classify step.
    """
    output:
        os.path.join("genomes_collection", "annotations", "taxonomy", "align", "gtdbtk.{resultfile}.fasta.gz"),
    input:
        aggregate_fastagz_align,
    params:
        intermediate = os.path.join("genomes_collection", "annotations", "taxonomy", "align", "gtdbtk.{resultfile}.fasta"),
    shell:
        "zcat {input} --force > {params.intermediate} "
        "&& gzip {params.intermediate} "

rule GTDBTK_Align:
  """  
    GTDBTk align step on a given batch of genomes.

    From documentation:
        The align step concatenates the aligned marker genes and filters the
        concatenated Multiple Sequence Alignments (MSA) to approximately
        5, 000 amino acids.
    """
    output:
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "align", "gtdbtk.ar53.msa.fasta.gz"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "align", "gtdbtk.ar53.user_msa.fasta.gz"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "align", "gtdbtk.ar53.filtered.tsv"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "align", "gtdbtk.bac120.msa.fasta.gz"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "align", "gtdbtk.bac120.user_msa.fasta.gz"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "align", "gtdbtk.bac120.filtered.tsv"),
    input:
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "identify", "gtdbtk.bac120.markers_summary.tsv"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "identify", "gtdbtk.translation_table_summary.tsv"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "identify", "gtdbtk.ar53.markers_summary.tsv"),
    conda:
        os.path.join(CONDAENV, "gtdbtk.yaml")
    threads: 5
    log:
        os.path.join("logs", "genomes_collection", "taxonomy", "gtdbtk_align_{batch}.log")
    params:
        batch_dir = os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK"),
        gtdbtk_data = os.path.abspath(os.path.join(db_dir, "release220")),
    shell:
        "export GTDBTK_DATA_PATH={params.gtdbtk_data}; "
        "gtdbtk align "
        "--identify_dir {params.batch_dir} "
        "--out_dir {params.batch_dir} "
        "--cpus {threads} "
        "&>> {log} "
        "&& touch {output} " #touch output to avoid missing file i.e ar53 missing files in most cases

def resolve_GTDBTK_batch_content(wildcards):
    checkpoint_output = checkpoints.symlink_dereplicated_genomes_into_batches.get(**wildcards).output[0]
    directory = "tmp/GTDBTK_batches/batch_{}/genomes".format(wildcards.batch)
    return [os.path.join(directory, file) for file in os.listdir(directory)]

rule GTDBTK_Identify:
    """ GTDBTk identify step on a given batch of genomes.

    From documentation:
        The identify step calls genes using Prodigal, and uses HMM models and
        the HMMER package to identify the 120 bacterial and 53 archaeal marker
        genes used for phylogenetic inference. Multiple sequence alignments
        (MSA) are obtained by aligning marker genes to their respective HMM
        model.
    """
    output:
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "identify", "gtdbtk.bac120.markers_summary.tsv"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "identify", "gtdbtk.translation_table_summary.tsv"),
        os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK", "identify", "gtdbtk.ar53.markers_summary.tsv"),
    input:
        resolve_GTDBTK_batch_content,
        os.path.join(db_dir, "gtdbtk.downloaded"),
    conda:
        os.path.join(CONDAENV, "gtdbtk.yaml")
    threads: 10
    params:
        genome_dir = os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "genomes"),
        batch_dir = os.path.join(tmp, "GTDBTK_batches", "batch_{batch}", "GTDBTK"),
        gtdbtk_data = os.path.abspath(os.path.join(db_dir, "release220")),
    log:
        os.path.join("logs", "genomes_collection", "taxonomy", "gtdbtk_identify_{batch}.log")
    shell:
        "export GTDBTK_DATA_PATH={params.gtdbtk_data}; "
        "gtdbtk identify "
        "--genome_dir {params.genome_dir} "
        "--extension fa "
        "--out_dir {params.batch_dir} "
        "--cpus {threads} "
        "&>> {log} "

checkpoint symlink_dereplicated_genomes_into_batches:
    """
    Symlink a fasta.gz into the appropriate batch directory.
    """
    output:
        temp(directory(os.path.join(tmp, "GTDBTK_batches"))),
    input:
        sample2genome = "tmp/genomes_collection_finished.checkpoint"
    params:
        genomes_dir = "genomes_collection/dereplicated_at_{}/".format(str(annotation_level)),
        batch_size = GTDBTK_batch_size,
    run:
        import os
        import shutil
        os.makedirs(str(output))
        genomes = [os.path.join(params.genomes_dir, file) for file in os.listdir(params.genomes_dir) if file.endswith(".fa")]
        batches = [genomes[x:x + params.batch_size] for x in range(0, len(genomes), params.batch_size)]
        for i in range(0, len(batches)):
            destination = os.path.join(str(output), "batch_" + str(i), "genomes")
            os.makedirs(destination)
            for genome in batches[i]:
                shutil.copy(genome, os.path.join(destination, os.path.basename(genome)))
        directory = "tmp/GTDBTK_batches/batch_0/genomes"

####################################################################################################
#
#    above : GTDBTK & eggNOGG mapper annotations
#    target read back mapping against genomes collection
#    below : read-back-mapping , dreplication , QC
#
####################################################################################################

rule genomes_collection:
    output:
        temp("tmp/genomes_collection_finished.checkpoint")
    input:
        "tmp/checkm2_by_batches.done",       #checkm extra tables
        "tmp/sample2genomes.checkpoint",    #map sample against genomes
        "tmp/dereplication_by_batches.done"
    shell:
        "touch {output} "

rule target_sample2genomes:
    output:
        temp("tmp/sample2genomes.checkpoint")
    input:
        expand_sample2genomes,
    shell:
        "touch {output}"

rule sample2genomes_sort_and_index_post_filtering:
    """
    Perform 2 modifications on the mapping results using samtools:
        - sort BAM file by position
        - index BAM file for later easy read extraction
    """
    output:
        bam = os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.filtered.sorted.bam"),
        bai = os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.filtered.sorted.bam.bai"),
    input:
        os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "filtered_{src}##genomes_{level}.sorted.bam"),
    threads: 5
    conda:
        os.path.join(CONDAENV, "samtools.yaml")
    log:
        os.path.join("logs", "genomes_collection", "mapping", "{assembly}", "sorting_{src}_{level}.filtered.log"),
    shell:
        "samtools view -u {input} | "
        "samtools sort "
        "-@ {threads} "             # number of threads used
        "-o {output.bam} "
        "&& samtools index "
        "{output.bam} &> {log}"


rule quality_filter_reads:
    """
    Filter reads based on mapping quality and identity.
    Output is temporary because it will be sorted.
    """
    output:
        os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "filtered_{src}##genomes_{level}.sorted.bam"),
    input:
        os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.sorted.bam"),
    log:
        os.path.join("logs", "genomes_collection", "mapping", "{assembly}", "{src}_{level}.filtering.log")
    conda:
        os.path.join(CONDAENV, "bamutils.yaml")
    params:
        min_mapq = config["sample2genomes"]["min_quality"],
        min_idt = config["sample2genomes"]["min_identity"],
        min_len = config["sample2genomes"]["min_len"],
        pp = config["sample2genomes"]["properly_paired"],
    script:
        "../scripts/bamprocess.py"


rule sample2genomes_sort_and_index:
    """Perform 2 modifications on the mapping results using samtools:
        - sort BAM file by position
        - index BAM file for later easy read extraction
    """
    output:
        bam = os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.sorted.bam"),
        bai = os.path.join("genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.sorted.bam.bai"),
    input:
        os.path.join(tmp, "sample2genomes", "{assembly}", "{src}##genomes_{level}.sam"),
    threads: 5
    conda:
        os.path.join(CONDAENV, "samtools.yaml")
    log:
        os.path.join("logs", "genomes_collection", "mapping", "{assembly}", "sorting_{src}_{level}.log"),
    shell:
        "samtools view -u {input} | "
        "samtools sort "
        "-@ {threads} "             # number of threads used
        "-o {output.bam} "
        "&& samtools index "
        "{output.bam} &> {log}"
"""
rule sample2genomes_sort_and_index_coa:
    Perform 2 modifications on the mapping results using samtools:
        - sort BAM file by position
        - index BAM file for later easy read extraction
    
    output:
        bam = os.path.join("genomes_collection", "mapping", "bam", "cluster_{number}", "cluster_{number}##genomes_{level}.sorted.bam"),
        bai = os.path.join("genomes_collection", "mapping", "bam", "cluster_{number}", "cluster_{number}##genomes_{level}.sorted.bam.bai"),
    input:
        os.path.join(tmp, "sample2genomes", "co_assembly", "cluster_{number}##genomes_{level}.sam"),
    threads: 5
    conda:
        os.path.join(CONDAENV, "samtools.yaml")
    log:
        os.path.join("logs", "genomes_collection", "mapping", "sorting_cluster_{number}_{level}.log"),
    shell:
        "samtools view -u {input} | "
        "samtools sort "
        "-@ {threads} "             # number of threads used
        "-o {output.bam} "
        "&& samtools index "
        "{output.bam} &> {log}"
"""
rule sample2genomes_mapping:
    """
    Map the reads from a run against an index of genomes.
    Output is temporary because it will be merged with other
    runs from this biological sample and compressed into BAM.
    """

    output:
        temp(os.path.join(tmp, "sample2genomes", "{assembly}", "{src}##genomes_{level}.sam"))
    input:
        index= expand(os.path.join("genomes_collection", "mapping", "index{{level}}", "all_genomes.{id}.bt2l"), id=range(1, 5) ),
        reads = input_cmd,
    log:
        os.path.join("logs", "genomes_collection", "mapping", "{assembly}", "{src}##{level}.log")
    threads: 14
    priority: 80
    conda:
        os.path.join(CONDAENV, "bowtie2.yaml")
    params:
        dbbasename = lambda wildcards, input: input[0][:-7],
        cmd = lambda wildcards, input : cmdparser.cmd(wildcards.src, input.reads, reads2use, "bowtie2").cmd # wildcards ==  sample in this case
    shell:
        "bowtie2 "
        "-p {threads} "             # number of parallel threads
        "--no-unal "                # remove unmapped reads (decrease size)
        "-x {params.dbbasename} "   # index for mapping
        "{params.cmd} "
        "-S {output} "
        "2> {log} "
"""
rule sample2genomes_mapping_coa:
    
    Map the reads from a run against an index of genomes.
    Output is temporary because it will be merged with other
    runs from this biological sample and compressed into BAM.
    

    output:
        temp(os.path.join(tmp, "sample2genomes", "co_assembly", "cluster_{number}##genomes_{level}.sam"))
    input:
        index= expand(os.path.join("genomes_collection", "mapping", "index{{level}}", "all_genomes.{id}.bt2l"), id=range(1, 5) ),
        reads = input_cmd_coa,
    log:
        os.path.join("logs", "genomes_collection", "mapping", "cluster_{number}##{level}.log")
    threads: 14
    priority: 80
    conda:
        os.path.join(CONDAENV, "bowtie2.yaml")
    params:
        dbbasename = lambda wildcards, input: input[0][:-7],
        cmd = lambda wildcards, input : cmdparser.cmd(wildcards.sample, input.reads, reads2use, "bowtie2").cmd # wildcards ==  sample in this case
    shell:
        "bowtie2 "
        "-p {threads} "             # number of parallel threads
        "--no-unal "                # remove unmapped reads (decrease size)
        "-x {params.dbbasename} "   # index for mapping
        "{params.cmd} "
        "-S {output} "
        "2> {log} "
"""
rule allgenomes_index:
    """
        Build a Bowtie2 index from a concatenated fasta.gz file.
    """
    output:
        expand(os.path.join("genomes_collection", "mapping", "index{{level}}", "all_genomes.{id}.bt2l"),
                         id=range(1, 5)),
        expand(os.path.join("genomes_collection", "mapping", "index{{level}}", "all_genomes.rev.{id}.bt2l"),
                         id=range(1, 3)),
    input:
        os.path.join("genomes_collection", "all_genomes.{level}.fa.gz"),
    log:
        os.path.join("logs", "genomes_collection", "genome_collections_bt2{level}_index", "all_genomes.indexing.log")
    threads: 10
    conda:
        os.path.join(CONDAENV, "bowtie2.yaml")
    params:
        prefix = lambda wildcards, output: output[0][:-7],
    shell:
        "echo {params.prefix} ;"
        "bowtie2-build "
        "--large-index "                # Force to build a 64-bit index
        "--threads {threads} "          # Number of parallel threads
        "{input} "                      # List of genomes in a single gz
        "{params.prefix} "              # Basename of the database
        "&> {log}"


def get_genome_collection(wildcards):
    checkpoint_output = checkpoints.move_dereplicated_genomes.get(**wildcards).output[0]
    return expand(os.path.join("genomes_collection", "dereplicated_at_{level}", "{genome}.fa"),
     level=wildcards.level,
     genome=glob_wildcards(os.path.join(checkpoint_output, "{genome}.fa")).genome)


rule concat_genomes:
    output:
        os.path.join("genomes_collection", "all_genomes.{level}.fa.gz"),
    input:
        "genomes_collection/dereplicated_at_{level}/",
        get_genome_collection
    params:
        wgs = get_genome_collection,
    shell:
        "cat {params.wgs} >> {output}"


####################################################################################################
####################################################################################################
####################################################################################################
'''
    Above : annotation and abundance
    Final target for genome collection. At this time, we have dereplicated genomes at {level}
    Below : QC genomes and dereplication
'''
####################################################################################################
####################################################################################################
####################################################################################################


def aggregate_genomes_collections(wildcards):
    checkpoint_output = checkpoints.move_dereplicated_genomes.get(**wildcards).output[0]
    return expand(os.path.join("genomes_collection/dereplicated_at_{{level}}", "{bin}.fa"),
        bin=glob_wildcards(os.path.join(checkpoint_output, "{bin}.fa")).bin)

rule expand_target_dereplication:
    output:
        "tmp/dereplication.done"
    input:
        expand("tmp/dereplicated_at_{level}.done", level=dRep_levels),
    shell:
        "touch {output}"


rule target_dereplication:
    output:
        "tmp/dereplicated_at_{level}.done",
    input:
        aggregate_genomes_collections
    shell:
        "touch {output} "

checkpoint move_dereplicated_genomes:
    output:
        directory("genomes_collection/dereplicated_at_{level}"),
    input:
        expand(os.path.join("genomes_collection", "data", "drep", "drep{{level}}", "data_tables", "{table}.csv"), table=dRep_tables),
    params:
        bins = "genomes_collection/data/drep/drep{level}/dereplicated_genomes/",
    shell:
        'mkdir -p  {output} && '
        'cp {params.bins}* {output} && '
        #   'rm -r {params.bins} ; '#&& touch {output[1]} ;'
        'for i in $(ls {output}); '
        'do id=$(basename $i | rev | cut -d"." -f2-  | rev); ' #get all filename field except last one delimited by "."
        'perl -i -p -e "s/>/>$id|/" {output}/$i; '                      #Use perl because more portable than sed
        'perl -i -p -e "s/$/\t$id/ if />.*/" {output}$i; '
        'done '

rule gzip_allmags:
    output:
        os.path.join("intermediate_results", "binning", f"{bins_collection}.tar.gz"),
    input:
        dir = os.path.join("intermediate_results", f"binning/{bins_collection}"),
    shell:
        "tar -zcvf {output} {input.dir} && rm -r {input.dir}"

def aggregate_final_output(wildcards):
    """
    Return all mags recovered during binning.snakefile subworkflow.
    """
    checkpoint_output = checkpoints.binning.get(**wildcards).output[0]
    mags = config["Binning"]["filter_mags"][0]
    return expand(os.path.join("intermediate_results", "binning", mags, "{binsid}.fa"),
        binsid=glob_wildcards(os.path.join(checkpoint_output, "{binsid}.fa")).binsid)


####################################################################################################
#
#    genomes dereplication by batches
#
####################################################################################################

rule target_dereplication_by_batches:
    """
    Create temporary .done file to check dRep step has well finished.
    """
    output:
        temp("tmp/dereplication_by_batches.done")
    input:
        expand(os.path.join("genomes_collection", "data", "drep", "drep{level}", "data_tables", "{table}.csv"), table=dRep_tables, level=dRep_levels),
    shell:
        "touch {output} "

rule dRep_final:
    output:
        expand(os.path.join("genomes_collection", "data", "drep", "drep{{level}}", "data_tables", "{table}.csv"), table=dRep_tables),
        directory("genomes_collection/data/drep/drep{level}/dereplicated_genomes/"),
    input:
        location = os.path.join("genomes_collection", "data", "drep", "drep{level}_input.csv"),
        quality = os.path.join("genomes_collection", "data", "checkM2", "checkM2_statistics_simplified.csv")
    params:
        outdir = os.path.join("genomes_collection", "data", "drep", "drep{level}"),
        completeness = config["clustering_parameters"]["completeness"],
        contamination = config["clustering_parameters"]["contamination"],
        mash_ani = "{level}",  #config["clustering_parameters"]["mash_ani"],
        nucmer_ani = "{level}", #config["clustering_parameters"]["nucmer_ani"],
        align_frac = config["clustering_parameters"]["min_aligned_fraction"],
    threads: 10
    benchmark:
        os.path.join("benchmark", "dRep_batches", "benchmark{level}.txt")
    conda: os.path.join(CONDAENV, "drep_2020.yaml")
    shell:
        "dRep dereplicate {params.outdir} "
        "--genomeInfo {input.quality} "
        "-p {threads} "
        "-comp {params.completeness} "
        "-con {params.contamination} "
        "-pa {params.mash_ani} "
        "-g {input.location} "
        "-sa {params.nucmer_ani} "
        "-nc {params.align_frac} "

def aggregate_drep_batch(wildcards):
    checkpoint_output = checkpoints.genomes2batches.get(**wildcards).output[0]
    dreplicated_genomes = []
    for batch in glob_wildcards(os.path.join(checkpoint_output, "batch_{batch}/summary.txt")).batch:
        directory = "tmp/batches/batch_{}/drep{}/dereplicated_genomes".format(batch, wildcards.level)
        dreplicated_genomes.append(directory)
    return dreplicated_genomes

def aggregate_drep_batch_table(wildcards):
    checkpoint_output = checkpoints.genomes2batches.get(**wildcards).output[0]
    lvl = wildcards.level
    tables = expand(
        os.path.join(tmp, "batches", "batch_{batch}", "drep{level}", "data_tables", "{table}.csv"),
        batch=glob_wildcards(os.path.join(checkpoint_output, "batch_{batch}/summary.txt")).batch,
        table=dRep_tables, level=lvl)
    return tables

rule dRep_final_genBdb:
    """
    Generate a csv containing all genome location to avoid "Argument list
    too long " in the following dRep command.
    See: https://github.com/MrOlm/drep/issues/42
    Since dRep 2.5, a genome list might be pass to dRep. This input correspond to a file
    with a path/to/genome per line.
    """
    output:
        os.path.join("genomes_collection", "data", "drep", "drep{level}_input.csv"),
    input:
        aggregate_drep_batch_table
    params:
        BinsList = aggregate_drep_batch,
    run:
        import csv
        with open(output[0], 'w') as outfile:
            for repo in params.BinsList:
                for infile in os.listdir(repo):
                    outfile.write(os.path.join(repo, infile) + "\n")

rule dRep:
    """
    Perform dereplication of MAGs independently reconstructed.
    Use dRep: check if two genomes are duplicated
    based on pairwise sequence identity.
    """
    output:
        expand(os.path.join(tmp, "batches", "batch_{{batch}}", "drep{{level}}", "data_tables", "{table}.csv"), table=dRep_tables),
    input:
        location = os.path.join(tmp, "batches", "batch_{batch}", "summary.txt"),
        quality = os.path.join("genomes_collection", "data", "checkM2", "checkM2_statistics_simplified.csv")
    params:
        outdir = os.path.join(tmp, "batches", "batch_{batch}", "drep{level}"),
        completeness = config["clustering_parameters"]["completeness"],
        contamination = config["clustering_parameters"]["contamination"],
        mash_ani = config["clustering_parameters"]["mash_ani"],
        nucmer_ani = "{level}", #config["clustering_parameters"]["nucmer_ani"]
        align_frac = config["clustering_parameters"]["min_aligned_fraction"],
    threads: 20
    log:
        os.path.join("logs", "genomes_collection", "dereplication{level}", "{batch}.log")
    conda: os.path.join(CONDAENV, "drep_2020.yaml")
    benchmark:
        os.path.join("benchmark", "dRep_batches", "batch", "benchmark{level}_{batch}.txt")
    shell:
        "dRep dereplicate {params.outdir} "
        "--genomeInfo {input.quality} "
        "-p {threads} "
        "-comp {params.completeness} "
        "-con {params.contamination} "
        "-pa {params.mash_ani} "
        "-sa {params.nucmer_ani} "
        "-g {input.location} "
        "-nc {params.align_frac} &> {log}"



# dRep input is produced below, under the rule "genomes2batches"

rule dRep_genQuality:
    """
    Generate a csv containing all genome qualities to avoid rerunning
    checkM.
    """
    output:
        os.path.join("genomes_collection", "data", "checkM2", "checkM2_statistics_simplified.csv")
    input:
        "tmp/checkm2_by_batches.done"
    params:
        os.path.join("genomes_collection", "data", "checkM2", "checkM2_quality_report.tsv")
    run:
        import pandas as pd
        df = pd.read_csv(params[0], sep='\t', usecols=['Name', 'Completeness', 'Contamination'])
        df.columns = ['genome', 'completeness', 'contamination']
        df['genome'] = df.genome + '.fa'
        df.to_csv(output[0], sep=',', index= False)

####################################################################################################
#
#    genomes qualities using checkM
#
####################################################################################################

rule target_checkm2_statistics:
    """
    Check if checkM2 step is well finished.
    """
    output:
        temp("tmp/checkm2_by_batches.done")
    input:
        os.path.join("genomes_collection", "data", "checkM2", "checkM2_quality_report.tsv"),
    shell:
        "touch {output}"

def aggregate_batches_checkm2(wildcards):
    """
    Return list of tables generated by drep.
    Used as input of rule concat_tables.
    """
    checkpoint_output = checkpoints.genomes2batches.get(**wildcards).output[0]
    return expand(os.path.join(tmp, "batches", "batch_{batch}", "checkm2", "quality_report.tsv"), batch=glob_wildcards(os.path.join(checkpoint_output, "batch_{batch}/summary.txt")).batch)

rule concat_tables_checkm2:
    """
    Concatenate checkM output matrices from each batch into one unique matrix.
    """
    output:
        os.path.join("genomes_collection", "data", "checkM2", "checkM2_quality_report.tsv")
    input:
        aggregate_batches_checkm2,
    params:
        first = lambda wildcards, input: input[0],
        others = lambda wildcards, input: input[1:],
    shell:
        "head -n 1 {input[0]} > {output} && "
        "tail -q -n +2 {input} >> {output} "

def resolve_batch_content(wildcards):
    """
    Return list of batch directories.
    Used as input of rule checkM.
    """
    checkpoint_output = checkpoints.genomes2batches.get(**wildcards).output[0]
    directory = "tmp/batches/batch_{}/genomes".format(wildcards.batch)
    return [os.path.join(directory, file) for file in os.listdir(directory)]

rule checkM2:
    """
    Main checkM2 rule.

    Compute CheckM on a given batch of bins.
    """
    output:
        tsv = os.path.join(tmp, "batches", "batch_{batch}", "checkm2", "quality_report.tsv"),
    input:
        # Function that returns the locations of genomes associated with {batch}
        resolve_batch_content,
        os.path.join(db_dir, "checkm2_data.dl")
    conda:
        os.path.join(CONDAENV, "checkm2.yaml")
    threads: 5
    params:
        tmp_bin_dir = os.path.join(tmp, "batches", "batch_{batch}", "genomes"),
        checkm_dir = lambda wildcards, output: os.path.dirname(output.tsv),
        checkm_data = os.path.join(db_dir, "CheckM2_database/uniref100.KO.1.dmnd"),
    log:
        os.path.join("logs", "CheckM2-Results", "checkM2_{batch}.log")
    shell:
        "checkm2 predict " 
        "-t {threads} "
        "-x fa "             # extension of the bin files
        "-i {params.tmp_bin_dir} " # (input) directory containing the bin files
        "-o {params.checkm_dir} "  # (output) directory where to store the results
        "--database_path {params.checkm_data} "
        "&>> {log} "

rule checkM2data:
    """
    Download checkM2 database to run checkM2.
    """
    output:
        checkm_data = directory(os.path.join(db_dir, "CheckM2_database")),
        checkout = os.path.join(db_dir, "checkm2_data.dl")
    conda:
        os.path.join(CONDAENV, "checkm2.yaml")
    params:
        db_storage = db_dir,
    shell:
        "mkdir -p {params.db_storage} && "
        "checkm2 database --download --path {params.db_storage} && "
        "touch {output.checkout} "


################################################################################
#    merge external db if existing with newly recovered MAGs
#    from {binning_strategy} and split all genomes into batches
#    to allow parallelization.
###############################################################################

checkpoint genomes2batches:
    """
    Split bins into batches to perform parallelization.
    """
    output:
        temp(directory("tmp/batches")),
    input:
        aggregate_final_output,
    params:
        batch_size = config["batch_size"],
        db = config["External_DB"],
    run:
        import os
        import shutil
        os.makedirs(str(output))
        #make a list of all genomes paths
        if os.path.isdir(str(input)):
            print("checkpoint failed - trying to resolve checkpoint output content ...")
            input = [os.path.join(str(input), mags) for mags in os.listdir(str(input)) if mags.endswith(".fa")]
        if params.db !=  "":
            db = [os.path.join(params.db, file) for file in os.listdir(params.db) if file.endswith(".fa")]
            genomes = db + input
        else:
            genomes = input
        batches = [genomes[x : x + params.batch_size] for x in range(0, len(genomes), params.batch_size)]
        print("Batches", batches)
        for i in range(0, len(batches)):
            destination = os.path.join(str(output), "batch_" + str(i), "genomes")
            os.makedirs(destination)
            summary = os.path.join(str(output), "batch_" + str(i), "summary.txt")
            f_summary = open(summary, 'w')
            for genome in batches[i]:
                genome_id = os.path.basename(genome)
                dest = os.path.join(destination, genome_id)
                #os.symlink(src = genome, dst = dest, target_is_directory = False)
                shutil.copy(genome, dest)
                f_summary.write(genome + "\n")
            f_summary.close()    