#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import yaml
from snakemake.logging import logger
import os
import gzip

#############
# Set paths #
#############
## TODO: get all paths parameterizable in config file

#env_dir = "../envs" #It is defined in main Snakefile but increase readibility
working_dir = config["project"]
tmpdir = f"{working_dir}/tmp" #It is defined in main Snakefile but increase readibility
logdir = f"{working_dir}/logs"

#store all pre-processing reports
reports_dir = f"{working_dir}/reports/pre_processing"
#store results from different programs
intermediate_results_dir = f"{working_dir}/intermediate_results"
#store filtered reads obtained with QC
reads_PE_dir = f"{intermediate_results_dir}/reads/PE"
reads_SE_dir = f"{intermediate_results_dir}/reads/SE"

#input file
#PE_samples = dict(filter(lambda samples: len(samples[1]) == 2,samples.items()))
#SE_samples = dict(filter(lambda samples: len(samples[1]) == 1,samples.items()))

########################################################################################################

def fastq_is_empty(fastq):
    '''
    Check if a fastq.gz is empty or not.
    May be useful for unpaired reads
    which represent 0 up to 90 percent of a paired-end dataset.
    '''
    with gzip.open(fastq, 'rb') as f:
        data = f.read(1)
    return len(data) == 0

########################################################################################################

def parse_samples(config):
    '''
    Load the samples.yaml file into a python dictionary
    '''
    with open(config["samples"], "r") as samplefile:
        samples = yaml.load(samplefile, Loader=yaml.FullLoader)
    logger.info("Found {} biological samples".format(len(samples)))
    return samples

########################################################################################################

def initreads(samples):
    '''
    Create symlink of raw reads files under the "reads" folder of your project.
    Thus raw data and pre-process data will be store in a single directory per sample
    It's an alternative way to samples directory to retrieve reads file.
    '''
    for sample in samples:
        #print(f"This is current sample : {sample}")
        for run in samples[sample]:
            #print(f"This is the current run : {run}")
            try:
                if len(samples[sample][run]) == 1: #SE
                    #print("I should not be here !")
                    os.makedirs(os.path.join(reads_SE_dir, sample, run),
                     exist_ok=True)
                    os.symlink(samples[sample][run],
                     os.path.join(reads_SE_dir, sample, run))

                else:
                    #print("I should reach this point !")
                    os.makedirs(f"{reads_PE_dir}/{sample}/{run}/", exist_ok=True)
                    os.symlink(samples[sample][run][0],
                    f"{reads_PE_dir}/{sample}/{run}/{run}_R1")
                    os.symlink(samples[sample][run][1],
                     f"{reads_PE_dir}/{sample}/{run}/{run}_R2")
            except FileExistsError:
                #avoids error from symlink generation
                pass #do nothing. What did you expect ?
            except:#for other errors (inaccurate : just for debugging)
                logger.warning(f"Unfound run : {run} !")
                raise #debugging

########################################################################################################

