#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import yaml
import os, sys
import subprocess

ps  = subprocess.Popen(('pip', 'show', 'magneto'), stdout=subprocess.PIPE)
pip_output = subprocess.check_output(('grep', 'Location'), stdin=ps.stdout)
ps.wait()
location = pip_output.split()[1].decode('utf-8')
scripts_dir = os.path.join(location, 'magneto/scripts')

sys.path.insert(0, scripts_dir)

import cmdparser as cmdparser
import parserconfig as conf
from snakemake.exceptions import MissingInputException

'''
    This rules file is dedicated to metagenomics assembly. (genosysmics run --help)
    Depending on value registered in the config file
    or arguments parse to genosysmics run,
    single and/or co-assembly might be ran.

    Co-assembly involve extra steps:
        - rule simka_input
        - rule simka
        - checkpoint cluster_simka

    The main workflow is composed by the sequential rules listed below :
        - megahit
        - bowtie2_index
        - bowtie2_mapping
        - sort_and_index
        - filter_assembly
        - coverage_statistics
        - assembly_statistics

    In addition, taxonomic annotation might be performed using CAT and misassembly might be detected using DeepMased.

'''

################################################################################
#
#    Assign config values
#
################################################################################

CONDAENV = "../envs" #It is defined in main Snakefile but increase readibility
working_dir = config["project"]
tmp = os.path.join(working_dir, "tmp") #It is defined in main Snakefile but increase readibility
logdir = os.path.join(working_dir, "logs")

#store all pre-processing reports
reports_dir = os.path.join(working_dir, "reports/pre_processing")
#store results from different programs
intermediate_results_dir = os.path.join(working_dir, "intermediate_results")

with open(config['samples'], 'r') as samplefile:
    samples = yaml.load(samplefile, Loader=yaml.FullLoader)

if config["Assembly"]["filter_assembly"]:
    assembly =  "post_filtering.contigs.fa"
    bam_suffix = "_post_filtering.sorted.bam"
    filtering = ["pre","post"]
else:
    assembly = "pre_filtering.contigs.fa"
    bam_suffix = "_pre_filtering.sorted.bam"
    filtering = ["pre"]

if len(samples.keys()) > 200:
    ext = "pdf"
else:
    ext = "html"

assembler = config["Assembly"]["Assembler"]
simka_type = config['Simka']['type']

assembly_request = []
strategies = config["binning_strategies"]

if 'SASB' in strategies or 'SACB' in strategies:
    assembly_request.append('single_assembly')
if 'CASB' in strategies or 'CACB' in strategies:
    assembly_request.append('co_assembly')
assembly_dict={}
if "single_assembly" in assembly_request :
    assembly_dict = {"single_assembly": samples.keys()}

if simka_type != 'None':
    output_of_rule_simka_input = f'tmp/{simka_type}_input.txt'
    input_of_checkpoint_simka = "intermediate_results/assembly/co_assembly/simka/mat_abundance_braycurtis.csv.gz"
    input_of_target_simka = directory("intermediate_results/assembly/co_assembly/clusters")
else:
    output_of_rule_simka_input = 'tmp/samples.txt'
    input_of_checkpoint_simka = 'tmp/samples.txt' #skip simka/simkaMin rules
    input_of_target_simka = 'tmp/simka_skip.txt' #skip cluster_simka checkpoint

################################################################################
#
#    Target definition.
#    Assembly step : require reads (QC, merge or raw) and produce an assembly per sample or per cluster
#
####################################################################################################

def aggregate_targets(wildcards):
    targets = []    

    # Main assembly output
    targets.extend(expand("tables/{assembly}_statistics.tsv",
        assembly=assembly_request)
        )
    targets.extend(expand("tables/contigs_{assembly}_statistics_{request_filtering}_filtering.csv",
        assembly=assembly_request,
        request_filtering=filtering)
        )
    # append deepMased results if required
    if config["Assembly"]["DeepMased"]:
        targets.extend(expand("tmp/DeepMased_{assembly}.txt",
            assembly=assembly_request))
    # append CAT results if required
    if config["Assembly"]["CAT"]:
        targets.extend(["data/assembly/all_contigs_and_assembly.summary.txt",
            "data/assembly/all_contigs_lineages.txt"]) #cat results
            
    #logging.info("Expected output are {}".format(targets))
    #print(f"Expected output are {targets}")
    return targets


