#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys, os
import subprocess
ps  = subprocess.Popen(('pip', 'show', 'magneto'), stdout=subprocess.PIPE)
pip_output = subprocess.check_output(('grep', 'Location'), stdin=ps.stdout)
ps.wait()
location = pip_output.split()[1].decode('utf-8')
scripts_dir = os.path.join(location, 'magneto/scripts')

sys.path.insert(0, scripts_dir)

import yaml

from inspect import getmembers, isfunction

import parserconfig as conf

#############
# Set paths #
#############

env_dir = "../envs" #It is defined in main Snakefile but increase readibility
working_dir = config["project"]
tmpdir = os.path.join(working_dir, "tmp") #It is defined in main Snakefile but increase readibility
logdir = os.path.join(working_dir, "logs")

submodule = config["target"]
print("submodule", submodule)

############
# QC metaG #
############
############
# QC metaT #
############

if submodule=="metaT" : 
    #store all pre-processing reports
    reports_dir = os.path.join(working_dir, "metaT/reports/pre_processing")
    #store results from different programs
    intermediate_results_dir_qc = os.path.join(working_dir, "metaT/intermediate_results")

    #store html report from multiqc
    multiqc_reports_dir = os.path.join(reports_dir, "multiqc_data")
    #store reports from fastP
    fastp_reports_dir = os.path.join(reports_dir, "fastP")
    #store reports from FQscreen
    fatsqcreenreports_dir = os.path.join(reports_dir, "FQscreen")
    #store filtered reads obtained with QC
    reads_PE_dir = os.path.join(intermediate_results_dir_qc, "reads/PE")
    reads_SE_dir = os.path.join(intermediate_results_dir_qc, "reads/SE")

    metaTreads2use, metaTsamples, deconta = conf.parse_metaT(config)
    print("metaTsamples", metaTsamples)
    print("metaTreads2use", metaTreads2use)
    list_deconta = []
    for specie in deconta :
        list_deconta.append(specie)

    PE_samples = {}
    SE_samples = {}
    list_path = []

    for sample in metaTsamples:
        for run in metaTsamples[sample]:
            list_path.extend(list(metaTsamples[sample][run]))
            if len(metaTsamples[sample][run]) == 2:
                if sample not in PE_samples:
                    PE_samples[sample] = {}
                PE_samples[sample][run] = metaTsamples[sample][run]

            elif len(metaTsamples[sample][run]) == 1:
                if sample not in SE_samples:
                    SE_samples[sample] = {}
                SE_samples[sample][run] = metaTsamples[sample][run]
    path_to_data = list_path[0].rsplit('/', 1)[0]
    localrules : end_qc_metaT
    rule end_qc_metaT:
        """
        Final output of this module
        """
        output: 
            os.path.join(tmpdir, "QC_metaT.checkpoint")
        input:
            PE = get_PE_output(PE_samples),
            SE = get_SE_output(SE_samples),
            checkpoint = os.path.join(multiqc_reports_dir, "multiqc.checkpoint"),
        priority: 5
        shell:
            "touch {output}"

else : 
    #store all pre-processing reports
    reports_dir = os.path.join(working_dir, "reports/pre_processing")
    #store results from different programs
    intermediate_results_dir_qc = os.path.join(working_dir, "intermediate_results")

    #store html report from multiqc
    multiqc_reports_dir = os.path.join(reports_dir, "multiqc_data")
    #store reports from fastP
    fastp_reports_dir = os.path.join(reports_dir, "fastP")
    #store reports from FQscreen
    fatsqcreenreports_dir = os.path.join(reports_dir, "FQscreen")
    #store filtered reads obtained with QC
    reads_PE_dir = os.path.join(intermediate_results_dir_qc, "reads/PE")
    reads_SE_dir = os.path.join(intermediate_results_dir_qc, "reads/SE")

    reads2use, samples, deconta = conf.parse(config)
    print("samples", samples)
    list_deconta = []
    for specie in deconta :
        list_deconta.append(specie)

    PE_samples = {}
    SE_samples = {}
    list_path = []

    for sample in samples:
        for run in samples[sample]:
            list_path.extend(list(samples[sample][run]))
            if len(samples[sample][run]) == 2:
                if sample not in PE_samples:
                    PE_samples[sample] = {}
                PE_samples[sample][run] = samples[sample][run]

            elif len(samples[sample][run]) == 1:
                if sample not in SE_samples:
                    SE_samples[sample] = {}
                SE_samples[sample][run] = samples[sample][run]
    path_to_data = list_path[0].rsplit('/', 1)[0]
    print("SE_samples", SE_samples)
    localrules : target_QC
    rule target_QC:
        """
        Final output of this module
        """
        output:
            os.path.join(tmpdir, "QC.checkpoint")
        input:
            PE = get_PE_output(PE_samples),
            SE = get_SE_output(SE_samples),
            checkpoint = os.path.join(multiqc_reports_dir, "multiqc.checkpoint"),
        priority: 5
        shell:
            "touch {output}"

