#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import yaml
from snakemake.logging import logger
import os
import gzip

#############
# Set paths #
#############
## TODO: get all paths parameterizable in config file

#env_dir = "../envs" #It is defined in main Snakefile but increase readibility

config_file = 'config/config.yaml'

with open(config_file, 'r') as ftr:
    config = yaml.load(ftr, Loader=yaml.FullLoader)

working_dir = config["project"]
db_dir = config["database"]
tmpdir = f"{working_dir}/tmp" #It is defined in main Snakefile but increase readibility
logdir = f"{working_dir}/logs"

#store all pre-processing reports
reports_dir = f"{working_dir}/reports/pre_processing"

#store filtered reads obtained with QC
reads_PE_dir_metaT = f"{working_dir}/metaT/intermediate_results/reads/PE"
reads_PE_dir_metaG = f"{working_dir}/intermediate_results/reads/PE"

reads_SE_dir_metaT = f"{working_dir}/metaT/intermediate_results/reads/SE"
reads_SE_dir_metaG = f"{working_dir}/intermediate_results/reads/SE"

########################################################################################################

def fastq_is_empty(fastq):
    '''
    Check if a fastq.gz is empty or not.
    May be useful for unpaired reads
    which represent 0 up to 90 percent of a paired-end dataset.
    '''
    with gzip.open(fastq, 'rb') as f:
        data = f.read(1)
    return len(data) == 0

########################################################################################################

def parse_samples(config):
    '''
    Load the samples.yaml file into a python dictionary
    '''
    with open(config["samples"], "r") as samplefile:
        samples = yaml.load(samplefile, Loader=yaml.FullLoader)
    #logger.info("Found {} biological samples".format(len(samples)))
    return samples

def parse_metaT_samples(config):
    '''
    Load the metaT_samples.yaml file into a python dictionary
    '''
    with open(config["metaT_samples"], "r") as samplefile:
        metaTsamples = yaml.load(samplefile, Loader=yaml.FullLoader)
    #logger.info("Found {} biological samples".format(len(samples)))
    return metaTsamples

########################################################################################################

def initreads(samples, reads_PE_dir, reads_SE_dir):
    '''
    Create symlink of raw reads files under the "reads" folder of your project.
    Thus raw data and pre-process data will be store in a single directory per sample
    It's an alternative way to samples directory to retrieve reads file.
    '''
    for sample in samples:
        for run in samples[sample]:
            try:
                #if os.path.isfile(run):
                #    os.makedirs(os.path.join(reads_SE_dir, sample, run),
                #    exist_ok=True)
                #    os.symlink(samples[sample][run],
                #     os.path.join(reads_SE_dir, sample, os.path.basename(run))
                #    )

                if len(samples[sample][run]) == 1: #SE
                    os.makedirs(os.path.join(reads_SE_dir, sample),
                     exist_ok=True)
                    os.symlink(os.path.abspath(samples[sample][run][0]),
                     f"{reads_SE_dir}/{sample}/{run}")

                else:
                    os.makedirs(f"{reads_PE_dir}/{sample}/", exist_ok=True)
                    os.symlink(os.path.abspath(samples[sample][run][0]),
                    f"{reads_PE_dir}/{sample}/{sample}_{run}_R1")
                    os.symlink(os.path.abspath(samples[sample][run][1]),
                    f"{reads_PE_dir}/{sample}/{sample}_{run}_R2")
            except FileExistsError:
                #avoids error from symlink generation
                pass #do nothing. What did you expect ?
            except:#for other errors (inaccurate : just for debugging)
                #logger.warning(f"Unfound run : {run} !")
                raise #debugging

########################################################################################################

def resolveReads(config, samples, reads_PE_dir, reads_SE_dir):
    '''
    '''
    tmp ={}
    initreads(samples, reads_PE_dir, reads_SE_dir)

    if config["skip_QC"] and not config["merging"]:
        # use raw sample
        return samples

    elif config["skip_QC"] and config["merging"]:
        # use raw read merged
        for sample in samples.keys():
            if sample not in tmp:
                tmp[sample] = {}
            for run in samples[sample]:
                if len(samples[sample][run]) == 2: #PE
                    tmp[sample][run] = [
                            f"{reads_PE_dir}/{sample}/{sample}_{run}_R1.unmerged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{sample}_{run}_R2.unmerged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{sample}_{run}.merged.fastq.gz",
                            ]
                elif len(samples[sample][run]) == 1: #SE
                    tmp[sample] = [os.readlink(os.path.join(reads_SE_dir, sample, run))]
                else:
                    raise Exception('no sample found')
        return tmp

    elif not config["skip_QC"] and not config["merging"]:
        # use QC reads
        for sample in samples.keys():
            if sample not in tmp:
                tmp[sample] = {}
            for run in samples[sample]:
                if len(samples[sample][run]) == 2: #PE
                    tmp[sample][run] = [
                        f"{reads_PE_dir}/{sample}/{sample}_{run}_R1.filtered.fastq.gz",
                        f"{reads_PE_dir}/{sample}/{sample}_{run}_R2.filtered.fastq.gz",
                        f"{reads_PE_dir}/{sample}/{sample}_{run}_unpaired.filtered.fastq.gz",
                        ]
                elif len(samples[sample][run]) == 1: #SE
                    tmp[sample][run] = [f"{reads_SE_dir}/{sample}/{sample}_{run}_SE.filtered.fastq.gz"]
                else:
                    raise Exception('No sample found')
        return tmp
    elif not config["skip_QC"] and config["merging"]:
        # use QC reads merged
        for sample in samples.keys():
            if sample not in tmp:
                tmp[sample] = {}
            for run in samples[sample]:
                if len(samples[sample][run]) == 2: #PE
                    tmp[sample][run] = [
                            f"{reads_PE_dir}/{sample}/{sample}_{run}_R1.unmerged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{sample}_{run}_R2.unmerged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{sample}_{run}.merged.fastq.gz",
                            f"{reads_PE_dir}/{sample}/{sample}_{run}_unpaired.filtered.fastq.gz",
                            ]
                elif len(samples[sample][run]) == 1: #SE
                    tmp[sample][run] = [f"{reads_SE_dir}/{sample}/{sample}_{run}_SE.filtered.fastq.gz"]
                else:
                    raise Exception('No sample found')
        return tmp