rule target_assembly:
    output:
        os.path.join(tmp, "assembly.checkpoint"),
    input:
        aggregate_targets,
    shell:
        "touch {output}"

################################################################################
#
#    CAT annotation
#
################################################################################
"""
def aggregate_cat_named_classification(wildcards):
    assembly_dict["single_assembly"]=samples

    if  "co_assembly" in assembly_request:
        # get simka output from checkpoint
        # it's a list of file path
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]

        # infere cluster id wildcard from the list of file path
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
    inputs = []
    for request in assembly_request:
        # expand all files required as input
        inputs.extend(expand(
            os.path.join(intermediate_results_dir, "assembly/{assembly}",
                assembler,
                "{src}/catbat_annotation/{src}_named.contig2classification.txt"),
            assembly=request,
            src=assembly_dict.get(request))
        )
    return inputs

rule cat_contigs_lineages:
    output:
        "data/assembly/all_contigs_lineages.txt",
    input:
        "data/assembly/all_contigs_and_assembly.summary.txt",
    params:
        input = aggregate_cat_named_classification,
    shell:
        "head -n 1 {params.input[0]} > {output} && "
        " tail -q -n +2 {params.input} >> {output} "

def aggregate_cat_summary(wildcards):
    if wildcards.assembly == "co_assembly":
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid

    return expand(
            os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{src}/catbat_annotation/{src}.summary.txt"),
            src=assembly_dict.get(wildcards.assembly)
        )

rule cat_full_results:
    output:
        "data/assembly/all_contigs_and_assembly.summary.txt",
    input:
        expand("data/assembly/{assembly}/all_contigs.summary.txt",
        assembly=assembly_request),
    shell:
        "head -n 1 {input[0]} > {output} && "
        " tail -q -n +2 {input} >> {output} "

rule cat_assembly_summary:
    output:
        "data/assembly/{assembly}/all_contigs.summary.txt",
    input:
        aggregate_cat_summary,
    params:
        first  = lambda wildcards, input: input[0],
        others = lambda wildcards, input: input[1:],
    shell:
        "cat {params.first} <(tail -q -n +2 {params.others}) > {output}"

rule cat_summary:
    output:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/catbat_annotation/{src}.summary.txt"),
    input:
        assembly = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_" + assembly),
        cat = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/catbat_annotation/{src}_named.contig2classification.txt"),
    conda:
        os.path.join(CONDAENV, "catbat.yaml")
    shell:
        "CAT summarise "
        "-c {input.assembly} "
        "-i {input.cat} "
        "-o {output}"

rule cat_name:
    output:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/catbat_annotation/{src}_named.contig2classification.txt"),
    input:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/catbat_annotation/{src}.contig2classification.txt"),
    conda:
        os.path.join(CONDAENV,"catbat.yaml")
    params:
        db   = os.path.join(working_dir, "Database/catbat_db"),
        taxo = os.path.join(working_dir, "Database/catbat_taxo"),
    shell:
        "CAT add_names "
        "-i {input} "
        "-o {output} "
        "-t {params.taxo} "
        "--only_official"

rule cat_assembly:
    output:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/catbat_annotation/{src}.contig2classification.txt"),
    input:
        assembly = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_" + assembly),
        db       = os.path.join(working_dir, "Database/catbat_db"),
        taxo     = os.path.join(working_dir, "Database/catbat_taxo"),
    conda:
        os.path.join(CONDAENV,"catbat.yaml")
    params:
        db   = os.path.join(working_dir, "Database/catbat_db"),
        taxo = os.path.join(working_dir, "Database/catbat_taxo"),
        out  = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/catbat_annotation/{src}"),
    log:
        os.path.join("logs/{assembly}/contigs_annotation/{src}_catbat.log")
    threads: 10
    shell:
        "CAT contigs "
        "-c {input.assembly} "
        "-d {params.db} "
        "-t {params.taxo} "
        "-o {params.out} "
        "--nproc {threads} "
        "--force"

################################################################################
#
#    Test DeepMased on single assembly:
#    Misassemblies detection using DeepMAsED (without ref approach)
#
################################################################################

def aggregate_deepmased(wildcards):
    if wildcards.assembly == "co_assembly":
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
    return expand(
            os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{src}/misassemblies/deepmased_predictions.tsv"),
            src=assembly_dict.get(wildcards.assembly)
        )

rule aggregate_deepmased:
    output:
        temp("tmp/DeepMased_{assembly}.txt"),
    input:
        aggregate_deepmased,
    shell:
        "touch {output}"

rule DeepMAsED_predict:
    output:
         os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/misassemblies/deepmased_predictions.tsv"),
    input:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/misassemblies/feature_file_table.tsv")
    conda:
        os.path.join(CONDAENV, "deepmased.yaml")
    params:
        out = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/misassemblies"),
    log:
        "logs/DeepMAsED/{assembly}/predict_{src}.log"
    shell:
        "DeepMAsED predict {input} --cpu-only --save-path {params.out} &> {log} || touch {output}"

rule DeepMAsED_features:
    output:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/misassemblies/feature_file_table.tsv")
    input:
        "tmp/DeepMased/{assembly}/{src}_bam_fasta_file.tsv"
    conda:
        os.path.join(CONDAENV, "deepmased.yaml")
    params:
        outname = "tmp/DeepMased/{assembly}/{src}_feature_file_table.tsv"
    threads: 10
    log: "logs/DeepMAsED/{assembly}/features_{src}.log"
    shell:
        "DeepMAsED features {input} "
        "-p {threads} "
        #"-o {params.out} "
        "-n {params.outname} "
        "-d &> {log}; "
        "mv {params.outname} {output} "

rule bam_fasta_file:
    output:
        temp("tmp/DeepMased/{assembly}/{src}_bam_fasta_file.tsv"),
    input:
        bam   = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/mapped_reads/{src}" + bam_suffix),
        fasta = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_" + assembly),
    run:
        import csv
        with open(output[0], 'wt') as out_file:
            writer = csv.writer(out_file, delimiter='\t')
            writer.writerow(['bam', 'fasta'])
            writer.writerow([input.bam, input.fasta])
"""
################################################################################
#
#    Filtering and statistics
#    Assembly
#
################################################################################