def get_PE_output(dico_samples):
    """
    Returns list of PE output from fastP.
    """
    list_PE_output = []
    for sample in dico_samples:
        for run in dico_samples[sample]:
            for index in ['R1', 'R2', 'unpaired']:
                list_PE_output.append(os.path.join(reads_PE_dir,
                    f'{sample}/{sample}_{run}_{index}.filtered.fastq.gz'))
    return list_PE_output

def get_SE_output(dico_samples):
    """
    Returns list of SE output from fastP.
    """
    list_SE_output = []
    for sample in dico_samples:
        for run in dico_samples[sample]:
            list_SE_output.append(os.path.join(reads_SE_dir,
                f'{sample}/{sample}_{run}_SE.filtered.fastq.gz'))
    return list_SE_output

####################################################################################################
#
#    merging paired-end read if required --> this part is currently not working
#
####################################################################################################


#rule target_merging:
#    '''
#    EXPLAIN !
#    '''
#    output:
#        temp(os.path.join(tmpdir, "merging.checkpoint"))
#    input:
#        expand(os.path.join(reads_PE_dir, "{PE}/{PE}_{run}.merged.fastq.gz"),
#        PE = PE_samples,
#        run = PE_samples[wildcards.PE]),
#        expand(os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_R1.unmerged.fastq.gz"),
#        PE = PE_samples,
#        run = PE_samples[wildcards.PE]),
#        expand(os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_R2.unmerged.fastq.gz"),
#        PE = PE_samples,
#        run = PE_samples[wildcards.PE]),
#    priority: 5
#    shell:
#        "touch {output}"


