#!/usr/bin/env python3
# -*- coding: utf-8 -*-

####################################################################################################
#
#    metatranscriptomic reads mapping on contigs and/or MAGs
#
####################################################################################################

#! Need to adapt with metaT reads
def input_cmd(wildcards):
    if wildcards.assembly == "single_assembly":
        list_reads = []
        for run in metaTreads2use[wildcards.src]:
            list_reads.extend(metaTreads2use[wildcards.src][run])
        print("list_reads", list_reads)    
        return list_reads
    elif wildcards.assembly == "co_assembly":
        if simka_type == "None":
            return os.path.join(tmpdir, 'samples.txt')
        return os.path.join(intermediate_results_dir, "assembly/co_assembly/clusters/{src}.txt")
    else:
        raise ValueError

    ####################################################################################################
    #
    #    metaT reads mapping against all genomes/MAGs
    #
    ####################################################################################################

def metaT_resolve_final_target(wildcards):
        targets = []
        targets.extend(expand(os.path.join(working_dir, "metaT/tables/coverm_genomes_abundance.{level}_metaT"), level=dRep_levels))
        if config["genes_collection"]:
            targets.append(os.path.join(working_dir, "metaT/tables/coverm_genes_abundance_metaT"))
        return targets

rule metaT_mapping_target:
    output:
        os.path.join(tmpdir, "finished_metaT_mapping"),
    input:
        metaT_resolve_final_target,
    shell:
        "touch {output}"

def metaT_expand_bam_coverm_genome(wildcards):
    inputs = []
    if "single_assembly" in assembly_request :
        request="single_assembly"
        inputs.extend(expand(os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{{level}}.filtered.sorted.bam"),
            src = assembly_dict.get(request), assembly=request))
    if "co_assembly" in assembly_request :
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
        request="co_assembly"
        inputs.extend(expand(os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{{level}}.filtered.sorted.bam"),
            src = assembly_dict.get(request), assembly=request))
    return inputs

rule metaT_coverm_genomes_abundances:
    output:
        os.path.join(working_dir, "metaT/tables/coverm_genomes_abundance.{level}_metaT"),
    input:
        metaT_expand_bam_coverm_genome,
    params:
        genome_directory = directory("genomes_collection/dereplicated_at_{level}"),
        method = coverm_genome_calcultation,
    conda:
        os.path.join(CONDAENV, "coverm.yaml")
    shell:
        "coverm genome --methods {params.method} --bam-files {input} --genome-fasta-directory {params.genome_directory} --genome-fasta-extension fa -o {output}"

rule metaT_sample2genomes_sort_and_index_post_filtering:
    """
    Perform 2 modifications on the mapping results using samtools:
        - sort BAM file by position
        - index BAM file for later easy read extraction
    """
    output:
        bam = os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.filtered.sorted.bam"),
        bai = os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.filtered.sorted.bam.bai"),
    input:
        os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "filtered_{src}##genomes_{level}.sorted.bam"),
    threads: 5
    conda:
        os.path.join(CONDAENV, "samtools.yaml")
    log:
        os.path.join("logs", "metaT", "genomes_collection", "mapping", "{assembly}", "sorting_{src}_{level}.filtered.log"),
    shell:
        "samtools view -u {input} | "
        "samtools sort "
        "-@ {threads} "             # number of threads used
        "-o {output.bam} "
        "&& samtools index "
        "{output.bam} &> {log}"

rule metaT_quality_filter_reads:
    """
    Filter reads based on mapping quality and identity.
    Output is temporary because it will be sorted.
    """
    output:
        os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "filtered_{src}##genomes_{level}.sorted.bam"),
    input:
        os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.sorted.bam"),
    log:
        os.path.join("logs", "metaT", "genomes_collection", "mapping", "{assembly}", "{src}_{level}.filtering.log")
    conda:
        os.path.join(CONDAENV, "bamutils.yaml")
    params:
        min_mapq = config["sample2genomes"]["min_quality"],
        min_idt = config["sample2genomes"]["min_identity"],
        min_len = config["sample2genomes"]["min_len"],
        pp = config["sample2genomes"]["properly_paired"],
    script:
        "../scripts/bamprocess.py"

rule metaT_sample2genomes_sort_and_index:
    """Perform 2 modifications on the mapping results using samtools:
        - sort BAM file by position
        - index BAM file for later easy read extraction
    """
    output:
        bam = os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.sorted.bam"),
        bai = os.path.join("metaT", "genomes_collection", "mapping", "{assembly}", "bam", "{src}", "{src}##genomes_{level}.sorted.bam.bai"),
    input:
        os.path.join(tmp, "metaT", "sample2genomes", "{assembly}", "{src}##genomes_{level}.sam"),
    threads: 5
    conda:
        os.path.join(CONDAENV, "samtools.yaml")
    log:
        os.path.join("logs", "metaT", "genomes_collection", "mapping", "{assembly}", "sorting_{src}_{level}.log"),
    shell:
        "samtools view -u {input} | "
        "samtools sort "
        "-@ {threads} "             # number of threads used
        "-o {output.bam} "
        "&& samtools index "
        "{output.bam} &> {log}"


