#!/usr/bin/env python3
# -*- coding: utf-8 -*-

####################################################################################################
#
#    Declaration
#
#
####################################################################################################
assembler = config["Assembly"]["Assembler"]

if config["Assembly"]["filter_assembly"]:
    assembly = "post_filtering.contigs.fa"
else:
    assembly = "pre_filtering.contigs.fa"

assembly_request = []
strategies = config["binning_strategies"]

if 'SASB' in strategies or 'SACB' in strategies:
    assembly_request.append('single_assembly')
if 'CASB' in strategies or 'CACB' in strategies:
    assembly_request.append('co_assembly')

def input_cmd(wildcards):
    list_reads = []
    #if "single_assembly" in assembly_request:
    for run in reads2use[wildcards.src]:
        list_reads.extend(reads2use[wildcards.src][run])
        return list_reads

coverm_contig_calcultation = config["coverm_contig"]["calculation_options"]

####################################################################################################
#
#    final target of genes collection
#
####################################################################################################

if config["genes_collection"]:
    rule target_Gene_catalogue:
        output:
            os.path.join(tmpdir, "gene_catalogue.checkpoint"),
        input:
            os.path.join(working_dir, "genes_collection/tables/genes_length.tsv"),
            os.path.join(working_dir, "genes_collection/tables/genes_functions.tsv"),
            os.path.join(working_dir, "genes_collection/tables/coverm_genes_abundance"),
            expand(os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_summary/{sample}_post_filtering.contigs_virus_summary.tsv"), sample=samples),
            expand(os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_aggregated_classification/{sample}_post_filtering.contigs_aggregated_classification.tsv"), sample=samples),
            expand(os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_annotate/{sample}_post_filtering.contigs_taxonomy.tsv"), sample=samples),
            expand(os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_find_proviruses/{sample}_post_filtering.contigs_provirus_genes.tsv"), sample=samples),
            expand(os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_marker_classification/{sample}_post_filtering.contigs_marker_classification.tsv"), sample=samples),
            expand(os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_nn_classification/{sample}_post_filtering.contigs_nn_classification.tsv"), sample=samples),
            expand(os.path.join(working_dir, "genes_collection/tables/genes_taxo.{db_request}.tsv"), db_request=config["genes_annotation"]["taxo_db"]),
            expand(os.path.join(working_dir, "genes_collection/tables/genes_taxo.{db_request}.classified.tsv"), db_request=config["genes_annotation"]["taxo_db"]),
        priority:5
        shell:
            "touch {output}"


    ####################################################################################################
    #
    #    generate tables :
    #        - genes x functions
    #        - genes x taxo
    #        v- genes x length
    #        v- genes x sample , reads count edition
    #        v- genes x sample , bp covered edition
    #        v- genes x sample , relative abundance edition
    #
    ####################################################################################################

    rule coverm_genes_abundances:
        output:
            os.path.join(working_dir, "genes_collection/tables/coverm_genes_abundance"),
        input:
            bams = expand(os.path.join(working_dir, "genes_collection/mapping/bams/{sample}/{sample}##genes.sorted.bam"), sample=samples),
        params:
            method = coverm_contig_calcultation,
        conda:
            os.path.join(CONDAENV, "coverm.yaml")
        shell:
            "coverm contig --methods {params.method} --bam-files {input.bams} -o {output}"

    rule genes_length_table:
        output:
            os.path.join(working_dir, "genes_collection/tables/genes_length.tsv")
        input:
            os.path.join(working_dir, "genes_collection/all_prot.nucleotide.c95.fnn"),
        shell:
            "cat {input} | grep '>' | awk -F \"#\" '{{print $1\";\"$3-$2}}' | sed 's/>//g' | sed 's/ //g' > {output} "

    rule genes_functions_table:
        output:
            os.path.join(working_dir, "genes_collection/tables/genes_functions.tsv")
        input:
            os.path.join(working_dir, "genes_collection/annotations/functions/all_prot.nucleotide.c95.emapper.annotations"),
        shell:
            "cp {input} {output}"


    ####################################################################################################
    #
    #    Taxonomic annotation of proteins using mmseqs
    #
    ####################################################################################################


    rule genes_taxonomy:
        output:
            tsv_all = os.path.join(working_dir, "genes_collection/tables/genes_taxo.{db_request}.tsv"),
            tsv_class = os.path.join(working_dir, "genes_collection/tables/genes_taxo.{db_request}.classified.tsv"),
            tmp = temp(directory(os.path.join(tmpdir, "genes_collection/mmseq_taxo/{db_request}"))),
        input:
            prot = os.path.join(tmpdir, "genes_collection/mmseq_prot/DBprot"),
            ref = os.path.join(db_dir, "mmseqs_db/{db_request}/{db_request}"),
            dir_temp = directory(os.path.join(tmpdir, "genes_collection/mmseq_prot"))
        conda:
            os.path.join(CONDAENV, "linclust.yaml")
        threads: 10
        params:
            taxo = os.path.join(tmpdir, "genes_collection/mmseq_taxo/{db_request}/genes_taxo.{db_request}"),
        log:
            "logs/genes_collection/taxonomy/create_{db_request}.log"
        shell:
            "mkdir -p {output.tmp} && "
            "mmseqs taxonomy {input.prot} {input.ref} {params.taxo} {output.tmp}  && "
            "mmseqs createtsv {input.prot} {params.taxo} {output.tsv_all}  --threads {threads} && "
            "grep -v 'no rank' {output.tsv_all} > {output.tsv_class} "


    rule mmseqs_reference:
        """
            download and create a sequence database, then augment it with taxonomic information
        """
        output:
            ref = os.path.join(db_dir, "mmseqs_db/{db_request}/{db_request}"),
            tmp = temp(directory(os.path.join(db_dir, "mmseqs_db/{db_request}/tmp"))),
        conda:
            os.path.join(CONDAENV, "linclust.yaml")
        threads:
            10
        log:
            "logs/genes_collection/taxonomy/create_{db_request}.log"
        shell:
            "mmseqs databases {wildcards.db_request} {output.ref} {output.tmp} && "
            "mmseqs createtaxdb {output.ref} {output.tmp}"

    rule format_mmseqs_input:
        output:
            protdb = temp(os.path.join(tmpdir, "genes_collection/mmseq_prot/DBprot")),
            dir_temp = temp(directory(os.path.join(tmpdir, "genes_collection/mmseq_prot"))),
        input:
            os.path.join(working_dir, "genes_collection/all_prot.aa.c95.faa"),
        conda:
            os.path.join(CONDAENV, "linclust.yaml")
        log:
            "logs/genes_collection/taxonomy/format_input.log"
        shell:
            "mkdir -p {output.dir_temp} && mmseqs createdb {input[0]} {output.protdb} --dbtype 1"

    rule translate_95:
        """
            mmseq predefined DB (uniprot) accept protein sequence. Needed because i used prot.nucleotide.fa
            for clustering, functional annotation (--translate option) and read back mapping
        """
        output:
            os.path.join(working_dir, "genes_collection/all_prot.aa.c95.faa"),
        input:
            os.path.join(working_dir, "genes_collection/all_prot.nucleotide.c95.fnn"),
        conda:
            os.path.join(CONDAENV, "seqkit.yaml")
        shell:
            "seqkit translate {input} -w 0 -o {output}"


    ####################################################################################################
    #
    #    Functional annotation of proteins using eggNOGG mapper
    #
    ####################################################################################################


    rule search_orthology:
        output:
            os.path.join(working_dir, "genes_collection/annotations/functions/all_prot.nucleotide.c95.emapper.annotations"),
        input:
            os.path.join(working_dir, "genes_collection/annotations/functions/all_prot.nucleotide.c95.emapper.seed_orthologs"),
            os.path.join(db_dir, "eggNOG/"),
        conda:
            os.path.join(CONDAENV, "eggnog.yaml")
        threads: 10
        params:
            prefix = os.path.join(working_dir, "genes_collection/annotations/functions/all_prot.nucleotide.c95")#emapper.annotations
        shell:
            "emapper.py --annotate_hits_table {input[0]} --no_file_comments -o {params.prefix} --cpu {threads}  --data_dir {input[1]} --override --dbmem "

    rule search_homology:
        output:
            os.path.join(working_dir, "genes_collection/annotations/functions/all_prot.nucleotide.c95.emapper.seed_orthologs"),
        input:
            os.path.join(working_dir, "genes_collection/all_prot.nucleotide.c95.fnn"),
            os.path.join(db_dir, "eggNOG/"),
        conda:
            os.path.join(CONDAENV, "eggnog.yaml")
        #log:
        #    "logs/genes_collection/eggNOGG/all_prot.nucleotide.c95.emapper.homology.log"
        threads: 18
        params:
            prefix = os.path.join(working_dir, "genes_collection/annotations/functions/all_prot.nucleotide.c95")#.emapper.seed_orthologs
        shell:
            "emapper.py -m diamond --itype CDS --translate --dmnd_ignore_warnings --no_annot --no_file_comments --cpu {threads} -i {input[0]} -o {params.prefix} --data_dir {input[1]} --override "

    ####################################################################################################
    #
    #    reads back mapping against all genes
    #
    ####################################################################################################

    rule gene_sort_and_index:
        """
        """
        output:
            os.path.join(working_dir, "genes_collection/mapping/bams/{sample}/{sample}##genes.sorted.bam"),
            os.path.join(working_dir, "genes_collection/mapping/bams/{sample}/{sample}##genes.sorted.bam.bai"),
        input:
            os.path.join(tmpdir, "single_assembly/{sample}##genes.sam"),
        log:
            "logs/genes_collection/mapping/{sample}_sorting_indexing.log"
        threads: 10
        priority: 100
        conda:
            os.path.join(CONDAENV, "samtools.yaml")
        priority:5
        shell:
            "samtools view -u {input} | "
            "samtools sort "
            "-@ {threads} "             # number of threads used
            "-o {output[0]} "
            "&& samtools index "
            "{output[0]} "
            "&> {log} "

    rule gene_bowtie2_map:
        """
            Map the reads against a representative set of genes from a single-assembly.
        """
        output:
            temp(os.path.join(tmpdir, "single_assembly/{src}##genes.sam"))
        input:
            index1 = expand(os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95.{id}.bt2l"), id=range(1, 4)),
            index2 = expand(os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95.rev.{id}.bt2l"), id=range(1, 2)),
            #reads = lambda wildcards: reads2use[wildcards.sample],
            reads = input_cmd,
        log:
            "logs/genes_collection/mapping/single_assembly_{src}##genes.log"
        threads: 10
        priority: 80
        conda:
            os.path.join(CONDAENV, "bowtie2.yaml")
        params:
            prefix = os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95"),
            cmd = lambda wildcards,input : cmdparser.cmd(wildcards.src, input.reads, reads2use, "bowtie2").cmd,
        shell:
            "bowtie2 "
            "-p {threads} "             # number of parallel threads
            "--no-unal "                # remove unmapped reads (decrease size)
            "-x {params.prefix} "       # index for mapping
            "{params.cmd} "
            "-S {output} "
            "&> {log} "

    rule gene_bowtie2_index:
        """
        Build a Bowtie2 index for contigs sample
        """

        output:
            expand(os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95.{id}.bt2l"), id=range(1, 4)),
            expand(os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95.rev.{id}.bt2l"), id=range(1, 2)),
        input:
            os.path.join(working_dir, "genes_collection/all_prot.nucleotide.c95.fnn"),
        log:
            "logs/genes_collection/mapping/index/all_prot.nucleotide.c95.indexing.log"
        threads: 10
        conda:
            os.path.join(CONDAENV, "bowtie2.yaml")
        params:
            prefix = os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95")
        priority:5
        shell:
            "bowtie2-build "
            "--large-index "
            "--threads {threads} "          # Number of parallel threads
            "{input} "                      # List of genomes in a single gz
            "{params.prefix} "              # Basename of the database
            "&> {log}"

    ####################################################################################################
    #
    #    Concatenate all protein sequence found within samples and cluster them at 95 identity with
    #    linclust (MMSEQ2)
    #
    ####################################################################################################

    rule gene_95:
        output:
            os.path.join(working_dir, "genes_collection/all_prot.nucleotide.c95.fnn"),
            temp(directory(os.path.join(tmpdir, "genes_collection/linclust.nucleotide.fnn"))),
        input:
            os.path.join(working_dir, "genes_collection/all_prot.nucleotide.fnn"),
        conda:
            os.path.join(CONDAENV, "linclust.yaml")
        threads: 10
        params:
            tmp = os.path.join(tmpdir, "genes_collection/linclust.nucleotide.fnn"),
            tmpDB = os.path.join(tmpdir, "genes_collection/linclust.nucleotide.fnn/DB"),
            tmpDB_clust = os.path.join(tmpdir, "genes_collection/linclust.nucleotide.fnn/DB_clust"),
            tmpDB_rep = os.path.join(tmpdir, "genes_collection/linclust.nucleotide.fnn/DB_rep"),
            tmpfolder = os.path.join(tmpdir, "genes_collection/linclust.nucleotide.fnn/temp"),
            seqid = config["linclust"]["seqid"],
        shell:
            "mkdir -p {params.tmp} && "
            "mmseqs createdb {input} {params.tmpDB} && "
            "mmseqs linclust {params.tmpDB} {params.tmpDB_clust} {params.tmpfolder} --min-seq-id {params.seqid} --threads {threads} && "
            "mmseqs createsubdb {params.tmpDB_clust} {params.tmpDB} {params.tmpDB_rep} && "
            "mmseqs convert2fasta {params.tmpDB_rep} {output[0]} "


    rule concat_genes:
        output:
            os.path.join(working_dir, "genes_collection/all_prot.nucleotide.fnn"),
            os.path.join(working_dir, "genes_collection/all_prot.translation.faa"),
        input:
            fnn = expand(os.path.join(working_dir, "intermediate_results/genes/{sample}/{sample}.prot.nucleotide.fnn"), sample=samples),
            faa = expand(os.path.join(working_dir, "intermediate_results/genes/{sample}/{sample}.prot.translation.faa"), sample=samples),
        shell:
            "cat {input.fnn} >> {output[0]} && "
            "cat {input.faa} >> {output[1]} "


    ####################################################################################################
    #
    #    Search CDS per sample using prodigal
    #
    ####################################################################################################

    rule fetch_CDS:
        output:
            coordinates = os.path.join(working_dir, "intermediate_results/genes/{sample}/{sample}.coord.gff"),
            nucleo = os.path.join(working_dir, "intermediate_results/genes/{sample}/{sample}.prot.nucleotide.fnn"),
            transl = os.path.join(working_dir, "intermediate_results/genes/{sample}/{sample}.prot.translation.faa"),
        input:
            os.path.join(working_dir, "intermediate_results/assembly/single_assembly", assembler, "{sample}/contigs/{sample}_" + assembly),
        conda:
            os.path.join(CONDAENV, "prodigal.yaml")
        log:
            os.path.join(working_dir, "logs/genes_collection/{sample}/{sample}.CDS.log")
        shell:
            'prodigal -i {input} '
            '-o {output.coordinates} '
            '-f gff '
            '-d {output.nucleo} '
            '-a {output.transl} '
            '-p meta  &> {log} ' #not consistent with the doc

    rule genomad_metag_sample:
        output:
            os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_aggregated_classification/{sample}_post_filtering.contigs_aggregated_classification.tsv"),
            os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_annotate/{sample}_post_filtering.contigs_taxonomy.tsv"),
            os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_find_proviruses/{sample}_post_filtering.contigs_provirus_genes.tsv"),
            os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_marker_classification/{sample}_post_filtering.contigs_marker_classification.tsv"),
            os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_nn_classification/{sample}_post_filtering.contigs_nn_classification.tsv"),
            os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}/{sample}_post_filtering.contigs_summary/{sample}_post_filtering.contigs_virus_summary.tsv"),            

        input:
            os.path.join(working_dir, "intermediate_results/assembly/single_assembly", assembler, "{sample}/contigs/{sample}_" + assembly),
            os.path.join(db_dir, "genomad_db/genomad_db"),
        params:
            db = os.path.join(db_dir, "genomad_db"),
            output_dir = os.path.join(working_dir, "genes_collection/annotations/genomad/{sample}"),
        conda:
            os.path.join(CONDAENV, "genomad.yaml")
        shell:
            'mkdir -p {params.output_dir} && '
            'genomad end-to-end --cleanup {input[0]} {params.output_dir} {params.db} '