def merging_input(wildcards):

    if config["skip_QC"]:
        return [PE_samples[wildcards.PE][wildcards.run][0], PE_samples[wildcards.PE][wildcards.run][1]]
    else:
        return [os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_R1.filtered.fastq.gz"),
            os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_R2.filtered.fastq.gz")]


rule merge:
    output:
        merged = os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_.merged.fastq.gz"),
        unmerged_R1 = os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_R1.unmerged.fastq.gz"),
        unmerged_R2 = os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_R2.unmerged.fastq.gz"),
    input:
        merging_input,
    conda:
        os.path.join(env_dir, "bbtools.yaml")
    log :
        os.path.join(logdir, "merging/{PE}_{run}_merge.log"),
    priority: 5
    params:
        minoverlap = config["minoverlap"]
    shell:
        "bbmerge.sh in1={input[0]} "
        "in2={input[1]} "
        "out={output.merged} "
        "outu1={output.unmerged_R1} "
        "outu2={output.unmerged_R2} "
        "minoverlap={params.minoverlap}"


################################################################################
#
#    rules to compute QC and contamination screening against fastq files
#    target definition
#
################################################################################


################################################################################
#
#    Multiqc report
#
################################################################################


rule target_multiQC:
    output:
        os.path.join(multiqc_reports_dir, "multiqc.checkpoint")
    input:
        os.path.join(multiqc_reports_dir, "multiqc_report.html"),
    priority: 0
    shell:
        "touch {output}"

rule multiqc:
    output:
        report = os.path.join(multiqc_reports_dir, "multiqc_report.html"),
    input:
        PE = get_PE_output(PE_samples),
        SE = get_SE_output(SE_samples),
    conda:
        os.path.join(env_dir, "multiqc.yaml"),
    priority: 0
    threads : 10
    params: out = multiqc_reports_dir,
    shell:
        "multiqc -f reports/pre_processing/ " #-f overwrite previous report
        "-o {params.out} "

################################################################################
#
#    file management
#
################################################################################

rule rename_SE:
    """
    Replace *R1* in single-end data by SE to a better disctinction between PE and SE
    """
    output:
        os.path.join(reads_SE_dir, "{SE}/{SE}_{run}_SE.filtered.fastq.gz"),
    input:
        os.path.join(tmpdir, "{SE}_{run}_SE.tagged_filter.fastq.gz")
    priority:5
    shell:
        "mv {input} {output}"


rule extract_and_filter:
    """
    See https://jgi.doe.gov/data-and-tools/bbtools/bb-tools-user-guide/repair-guide/
    fastQ screen doesn't support paired-end input and process file by file.
    Therefore unpaired reads are produced in both R1 and R2.
    This rule extract unpaired reads from R1 and R2,
    then merge them with unpaired reads produced by fastp.
    """
    output:
        filtered_R1 = os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_R1.filtered.fastq.gz"),
        filtered_R2 = os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_R2.filtered.fastq.gz"),
        unpaired = os.path.join(reads_PE_dir, "{PE}/{PE}_{run}_unpaired.filtered.fastq.gz"),
    input:
        pair1 = os.path.join(tmpdir, "{PE}_{run}_R1.tagged_filter.fastq.gz"),
        pair2 = os.path.join(tmpdir, "{PE}_{run}_R2.tagged_filter.fastq.gz"),
        unpaired = os.path.join(tmpdir, "{PE}_{run}_unpaired.tagged_filter.fastq.gz"),
    conda:
        os.path.join(env_dir, "bbtools.yaml")
    log:
        os.path.join(logdir, "repair_paired_end/extract_and_filter_{PE}_{run}.log")
    priority:5
    shell:
        "repair.sh -in={input[0]} -in2={input[1]} " #input reads
        "-out={output.filtered_R1} -out2={output.filtered_R2} " #filtered output reads
        "-outs={output.unpaired} " #filtered unpaired reads
        "-Xmx10g &> {log};"
        "cat {input.unpaired} >> {output.unpaired}; " #????

################################################################################
#
#    contaminant screening using FastQscreen
#
################################################################################

rule fastQ_screen:
     """
     Contaminant screening using fastQscreen.
     https://www.bioinformatics.babraham.ac.uk/projects/fastq_screen/
     """
     output:
        reads = temp(os.path.join(tmpdir, "{sample}_{run}_{id}.tagged_filter.fastq.gz")),
        #img = os.path.join(fatsqcreenreports_dir, "{sample}/{run}/{id}/{sample}_{run}_QC_{id}_screen.png"),
        txt = os.path.join(fatsqcreenreports_dir, "{sample}/{run}/{id}/{sample}_{run}_QC_{id}_screen.txt"),
        html = os.path.join(fatsqcreenreports_dir, "{sample}/{run}/{id}/{sample}_{run}_QC_{id}_screen.html"),
     input:
        conf = os.path.join(tmpdir, "fastq_screen.conf"),
        fastq = os.path.join(tmpdir, "{sample}_{run}_QC_{id}.fastq.gz"),
     conda:
        os.path.join(env_dir, "pre-processing.yaml")

     benchmark:
        "benchmark/FQscreen/{sample}_{run}_{id}.fastqscreen.performances.txt"

     params:
        fastQscreen = os.path.join(fatsqcreenreports_dir, "{sample}", "{run}", "{id}"),
     log: os.path.join(logdir, "FQscreen/{sample}_{run}_{id}.log")
     threads: 5
     priority: 5
     shell:
        "if [ $(gunzip -c {input.fastq} | wc -l) -eq 0 ]; " #not useful
        "then "
        "touch {output.reads} {output.txt} {output.html}; "
        "else "
        "fastq_screen --nohits {input.fastq} "
        "--force "
        "--aligner=Bowtie2 "
        "--conf={input.conf} "
        "--outdir {params.fastQscreen} "
        #"--illumina1_3 " error with quality score, no alignement so 0 contamination
        "--threads {threads} "
        "&>{log}; "
        "mv {params.fastQscreen}/{wildcards.sample}_{wildcards.run}_QC_{wildcards.id}.tagged_filter.fastq.gz {output.reads}; "
        "fi "


rule fastQ_screen_conf:
     """
     Input file required by fastQ_screen rule.
     See fastqscreen documentation for details.
     """
     output:
          os.path.join(tmpdir, "fastq_screen.conf"),

     input:
         # return value from deconta dictionary (i.e list of genomes' index)
         **conf.index_contamination(config, deconta),
     params:
         threads = 10,
         #targeted species
         species = conf.index_contamination(config, deconta).keys(),
         prefix = deconta,
     priority: 5
     run:
         with open(str(output), "w") as file:
             file.write(f"THREADS\t{params.threads}\n")
             file.write("########\n\n")
             # for each species write species name \t prefix of its genome index
             for spec in params.species:
                 file.write(f"\nDATABASE\t{spec}\t{db_dir}/{params.prefix.get(spec)}\n########")


if config["dl_genomes"]:
    rule get_genomes:
        """
        cp Database to another directory to avoid directory conflicts 
        caused by snakemake directory generation  
        """
        output:
            expand(os.path.join(db_dir, "{path_to_index}.{id}.bt2"),
            id=range(1,5),
                path_to_index = deconta.values()),
            expand(os.path.join(db_dir, "{path_to_index}.rev.{id}.bt2"),
                id=range(1, 3),
                path_to_index = deconta.values()),
        input:
            directory(os.path.join(db_dir, "FastQ_Screen_Genomes/")),
        params:
            outdir = db_dir,
        shell:
            "cp -a {params.outdir}/FastQ_Screen_Genomes/. {params.outdir}/FQ_Screen_Genomes/ && "
            "rm -r {params.outdir}/FastQ_Screen_Genomes"

    rule fastq_screen_database:
        """
        Download fastqscreen database if required.
        See config and fastqscreen documentation.
        This database includes species usually concerned by contamination screening.
        """
        output:
            directory(os.path.join(db_dir, "FastQ_Screen_Genomes/")),
        conda:
            os.path.join(env_dir, "pre-processing.yaml")
        params:
            outdir = db_dir
        priority: 5
        shell:
            "fastq_screen --get_genomes --force --outdir {params.outdir}"

################################################################################
#
#    Reads QC and filtering using fastP
#
################################################################################

def get_SE_read(wildcards):
    for file_path in SE_samples[wildcards.SE][wildcards.run]:
        if os.path.exists(file_path):
            return file_path
        raise FileNotFoundError(
            f"No file {file_path} found for sample : {wildcards.SE} run : {wildcards.run}")

rule fastP_SE:
    """
    Perform a single-end quality control on metagenomic sample using fastP
    https://github.com/OpenGene/fastp
    """
    output:
        filtered_R1 = temp(os.path.join(tmpdir, "{SE}_{run}_QC_SE.fastq.gz")),
        json =  os.path.join(fastp_reports_dir, "{SE}/{SE}_{run}_fastp.json"),
        html = os.path.join(fastp_reports_dir, "{SE}/{SE}_{run}_fastp.html"),
    input:
        get_SE_read,
    conda:
        os.path.join(env_dir, "pre-processing.yaml")
    log:
        os.path.join(logdir, "fastp/{SE}_{run}.fastP.txt")
    benchmark:
        "benchmark/fastp/{SE}_{run}.fastP.performances.txt"
    threads: 1
    priority: 5
    params:
        filters = "-f {} -t {} --poly_g_min_len {} -q {} -u {} -n {} -e {} -l {}".format(
            config["trimming_R1_front"],
            config["trimming_R1_tail"],
            config["polyG_min_len"],
            config["min_phred_score"],
            config["unqualified_bases"],
            config["max_N"],
            config["average_Phred"],
            config["minimum_reads_length"])
    shell:
        "fastp "
        "-i {input} "
        "-z 4 "
        "-o {output.filtered_R1} "
        "--json {output.json} "
        "--html {output.html} "
        "{params.filters} "
        "--thread {threads} 2> {log};"

def get_PE_read(wildcards, mate):
    suffixes = (f"R{mate}.fastq.gz", f"{mate}.fastq.gz", f"R{mate}.fastq", f"{mate}.fastq")
    for file_path in PE_samples[wildcards.PE][wildcards.run]:
        if file_path.endswith(suffixes):
            print("file path:", file_path)
            if os.path.exists(file_path):
                return file_path
            raise FileNotFoundError(
                f"No file {file_path} found for sample : {wildcards.PE} run : {wildcards.run}")


rule fastP_PE:
    """
    Perform a paired-end quality control on metagenomic sample using fastP
    https://github.com/OpenGene/fastp
    """
    output:
        filtered_R1 = temp(os.path.join(tmpdir, "{PE}_{run}_QC_R1.fastq.gz")),
        filtered_R2 = temp(os.path.join(tmpdir, "{PE}_{run}_QC_R2.fastq.gz")),
        unpaired = temp(os.path.join(tmpdir, "{PE}_{run}_QC_unpaired.fastq.gz")),
        json =  os.path.join(fastp_reports_dir, "{PE}_{run}_{run}_fastp.json"),
        html = os.path.join(fastp_reports_dir, "{PE}_{run}_{run}_fastp.html"),
    input:
        #get_reads_PE(PE_samples)
        r1 = lambda wc: get_PE_read(wc, 1),
        r2 = lambda wc: get_PE_read(wc, 2),
    conda:
        os.path.join(env_dir, "pre-processing.yaml")
    log:
        os.path.join(logdir, "fastp/{PE}_{run}.log")
    benchmark:
        "benchmark/fastp/{PE}_{run}.fastP.performances.txt"
    threads: 1
    priority: 5
    params:
        filters = "-f {} -t {} -F {} -T {} --poly_g_min_len {} -q {} -u {} -n {} -e {} -l {}".format(
            config["trimming_R1_front"],
            config["trimming_R1_tail"],
            config["trimming_R2_front"],
            config["trimming_R2_tail"],
            config["polyG_min_len"],
            config["min_phred_score"],
            config["unqualified_bases"],
            config["max_N"],
            config["average_Phred"],
            config["minimum_reads_length"])
    shell:
        "fastp "
        "-i {input.r1} "
        "-I {input.r2} "
        "-o {output.filtered_R1} "
        "-O {output.filtered_R2} "
        "--unpaired1 {output.unpaired} "
        "--unpaired2 {output.unpaired} "
        "--json {output.json} "
        "--html {output.html} "
        "--thread {threads} "
        "-z 4 "
        "{params.filters} "
        "&> {log};"