rule plot_assembly:
    output:
        expand(os.path.join("reports", "{{assembly}}", "{meters}." + ext),
            meters=["N50", "L50", "total_length", "n_contigs", "n_longcontigs", "n_verylongcontigs"])
    input:
        stats = "tables/{assembly}_statistics.tsv",
    params:
        filtering = config["Assembly"]["filter_assembly"],
        out = "reports/{assembly}",
    conda:
        os.path.join(CONDAENV, "reports.yaml")
    script:
        "../scripts/plot_assembly.py"

def expand_assembly_stats(wildcards):
    if wildcards.assembly == "co_assembly":
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
    return expand(
            os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{src}/contigs/{src}_assembly_stats.tsv"),
            src=assembly_dict.get(wildcards.assembly)
        )

rule aggregate_assembly_statistics:
    output:
        os.path.join("tables","{assembly}_statistics.tsv"),
    input:
        expand_assembly_stats,
    run:
        import pandas as pd
        dfs = []
        for i in list(input):
            dfs.append(pd.read_csv(i, sep="\t", header=0, index_col=0))
        results = pd.concat(dfs)
        results.to_csv(output[0], sep="\t")

rule assembly_statistics:
    output:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_assembly_stats.tsv")
    input:
        contigstats = expand(
            os.path.join(intermediate_results_dir, "assembly/{{assembly}}",
            assembler,
            "{{src}}/contigs/contigstats_{request_filtering}_filtering.csv"),
        request_filtering=filtering),
    conda:
        os.path.join(CONDAENV, "bamutils.yaml")
    params:
        unfiltered = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/contigstats_pre_filtering.csv"),
        filtered   = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/contigstats_post_filtering.csv") if "post" in filtering else None,
        src = "{src}"
    script:
        "../scripts/contigstats.py"