rule metaT_sample2genomes_mapping:
    """
    Map the reads from a run against an index of genomes.
    Output is temporary because it will be merged with other
    runs from this biological sample and compressed into BAM.
    """

    output:
        temp(os.path.join(tmp, "metaT", "sample2genomes", "{assembly}", "{src}##genomes_{level}.sam"))
    input:
        index = expand(os.path.join("genomes_collection", "mapping", "index{{level}}", "all_genomes.{id}.bt2l"), id=range(1, 5) ),
        reads = input_cmd,
        qc_metaT = os.path.join(tmpdir, "QC_metaT.checkpoint")
    log:
        os.path.join("logs", "metaT", "genomes_collection", "mapping", "{assembly}", "{src}##{level}.log")
    threads: 14
    priority: 80
    conda:
        os.path.join(CONDAENV, "bowtie2.yaml")
    params:
        dbbasename = lambda wildcards, input: input[0][:-7],
        cmd = lambda wildcards, input : cmdparser.cmd(wildcards.src, input.reads, metaTreads2use, "bowtie2").cmd # wildcards ==  sample in this case
    shell:
        "bowtie2 "
        "-p {threads} "             # number of parallel threads
        "--no-unal "                # remove unmapped reads (decrease size)
        "-x {params.dbbasename} "   # index for mapping
        "{params.cmd} "
        "-S {output} "
        "2> {log} "



    ####################################################################################################
    #
    #    metaT reads mapping against all genes
    #
    ####################################################################################################

#! Need to adapt with metaT reads
def input_cmd_genes(wildcards):
    list_reads = []
    for run in metaTreads2use[wildcards.src]:
        list_reads.extend(metaTreads2use[wildcards.src][run])
        print("list_reads", list_reads)
        return list_reads

if config["genes_collection"]:
    rule metaT_coverm_genes_abundances:
        output:
            os.path.join(working_dir, "metaT/tables/coverm_genes_abundance_metaT"),
        input:
            bams = expand(os.path.join(working_dir, "metaT/genes_collection/mapping/bams/{sample}/{sample}##genes.sorted.bam"), sample=samples),
        params:
            method = coverm_contig_calcultation,
        conda:
            os.path.join(CONDAENV, "coverm.yaml")
        shell:
            "coverm contig --methods {params.method} --bam-files {input.bams} -o {output}"

    rule metaT_sort_and_index:
        output:
            os.path.join(working_dir, "metaT/genes_collection/mapping/bams/{sample}/{sample}##genes.sorted.bam"),
            os.path.join(working_dir, "metaT/genes_collection/mapping/bams/{sample}/{sample}##genes.sorted.bam.bai"),
        input:
            os.path.join(tmpdir, "metaT_tmp/genes_collection/{sample}##genes.sam"),
        log:
            "logs/metaT/genes_collection/mapping/metaT_{sample}_sorting_indexing.log"
        threads: 10
        priority: 100
        conda:
            os.path.join(CONDAENV, "samtools.yaml")
        priority:5
        shell:
            "samtools view -u {input} | "
            "samtools sort "
            "-@ {threads} "             # number of threads used
            "-o {output[0]} "
            "&& samtools index "
            "{output[0]} "
            "&> {log} "

    rule metaT_bowtie2_map:
        """
            Map the metaT reads against a representative set of genes from a single-assembly.
        """
        output:
            temp(os.path.join(tmpdir, "metaT_tmp/genes_collection/{src}##genes.sam"))
        input:
            index1 = expand(os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95.{id}.bt2l"), id=range(1, 4)),
            index2 = expand(os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95.rev.{id}.bt2l"), id=range(1, 2)),
            reads = input_cmd_genes,
            qc_metaT = os.path.join(tmpdir, "QC_metaT.checkpoint")
        log:
            "logs/metaT/genes_collection/mapping/metaT_{src}##genes.log"
        threads: 10
        priority: 80
        conda:
            os.path.join(CONDAENV, "bowtie2.yaml")
        params:
            prefix = os.path.join(working_dir, "genes_collection/mapping/index/all_prot.nucleotide.c95"),
            cmd = lambda wildcards,input : cmdparser.cmd(wildcards.src, input.reads, metaTreads2use, "bowtie2").cmd,
        shell:
            "bowtie2 "
            "-p {threads} "             # number of parallel threads
            "--no-unal "                # remove unmapped reads (decrease size)
            "-x {params.prefix} "       # index for mapping
            "{params.cmd} "
            "-S {output} "
            "&> {log} "