########################################################################################################

def index_contamination(config, deconta):
    '''
    For contamination removal --
    build a dictionary with species as keys and prefixed index files
    '''
    speciesindex = {}
    if len(deconta.keys()) > 0:
        for species in list(deconta.keys()):
            indexes = []
            indexes.append(f"{db_dir}/{deconta.get(species)}.1.bt2")
            indexes.append(f"{db_dir}/{deconta.get(species)}.2.bt2")
            indexes.append(f"{db_dir}/{deconta.get(species)}.3.bt2")
            indexes.append(f"{db_dir}/{deconta.get(species)}.4.bt2")
            indexes.append(f"{db_dir}/{deconta.get(species)}.rev.1.bt2")
            indexes.append(f"{db_dir}/{deconta.get(species)}.rev.2.bt2")
            speciesindex[species] = indexes
        return speciesindex

########################################################################################################

def SE_symlink(wildcards, reads_SE_dir):
    '''
    Small input function to retrieve a file path from a symlink, for
    single-end data case.
    '''
    r1 = os.path.join(reads_SE_dir, wildcards.SE , wildcards.run)
    return os.readlink(r1)

########################################################################################################

def PE_symlink(wildcards, reads_PE_dir):
    '''
    Small input function to retrieve a file path from a symlink, for
    paired-end data case.
    '''
    r1 = os.path.join(reads_PE_dir, wildcards.PE, wildcards.run + "_R1")
    r2 = os.path.join(reads_PE_dir, wildcards.PE, wildcards.run + "_R2")
    return os.readlink(r1), os.readlink(r2)

########################################################################################################

def parse(config):
    '''
    "Main" parser. Return 3 dictionaries. See above.
    '''
    deconta = config["contaminant"]
    samples = parse_samples(config)
    reads2use = resolveReads(config, samples, reads_PE_dir_metaG, reads_SE_dir_metaG)
    #logger.info("Reads QC {}".format("enable" if not config["skip_QC"] else "disable"))
    #logger.info("Found {} species for decontamination".format(len(list(deconta.keys()))))
    #logger.info("species : {}".format(list(deconta.keys())))
    initreads(samples, reads_PE_dir_metaG, reads_SE_dir_metaG)
    return reads2use, samples, deconta

def parse_metaT(config):
    '''
    "Main" parser. Return 3 dictionaries. See above.
    '''
    deconta = config["contaminant"]
    metaTsamples = parse_metaT_samples(config)
    metaTreads2use = resolveReads(config, metaTsamples, reads_PE_dir_metaT, reads_SE_dir_metaT)
    #logger.info("Reads QC {}".format("enable" if not config["skip_QC"] else "disable"))
    #logger.info("Found {} species for decontamination".format(len(list(deconta.keys()))))
    #logger.info("species : {}".format(list(deconta.keys())))
    initreads(metaTsamples, reads_PE_dir_metaT, reads_SE_dir_metaT)
    return metaTreads2use, metaTsamples, deconta



def mapping_cmd(config, step):
    '''
    Parse mapping parameters for bowtie2 from config file.

    config : config loaded in dictionary
    step : config dictionary key

    Return bowtie2 command for mapping rules.
    '''
    #print(step, type(step)) #debugging
    cmd = []
    if config[step]["preset"] != "None":
        setting = "--{}".format(config[step]["preset"])
        if not config[step]["extra"]["end-to-end"]:
            setting = f"{setting}-local"
    else :
        setting = ""
        for param in config[step]["custom_set"].keys():
            setting = "-{} {} ".format(param,config[step]["custom_set"].get(param))
    cmd.append(setting)

    for param in config[step]["extra"].keys():
        value = config[step]["extra"].get(param)
        if type(value) == bool and value:
            #add flag
            cmd.append("--{}".format(param))
        elif type(value) == bool and not value:
            pass
        else:
            #add flag and value
            cmd.append("--{} {}".format(param, value))
    cmd = " ".join(cmd)
    assert(type(cmd) == str)
    return cmd