def expand_contigs_stats(wildcards):
    if wildcards.assembly == "co_assembly":
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid

    return expand(
            os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{src}/contigs/contigstats_{{request_filtering}}_filtering.csv"),
            src = assembly_dict.get(wildcards.assembly)
        )

rule aggregate_contigs_statistics:
    output:
        "tables/contigs_{assembly}_statistics_{request_filtering}_filtering.csv"
    input:
        expand_contigs_stats,
    run:
        import pandas as pd
        dfs = []
        for i in list(input):
            strategy = i.split("/")[2]
            src = i.split("/")[4]
            df = pd.read_csv(i, sep="\t", header=0, index_col=0)
            l_strat = [strategy for x in range(0, df.shape[0])]
            l_src = [src for x in range(0, df.shape[0])]

            df["strategy"] = l_strat
            df["src"] = l_src
            dfs.append(df)
        results = pd.concat(dfs)
        results.to_csv(output[0], sep="\t")

rule filter_assembly:
    '''
        Filter assembled contigs by their coverage.
    '''
    output:
        filtered = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_post_filtering.contigs.fa"),
        failed   = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_failed_filtering.contigs.fa"),
    input:
        assembly = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_pre_filtering.contigs.fa"),
        covstats = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/contigstats_pre_filtering.csv"),
    conda:
        os.path.join(CONDAENV, "bamutils.yaml")
    params:
        min_reads  = lambda wildcards: config[wildcards.assembly]["min_reads"],
        min_length = lambda wildcards: config[wildcards.assembly]["min_length"],
        min_breath = lambda wildcards: config[wildcards.assembly]["min_breath"],
        min_depth  = lambda wildcards: config[wildcards.assembly]["min_depth"],
        properly_paired = lambda wildcards: "-p" if config[wildcards.assembly]["properly_paired"] else "",
    log: "logs/{assembly}/filtering_contigs/{src}.log"
    script:
        "../scripts/filterbycoverage.py"

rule coverage_statistics:
    '''
    Compute estimation of coverage of contigs.
    '''
    output:
        stats = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler,"{src}/contigs/contigstats_{request_filtering}_filtering.csv"),
    input:
        align = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/mapped_reads/{src}_{request_filtering}_filtering.sorted.bam"),
    conda:
        os.path.join(CONDAENV,"bamutils.yaml")
    params:
        only_covered = lambda wildcards: "--only-covered" if config[wildcards.assembly]["only_covered"] else "",
        min_baseq = lambda wildcards: config[wildcards.assembly]["min_baseq"],
        depth_per_base = lambda wildcards:  config[wildcards.assembly]["depth_per_base"],
    log: "logs/{assembly}/{request_filtering}_statistics/{src}_{request_filtering}.log"
    script:
        "../scripts/metacovest.py"


################################################################################
#
#    reads back mapped to their assembly
#
################################################################################

def input_cmd_assembly(wildcards):
    if wildcards.assembly == "single_assembly":
        list_reads = []
        for run in reads2use[wildcards.src]:
            list_reads.extend(reads2use[wildcards.src][run])
        return list_reads
    elif wildcards.assembly == "co_assembly":
        if simka_type=="None" :
            return 'tmp/samples.txt'
        return "intermediate_results/assembly/co_assembly/clusters/{src}.txt"
    else:
        raise ValueError


rule sort_and_index:
    output:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/mapped_reads/{src}_{request_filtering}_filtering.sorted.bam"),
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/mapped_reads/{src}_{request_filtering}_filtering.sorted.bam.bai"),
    input:
        os.path.join(tmp,"{assembly}/{src}_to_{src}_{request_filtering}_filtering.sam"),
    log: "logs/{assembly}/sort_and_index/{src}_{request_filtering}_sorting.log"
    threads: 10
    priority: 100
    conda:
        os.path.join(CONDAENV, "samtools.yaml")
    priority: 5
    shell:
        "samtools view -u {input} | "
        "samtools sort "
        "-@ {threads} "             # number of threads used
        "-o {output[0]} "
        "&& samtools index "
        "{output[0]} "
        "&> {log} "