def resolveReads(config,samples):
    '''
    WIP see #TODO : modification to allow other sample file organization
    important :
    Because there pre-processing and merging are optional, reads files to use for assembly, motus profiling (etc) may vary.
    I had two solutions to deal with this particularity;
        1) Duplicate each rules depending on reads file to deal with each situations. (Redundant code +++)
        2) Use a little parser to resolve the desired input files for each rules dependings on reads files.
    I choose the second option. This function return a dictionary with sample name as keys and intermediate
    files names (depending on user choice) as values.
    '''
    tmp ={}
    initreads(samples)

    if config["skip_QC"] and not config["merging"]:
        # use raw sample
        return samples

    elif config["skip_QC"] and config["merging"]:
        # use raw read merged
        for sample in samples.keys():
            if sample not in tmp:
                tmp[sample] = {}
            for run in samples[sample]:
                if len(samples[sample][run]) == 2: #PE
                    tmp[sample][run] = [
                            f"{reads_PE_dir}/{sample}/{run}/{run}_R1.unmerged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{run}/{run}_R2.unmerged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{run}/{run}.merged.fastq.gz",
                            ]
                elif len(samples[sample][run]) == 1: #SE
                    tmp[i] = [os.readlink(os.path.join(reads_SE_dir, sample, run))]
                else:
                    raise Exception('no sample found')
        return tmp
        
    elif not config["skip_QC"] and not config["merging"]:
        # use QC reads
        for sample in samples.keys():
            if sample not in tmp:
                tmp[sample] = {}
            for run in samples[sample]:
                if len(samples[sample][run]) == 2: #PE
                    tmp[sample][run] = [
                        f"{reads_PE_dir}/{sample}/{run}/{run}_R1.filtered.fastq.gz",
                        f"{reads_PE_dir}/{sample}/{run}/{run}_R2.filtered.fastq.gz",
                        f"{reads_PE_dir}/{sample}/{run}/{run}_unpaired.filtered.fastq.gz",
                        ]
                elif len(samples[sample][run]) == 1: #SE
                    tmp[sample][run] = [f"{reads_PE_dir}/{sample}/{run}/{run}_SE.filtered.fastq.gz"]
                else:
                    raise Exception('No sample found')
        return tmp
    elif not config["skip_QC"] and config["merging"]:
        # use QC reads merged
        for sample in samples.keys():
            if sample not in tmp:
                tmp[sample] = {}
            for run in samples[sample]:
                if len(samples[sample][run]) == 2: #PE
                    tmp[sample][run] = [
                            f"{reads_PE_dir}/{sample}/{run}/{run}_R1.unmerged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{run}/{run}_R2.unmerged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{run}/{run}.merged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{run}/{run}_unpaired.filtered.fastq.gz",
                            ]
                elif len(samples[sample][run]) == 1: #SE
                    tmp[sample][run] = [f"{reads_PE_dir}/{sample}/{run}/{run}_SE.filtered.fastq.gz"]
                else:
                    raise Exception('No sample found')
        return tmp

########################################################################################################

def index_contamination(config, deconta):
    '''
    For contamination removal --
    build a dictionary with species as keys and prefixed index files
    '''
    speciesindex = {}
    if len(deconta.keys()) > 0:
        for species in list(deconta.keys()):
            indexes = []
            indexes.append(f"{deconta.get(species)}.1.bt2")
            indexes.append(f"{deconta.get(species)}.2.bt2")
            indexes.append(f"{deconta.get(species)}.3.bt2")
            indexes.append(f"{deconta.get(species)}.4.bt2")
            indexes.append(f"{deconta.get(species)}.rev.1.bt2")
            indexes.append(f"{deconta.get(species)}.rev.2.bt2")
            speciesindex[species] = indexes
        return speciesindex

########################################################################################################

def SE_symlink(wildcards):
    '''
    Small input function to retrieve a file path from a symlink, for
    single-end data case.
    '''
    r1 = os.path.join(reads_SE_dir, wildcards.SE , wildcards.run)
    return os.readlink(r1)

########################################################################################################

def PE_symlink(wildcards):
    '''
    Small input function to retrieve a file path from a symlink, for
    paired-end data case.
    '''
    r1 = os.path.join(reads_PE_dir, wildcards.PE, wildcards.run + "_R1")
    r2 = os.path.join(reads_PE_dir, wildcards.PE, wildcards.run + "_R2")
    return os.readlink(r1), os.readlink(r2)

########################################################################################################

def parse(config):
    '''
    "Main" parser. Return 3 dictionaries. See above.
    '''
    deconta = config["contaminant"]
    samples = parse_samples(config)
    reads2use = resolveReads(config, samples)
    logger.info("Reads QC {}".format("enable" if not config["skip_QC"] else "disable"))
    logger.info("Found {} species for decontamination".format(len(list(deconta.keys()))))
    logger.info("species : {}".format(list(deconta.keys())))
    initreads(samples)
    return reads2use, samples, deconta



def mapping_cmd(config, step):
    '''
    Parse mapping parameters for bowtie2 from config file.

    config : config loaded in dictionary
    step : config dictionary key

    Return bowtie2 command for mapping rules.
    '''

    #print(step, type(step)) #debugging
    cmd = []
    if config[step]["preset"] != "None":
        setting = "--{}".format(config[step]["preset"])
        if not config[step]["extra"]["end-to-end"]:
            setting = f"{setting}-local"
    else :
        setting = ""
        for param in config[step]["custom_set"].keys():
            setting = "-{} {} ".format(param,config[step]["custom_set"].get(param))
    cmd.append(setting)

    for param in config[step]["extra"].keys():
        value = config[step]["extra"].get(param)
        if type(value) == bool and value:
            #add flag
            cmd.append("--{}".format(param))
        elif type(value) == bool and not value:
            pass
        else:
            #add flag and value
            cmd.append("--{} {}".format(param, value))

    return " ".join(cmd)