rule bowtie2_map:
    """
    Align reads against their assembly, using bowtie2.
    """
    output:
        temp(os.path.join(tmp,"{assembly}/{src}_to_{src}_{request_filtering}_filtering.sam"))
    input:
        index1 = expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}/megahit/{{src}}/index/{{src}}_{{request_filtering}}_filtering.{id}" + (".bt2l" if config["large_index"] else ".bt2")), id=range(1,4)),
        index2 = expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}/megahit/{{src}}/index/{{src}}_{{request_filtering}}_filtering.rev.{id}" + (".bt2l" if config["large_index"] else ".bt2")), id=range(1, 2)),
        reads  = input_cmd_assembly,
    threads: 10
    priority: 80
    conda:
        os.path.join(CONDAENV, "bowtie2.yaml")
    params:
        prefix = os.path.join("intermediate_results/assembly/{assembly}", assembler, "{src}/index/{src}_{request_filtering}_filtering"),
        input_reads = lambda wildcards, input : cmdparser.cmd(wildcards.src, input.reads, reads2use, "bowtie2").cmd,
        cmd = lambda wildcards : conf.mapping_cmd(config, wildcards.assembly),
    shell:
        "bowtie2 "
        "-p {threads} "             # number of parallel threads
        "--no-unal "                # remove unmapped reads (decrease size)
        "-x {params.prefix} "       # index for mapping
        "{params.input_reads} "
        "{params.cmd} "
        "-S {output} "

if config["large_index"]:       #for big datasets, large_index in config file need to be True
    rule bowtie2_index:
        """
            Build a Bowtie2 index from assembly
        """
        output:
            expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src}}/index/{{src}}_{{request_filtering}}_filtering.{id}.bt2l"), id=range(1,4)),
            expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src}}/index/{{src}}_{{request_filtering}}_filtering.rev.{id}.bt2l"), id=range(1, 2)),
        input:
            os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_{request_filtering}_filtering.contigs.fa"),
        log:
            "logs/index_{assembly}/{src}_{request_filtering}_filter_indexing.log"
        threads: 10
        conda:
            os.path.join(CONDAENV, "bowtie2.yaml")
        params:
            prefix = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/index/{src}_{request_filtering}_filtering")
        priority:5
        shell:
            "bowtie2-build "
            "--large-index "
            "--threads {threads} "          # Number of parallel threads
            "{input} "                      # List of genomes in a single gz
            "{params.prefix} "              # Basename of the database
            "&> {log}"
else:
    rule bowtie2_index:
        """
            Build a Bowtie2 index from assembly
        """
        output:
            expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src}}/index/{{src}}_{{request_filtering}}_filtering.{id}.bt2"), id=range(1,4)),
            expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src}}/index/{{src}}_{{request_filtering}}_filtering.rev.{id}.bt2"), id=range(1, 2)),
        input:
            os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_{request_filtering}_filtering.contigs.fa"),
        log:
            "logs/index_{assembly}/{src}_{request_filtering}_filter_indexing.log"
        threads: 10
        conda:
            os.path.join(CONDAENV, "bowtie2.yaml")
        params:
            prefix = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/index/{src}_{request_filtering}_filtering")
        priority:5
        shell:
            "bowtie2-build "
            "--threads {threads} "          # Number of parallel threads
            "{input} "                      # List of genomes in a single gz
            "{params.prefix} "              # Basename of the database
            "&> {log}"

################################################################################
#
#    Assembly
#
################################################################################

rule megahit:
    """
    Assemble reads into contigs, using Megahit.
    """
    output:
        os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src}/contigs/{src}_pre_filtering.contigs.fa"),
    input:
        input_cmd_assembly,
    log: "logs/{assembly}/{src}_assembly.log"
    benchmark: "benchmark/megahit/{assembly}/{src}.megahit.performances.txt"
    params:
        output_dir = os.path.join(intermediate_results_dir, "assembly/{assembly}/megahit/{src}/contigs"),
        cmd = lambda wildcards, input: cmdparser.cmd(wildcards.src, input, reads2use, "megahit").cmd,
        intermediate_contigs = os.path.join(intermediate_results_dir, "assembly/{assembly}/megahit/{src}/contigs/intermediate_contigs")
    conda:
        os.path.join(CONDAENV, "megahit.yaml")
    threads: 10
    priority: 5
    shell:
        'megahit -f ' #-f to avoid error because of already existing dir
        '-t {threads} '
        '{params.cmd} '
        '--presets meta-large '
        '-o {params.output_dir} '
        '--out-prefix  {wildcards.src}_pre_filtering && '
        'perl -i -p -e "s/>/>{wildcards.src}_/" {output[0]}' #Use perl because more portable than sed
        '&>{log} && '
        'rm -rf {params.intermediate_contigs} '

################################################################################
#
#    Samples clustering using simka and CAH
#    https://github.com/GATB/simka
#
################################################################################

rule target_simka:
    output:
        "tmp/finished_sample_clustering"
    input:
        input_of_target_simka
    shell:
        "touch {output}"


checkpoint cluster_simka:
    """
    Call clustersample.py to infere cluster of samples from simka results.
    Avalaible method : CAH (add a config parameter
    and update clustersample.py to set the method from script calling)
    """
    output:
        input_of_target_simka
    input:
        input_of_checkpoint_simka,
    conda: 
        os.path.join(CONDAENV, "sampleclustering.yaml")
    log: "logs/co_assembly/clustering.log"
    priority: 5
    params:
        height = config["Simka"]["height"]
    script:
        "../scripts/HCA.py"

rule simkaMin:
    '''
    As simka and simkaMin do not have the same options, we need two separate rules.
    see simka doc : https://github.com/GATB/simka
    '''
    output:
        "intermediate_results/assembly/co_assembly/simkaMin/mat_abundance_braycurtis.csv.gz",
        "intermediate_results/assembly/co_assembly/simkaMin/mat_presenceAbsence_jaccard.csv.gz",
    input:
        "tmp/simkaMin_input.txt",
    log:
        "logs/co_assembly/simkaMin.log"
    params:
        outdir = "intermediate_results/assembly/co_assembly/simkaMin",
    conda:
        os.path.join(CONDAENV, "simka.yaml")
    threads: 20
    shell:
        "simkaMin.py "
        "-bin $(which simkaMinCore) "
        "-in {input} "
        "-out {params.outdir} "
        "-filter "
        "-nb-cores {threads} "
        "&>{log};"

rule simka:
    '''
        see simka doc : https://github.com/GATB/simka
    '''
    output:
        "intermediate_results/assembly/co_assembly/simka/mat_abundance_braycurtis.csv.gz",
    input:
        "tmp/simka_input.txt",
    log:
        "logs/co_assembly/simka.log"
    params:
        outdir = "intermediate_results/assembly/co_assembly/simka",
        tmpdir = os.path.join(tmp, "simka"),
    conda:
       os.path.join(CONDAENV, "simka.yaml")
    threads: 20
    shell:
        "$(which simka) "
        "-in {input} "
        "-out {params.outdir} "
        "-out-tmp {params.tmpdir} "
        "-simple-dist "
        "-nb-cores {threads} "
        "-abundance-min 2 "
        "-abundance-max 200 &>{log};"

def input_for_simka_input(reads2use=reads2use):
    list_files = []
    for sample in reads2use:
        for run in reads2use[sample]:
            for run in reads2use[sample][run]:
                list_files.append(run)
    return list_files

rule simka_input:
    '''
    Generate input file required by simka
    '''
    output:
        output_of_rule_simka_input,
    input:
        reads = input_for_simka_input(reads2use)
    params:
        reads = reads2use,
        simka_type = simka_type
    priority: 5
    run:
        with open(str(output), "w") as ftw:
            for sample in params.reads.keys():
                for run in params.reads[sample]:
                    if len(list(params.reads[sample][run])) == 1:
                        reads_list = os.path.abspath(params.reads[sample][run][0])
                        ftw.write(f"{sample} {run}: {reads_list}\n")
                    else :
                        pair1 = os.path.abspath(params.reads[sample][run][0])
                        pair2 = os.path.abspath(params.reads[sample][run][1])
                        ftw.write(f"{sample} {run}: {pair1} ; {pair2}\n")
