#!/usr/bin/env python3
# -*- coding: utf-8 -*-

####################################################################################################
#
#    Declaration
#        mags : directory where mags are stored.
#        assembly :
#        bam_suffix :
#        co_assembly :
##
#
####################################################################################################

CONDAENV = "../envs" #It is defined in main Snakefile but increase readibility
working_dir = config["project"]
tmpdir = os.path.join(working_dir, "tmp") #It is defined in main Snakefile but increase readibility
logdir = os.path.join(working_dir, "logs")

#store all pre-processing reports
reports_dir = os.path.join(working_dir, "reports/pre_processing")
#store results from different programs
intermediate_results_dir = os.path.join(working_dir, "intermediate_results")

with open(config['samples'], 'r') as samplefile:
    samples = yaml.load(samplefile, Loader=yaml.FullLoader)

if config["Assembly"]["filter_assembly"]:
    assembly =  "post_filtering.contigs.fa"
    bam_suffix = "_post_filtering.sorted.bam"
    filtering = ["pre","post"]
    index = "post"
else:
    assembly = "pre_filtering.contigs.fa"
    bam_suffix = "_pre_filtering.sorted.bam"
    filtering = ["pre"]
    index = "pre"

assembler = config["Assembly"]["Assembler"]

bins_collection = config["Binning"]["filter_mags"][0]

binning_request = []
assembly_request = []
strategies = config["binning_strategies"]

if 'SASB' in strategies or 'SACB' in strategies:
    assembly_request.append('single_assembly')
if 'CASB' in strategies or 'CACB' in strategies:
    assembly_request.append('co_assembly')
for s in ['SASB', 'SACB', 'CASB', 'CACB']:
    if s in strategies:
        binning_request.append(s)
#assembly_dict={}
#if "single_assembly" in assembly_request :
#    assembly_dict = {"single_assembly": samples.keys()}
#reads2use, samples, deconta = conf.parse(config)

binning_dict={"SASB":[],
                "SACB":[],
                "CASB":[],
                "CACB":[],
}
if config["binning_strategies"] == "SASB" or config["binning_strategies"] == "CASB" :
    ruleorder: sort_and_index_binning_SB > concoct_cut_up
if config["binning_strategies"] == "SACB" or config["binning_strategies"] == "CACB" :
    ruleorder: sort_and_index_binning_CB > concoct_cut_up

####################################################################################################
#
#    target(s) definition
#
####################################################################################################

def aggregate_final_output(wildcards):
    if "CASB" in binning_request or "CACB" in binning_request:
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"]= glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
        binning_dict["CASB"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
        binning_dict["CACB"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
    if "SASB" in binning_request or "SACB" in binning_request:
        assembly_dict["single_assembly"]= list(samples.keys())
        binning_dict["SASB"] = list(samples.keys())
        binning_dict["SACB"] = list(samples.keys())
    inputs = []
    for request in binning_request:
        inputs = inputs + expand(os.path.join(tmpdir, "binning/metabat_{binning_strategy}.{src}.done"), binning_strategy=request, src=binning_dict.get(request))
        inputs = inputs + expand(os.path.join(tmpdir, "binning/concoct_{binning_strategy}.{src}.done"), binning_strategy=request, src=binning_dict.get(request))
        inputs = inputs + expand(os.path.join(tmpdir, "binning/semibin2_{binning_strategy}.{src}.done"), binning_strategy=request, src=binning_dict.get(request))
        if bins_collection == "das_tool_mags":
            inputs = inputs + expand(os.path.join(tmpdir, "binning/DAS_tool_{binning_strategy}.{src}.done"), binning_strategy=request, src=binning_dict.get(request))
    return inputs

def aggregate_bin_collection(wildcards):
    '''
        return all mags recovered during binning.snakefile
    '''
    checkpoint_output = checkpoints.binning.get(**wildcards).output[0]
    return expand(os.path.join(intermediate_results_dir, "binning", bins_collection, "{binsid}.fa"),  \
        binsid=glob_wildcards(os.path.join(checkpoint_output, "{binsid}.fa")).binsid)


rule target_pooling_MAGs:
    output:
        os.path.join(tmpdir, "binning_finished.checkpoint")
    input:
        aggregate_bin_collection,
        os.path.join(tmpdir, f"binning/{bins_collection}.done"),
    shell:
        "touch {output} && "
        "rm -r tmp/binning "

checkpoint binning:
    output:
        directory(os.path.join(intermediate_results_dir, f"binning/{bins_collection}"))
    input:
        os.path.join(tmpdir, f"binning/{bins_collection}.done"),
    params:
        bins = os.path.join(intermediate_results_dir, f"binning/tmp_{bins_collection}")
    shell:
        "mv {params.bins} {output}"

rule bin_filtered:
    output:
        os.path.join(tmpdir, "binning/all_mags_filtered.done"),
    input:
        os.path.join(intermediate_results_dir, "binning/tmp_all_mags_filtered/"),
    shell:
        "touch {output}"

rule bin_filtering:
    output:
        directory(os.path.join(intermediate_results_dir, "binning/tmp_all_mags_filtered")),
    input:
        os.path.join(tmpdir, "binning/tmp_all_mags.done"),
        os.path.join(working_dir, "data/assembly/all_contigs_lineages.txt"),
    conda:
        os.path.join(CONDAENV, "bamutils.yaml")
    params:
        bins  = os.path.join(intermediate_results_dir, "binning/tmp_all_mags/"),
        stats = os.path.join(intermediate_results_dir, "binning/data"),
        out   = os.path.join(intermediate_results_dir, "binning/all_mags_filtered")
    script:
        "../scripts/filtermags.py"

rule aggregate_bins:
    output:
        temp(os.path.join(tmpdir, f"binning/{bins_collection}.done")),
    input:
        aggregate_final_output
    shell:
        "touch {output} "

def resolve_assembly(wildcards):
    if wildcards.binning_strategy in ["SASB","SACB"]:
        fasta = os.path.join(intermediate_results_dir, "assembly/single_assembly", assembler, wildcards.src, "contigs", wildcards.src + "_" + assembly)
    elif wildcards.binning_strategy in ["CASB","CACB"]:
        fasta = os.path.join(intermediate_results_dir, "assembly/co_assembly", assembler, wildcards.src, "contigs", wildcards.src + "_" + assembly)
    else:
        raise FileNotFoundError
        exit(1)
    return fasta

if bins_collection == "das_tool_mags":
    rule DAS_tool:
        '''
            DAS Tool to calculate an optimized, non-redundant set of bins
        '''
        output:
            os.path.join(tmpdir, "binning/DAS_tool_{binning_strategy}.{src}.done"),
        input:
            concoct_tsv = os.path.join(tmpdir, "binning/{binning_strategy}/DAS_tool/{src}/concoct_contigs2bin.tsv"),
            metabat_tsv = os.path.join(tmpdir, "binning/{binning_strategy}/DAS_tool/{src}/metabat_contigs2bin.tsv"),
            semibin2_tsv = os.path.join(tmpdir, "binning/{binning_strategy}/DAS_tool/{src}/semibin2_contigs2bin.tsv"),
            fasta = resolve_assembly,
        threads: 5
        conda:
            os.path.join(CONDAENV, "das_tool.yaml")
        params:
            out_dir = os.path.join(tmpdir, "binning/{binning_strategy}/DAS_tool/{src}/{src}_{binning_strategy}"),
            das_tool_aggregate_bins_dir = os.path.join(intermediate_results_dir, "binning/tmp_das_tool_mags")
        shell:
            "DAS_Tool -i {input.concoct_tsv},{input.metabat_tsv},{input.semibin2_tsv} "
            "-l concoct,metabat,semibin2 "
            "-c {input.fasta} "
            "-o {params.out_dir} "
            "--write_bin_evals --write_bins && "
            "mkdir -p {params.das_tool_aggregate_bins_dir} && "
            "mv {params.out_dir}_DASTool_bins/* {params.das_tool_aggregate_bins_dir} && "
            "touch {output} "

    rule DAS_tool_metabat_contig2bins:
        '''
            Converting metabat fasta output into tab separated contigs2bin file
        '''
        output:
            os.path.join(tmpdir, "binning/{binning_strategy}/DAS_tool/{src}/metabat_contigs2bin.tsv"),
        input:
            os.path.join(tmpdir, "binning/metabat_{binning_strategy}.{src}.done"),
        threads: 5
        params:
            bin_dir = os.path.join(tmpdir, "binning/{binning_strategy}/metabat/{src}"),
            script = os.path.join(working_dir, "scripts/Fasta_to_Contig2Bin.sh"),
            output_tmp_mags = os.path.join(intermediate_results_dir, "binning/tmp_all_mags"),
        shell:
            "{params.script} -i {params.bin_dir} -e fa > {output} && "
            "mkdir -p {params.output_tmp_mags} && "
            "mv {params.bin_dir}/* {params.output_tmp_mags} "

    rule DAS_tool_semibin_contig2bins:
        '''
            Converting semibin fasta output into tab separated contigs2bin file
        '''
        output:
            os.path.join(tmpdir, "binning/{binning_strategy}/DAS_tool/{src}/semibin2_contigs2bin.tsv"),
        input:
            os.path.join(tmpdir, "binning/semibin2_{binning_strategy}.{src}.done"),
        threads: 5
        params:
            bin_dir = lambda wildcards : os.path.join(tmpdir, f"binning/{wildcards.binning_strategy}/semibin2/{wildcards.src}/semibin2_output/output_bins") if wildcards.binning_strategy in ["SASB", "CASB"] else os.path.join(tmpdir, f"binning/{wildcards.binning_strategy}/semibin2/semibin2_output/samples/{wildcards.src}_post_filtering/output_bins"),        
            script = os.path.join(working_dir, "scripts/Fasta_to_Contig2Bin.sh"),
            output_tmp_mags = os.path.join(intermediate_results_dir, "binning/tmp_all_mags"),
        shell:
            "{params.script} -i {params.bin_dir} -e fa > {output} && "
            "mkdir -p {params.output_tmp_mags} && "
            "mv {params.bin_dir}/* {params.output_tmp_mags} "

    rule DAS_tool_concoct_contig2bins:
        '''
            Converting CONCOCT csv output into tab separated contigs2bin file
        '''
        output:
            os.path.join(tmpdir, "binning/{binning_strategy}/DAS_tool/{src}/concoct_contigs2bin.tsv"),
        input:
            os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/concoct_output/clustering_merged.csv"),
        threads: 5
        shell:
            "perl -pe 's/,/\tbin_concoct_{wildcards.binning_strategy}_{wildcards.src}./g;' {input} > {output} && "
            "sed -i '1d' {output} "


rule metabat:
    '''
        Performs binning using an assembly fasta files and the depth file
        associated.
    '''
    output:
        temp(os.path.join(tmpdir, "binning/metabat_{binning_strategy}.{src}.done")),
    input:
        depth = os.path.join(intermediate_results_dir, "binning/{binning_strategy}/{src}/depth.txt"),
        fasta = resolve_assembly,
    conda:
        os.path.join(CONDAENV, "metabat.yaml")
    threads: 5
    benchmark:
        "benchmark/binning/{binning_strategy}/metabat_{src}_benchmark.txt"
    log:
        "logs/binning/{binning_strategy}/metabat_{src}.log"
    params:
        output_prefix = os.path.join(tmpdir, "binning/{binning_strategy}/metabat/{src}/bin_metabat_{binning_strategy}_{src}"),
        output_dir = os.path.join(tmpdir, "binning/{binning_strategy}/metabat/{src}"),
        output_tmp_mags = os.path.join(intermediate_results_dir, "binning/tmp_all_mags"),
        min_contigs_length = 1500,
        collection = bins_collection
    shell:
        """
        mkdir -p {params.output_dir}
        metabat2 -i {input.fasta} -m {params.min_contigs_length} -a {input.depth} -o {params.output_prefix} -t {threads} --noAdd
        if [[ '{params.collection}' == 'all_mags' ]]; then 
            mkdir -p {params.output_tmp_mags} && mv {params.output_dir}/* {params.output_tmp_mags} 
        fi
        touch {output} 
        """

rule concoct_extract_fasta_bins:
    '''
        Merge subcontig clustering into original contig clustering
    '''
    output:
        temp(os.path.join(tmpdir, "binning/concoct_{binning_strategy}.{src}.done")),
    input:
        fasta = resolve_assembly,
        clustering_merged = os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/concoct_output/clustering_merged.csv"),
    conda:
        os.path.join(CONDAENV, "concoct.yaml")
    threads: 5
    params:
        concoct_fasta_bin_outdir = os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/concoct_output/fasta_bins"),
        output_tmp_mags = os.path.join(intermediate_results_dir, "binning/tmp_all_mags"),
    shell:
        'mkdir -p {params.concoct_fasta_bin_outdir} && ' 
        'extract_fasta_bins.py {input.fasta} {input.clustering_merged} --output_path {params.concoct_fasta_bin_outdir} && '
        'for file in {params.concoct_fasta_bin_outdir}/*; do '
        '    mv "$file" "{params.concoct_fasta_bin_outdir}/bin_concoct_{wildcards.binning_strategy}_{wildcards.src}.$(basename "$file")"; '
        'done && '
        'mkdir -p {params.output_tmp_mags} && '
        'mv {params.concoct_fasta_bin_outdir}/* {params.output_tmp_mags} && '
        'touch {output}'

rule concoct_cutup_clustering:
    '''
        Merge subcontig clustering into original contig clustering
    '''
    output:
        clustering_merged = temp(os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/concoct_output/clustering_merged.csv")),
    input:
        clustering = os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/concoct_output/clustering_gt.csv"),
    conda:
        os.path.join(CONDAENV, "concoct.yaml")
    threads: 5
    shell:
        "merge_cutup_clustering.py {input.clustering} > {output.clustering_merged} "

rule concoct:
    '''
        Run concoct
    '''
    output:
        clustering = temp(os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/concoct_output/clustering_gt.csv")),
    input:
        cut_contigs = os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/cut_contigs_{src}.fa"),
        coverage = os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/coverage_table_{src}.tsv"),
    conda:
        os.path.join(CONDAENV, "concoct.yaml")
    threads: 5
    params:
        concoct_outdir = os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/concoct_output/"),
        length_threshold = config["Binning"]["concoct_length_threshold"]
    shell:
        "concoct --composition_file {input.cut_contigs} --coverage_file {input.coverage} -l {params.length_threshold} --threads {threads} -b {params.concoct_outdir} && "
        "mv {params.concoct_outdir}clustering_gt*.csv {params.concoct_outdir}clustering_gt.csv "

def concoct_bam_input(wildcards):
    if "CASB" in strategies or "CACB" in strategies:
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
        assembly_request = "co_assembly"

    if "SASB" in strategies or "SACB" in strategies:
        assembly_dict["single_assembly"] = list(samples.keys())
        assembly_request = "single_assembly"
    
    if wildcards.binning_strategy == "SASB" or wildcards.binning_strategy == "CASB":
        inputs = []
        inputs = expand(os.path.join(intermediate_results_dir,
                        "assembly",
                        assembly_request,
                        assembler,
                        wildcards.src,
                        "compute_depth",
                        wildcards.src + "_to_" + wildcards.src + ".sorted.bam"
                        ), src=assembly_dict.get(assembly_request))
        inputs = inputs + expand(os.path.join(intermediate_results_dir,
                        "assembly",
                        assembly_request,
                        assembler,
                        wildcards.src,
                        "compute_depth",
                        wildcards.src + "_to_" + wildcards.src + ".sorted.bam.bai"
                        ), src=assembly_dict.get(assembly_request))
        return inputs
    if wildcards.binning_strategy == "SACB" or wildcards.binning_strategy == "CACB":
        inputs = []
        inputs = expand(os.path.join(intermediate_results_dir,
                        "assembly",
                        assembly_request,
                        assembler,
                        wildcards.src,
                        "compute_depth",
                        "{src}_to_" + wildcards.src + ".sorted.bam"
                        ), src=assembly_dict.get(assembly_request))
        inputs = inputs + expand(os.path.join(intermediate_results_dir,
                        "assembly",
                        assembly_request,
                        assembler,
                        wildcards.src,
                        "compute_depth",
                        "{src}_to_" + wildcards.src + ".sorted.bam.bai"
                        ), src=assembly_dict.get(assembly_request))
        return inputs

def path_to_bam(wildcards):
    if "CASB" in strategies or "CACB" in strategies:
        if wildcards.binning_strategy == "CASB":
            path = os.path.join(intermediate_results_dir, "assembly/co_assembly/megahit/" + wildcards.src + "/compute_depth/" + wildcards.src + "_to_" + wildcards.src + ".sorted.bam")
            return path
        if wildcards.binning_strategy == "CACB":
            path = os.path.join(intermediate_results_dir, "assembly/co_assembly/megahit/" + wildcards.src + "/compute_depth/*.sorted.bam")
            return path
    if "SASB" in strategies or "SACB" in strategies:
        if wildcards.binning_strategy == "SASB":
            path = os.path.join(intermediate_results_dir, "assembly/single_assembly/megahit/" + wildcards.src + "/compute_depth/" + wildcards.src + "_to_" + wildcards.src + ".sorted.bam")
            return path
        if wildcards.binning_strategy == "SACB":
            path = os.path.join(intermediate_results_dir, "assembly/single_assembly/megahit/" + wildcards.src + "/compute_depth/*.sorted.bam")
            return path

rule concoct_coverage_table:
    '''
        Generate table with coverage depth information per sample and subcontig
    '''
    output:
        coverage = temp(os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/coverage_table_{src}.tsv")),
    input:
        os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/contigs_{src}.bed"),
        concoct_bam_input,
    conda:
        os.path.join(CONDAENV, "concoct.yaml")
    threads: 5
    params:
        bam_path = path_to_bam
    shell:
        "concoct_coverage_table.py {input[0]} {params.bam_path} > {output.coverage} "

rule concoct_cut_up:
    '''
        Cut contigs into smaller parts
    '''
    output:
        bed_output = temp(os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/contigs_{src}.bed")),
        cut_contigs_output = temp(os.path.join(tmpdir, "binning/{binning_strategy}/CONCOCT/{src}/cut_contigs_{src}.fa")),
    input:
        fasta = resolve_assembly,
    conda:
        os.path.join(CONDAENV, "concoct.yaml")
    threads: 5
    params:
        chunk_size = 10000,
        overlap_size = 0,
    shell:
        "cut_up_fasta.py {input.fasta} -c {params.chunk_size} -o {params.overlap_size} --merge_last -b {output.bed_output} > {output.cut_contigs_output} "

def aggregate_bam_input(wildcards):
    if "CASB" in strategies or "CACB" in strategies:
        checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
        assembly_dict["co_assembly"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
        assembly_request = "co_assembly"

    if "SASB" in strategies or "SACB" in strategies:
        assembly_dict["single_assembly"] = list(samples.keys())
        assembly_request = "single_assembly"
    if wildcards.binning_strategy == "SASB" or wildcards.binning_strategy == "CASB":
        inputs = expand(os.path.join(intermediate_results_dir,
                        "assembly",
                        assembly_request,
                        assembler,
                        wildcards.src,
                        "compute_depth",
                        wildcards.src + "_to_" + wildcards.src + ".sorted.bam"
                        ), src=assembly_dict.get(assembly_request))
        return inputs
    if wildcards.binning_strategy == "SACB" or wildcards.binning_strategy == "CACB":
        inputs = expand(os.path.join(intermediate_results_dir,
                        "assembly",
                        assembly_request,
                        assembler,
                        wildcards.src,
                        "compute_depth",
                        "{src}_to_" + wildcards.src + ".sorted.bam"
                        ), src=assembly_dict.get(assembly_request))
        return inputs

rule summarize_contig_depth:
    '''
        Compute reads coverage depth to perform binning.
    '''
    output:
        os.path.join(intermediate_results_dir, "binning/{binning_strategy}/{src}/depth.txt"),
    input:
        aggregate_bam_input,
    params:
        bams = aggregate_bam_input,
    conda:
        os.path.join(CONDAENV, "metabat.yaml")
    threads: 5
    priority: 4
    shell:
        "jgi_summarize_bam_contig_depths --outputDepth {output} {params.bams}"

if "SASB" in binning_request or "CASB" in binning_request:
    rule semibin2_SB:
        wildcard_constraints:
            binning_strategy="SASB|CASB"
        output:
            os.path.join(tmpdir, "binning/semibin2_{binning_strategy}.{src}.done"),
        input:
            mapping = aggregate_bam_input,
            contigs = resolve_assembly,
        params:
            semibin_fasta_bin_outdir = os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/{src}/semibin2_output"),
            mapping = aggregate_bam_input,
            output_tmp_mags = os.path.join(intermediate_results_dir, "binning/tmp_all_mags"),
            collection = bins_collection
        conda:
            os.path.join(CONDAENV, "semibin2.yaml")
        shell:
            """
            SemiBin2 single_easy_bin --environment global -i {input.contigs} -b {params.mapping} -o {params.semibin_fasta_bin_outdir}
            for file in {params.semibin_fasta_bin_outdir}/output_bins/*.gz; do 
                gunzip "$file" && 
                base_name=$(basename "${{file%.gz}}" | sed 's/SemiBin_//')
                mv "${{file%.gz}}" "{params.semibin_fasta_bin_outdir}/output_bins/bin_semibin_{wildcards.binning_strategy}_{wildcards.src}.${{base_name}}"; 
            done
            if [[ '{params.collection}' == 'all_mags' ]]; then 
                mkdir -p {params.output_tmp_mags} && mv {params.semibin_fasta_bin_outdir}/output_bins/* {params.output_tmp_mags} 
            fi
            touch {output}
            """

    if "SASB" in binning_request and "SACB" in binning_request or "CASB" in binning_request and "CACB" in binning_request:
        print("No need of single binning mapping rules because files are already created with co-binning")
    else :
        rule sort_and_index_binning_SB:
            wildcard_constraints:
                binning_strategy="SASB|CASB"
            output:
                temp(os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/compute_depth/{src1}_to_{src1}.sorted.bam")),
                temp(os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/compute_depth/{src1}_to_{src1}.sorted.bam.bai")),
            input:
                os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/mapped_reads/{src1}.filtered.sam"),
            threads: 1
            priority: 3
            conda:
                os.path.join(CONDAENV, "samtools.yaml")
            shell:
                "samtools view -u {input} | "
                "samtools sort "
                "-@ {threads} "             # number of threads used
                "-o {output[0]} "
                "&& samtools index "
                "{output[0]} "

        def input_cmd_SB(wildcards):
            if wildcards.assembly == "single_assembly":
                list_reads = []
                for run in reads2use[wildcards.src1]:
                    list_reads.extend(reads2use[wildcards.src1][run])
                return list_reads
            elif wildcards.assembly == "co_assembly":
                if simka_type == "None":
                    return os.path.join(tmpdir, 'samples.txt')
                return os.path.join(intermediate_results_dir, "assembly/co_assembly/clusters/{src1}.txt")
            else:
                raise ValueError

        rule filter_bam_SB:
            """
            Filter reads based on mapping quality and identity.
            Output is temporary because it will be sorted.
            """
            wildcard_constraints:
                binning_strategy="SASB|CASB"
            output:
                temp(os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/mapped_reads/{src1}.filtered.sam")),
            input:
                os.path.join(tmpdir, "binning", "{assembly,.*_assembly}/{src1}_to_{src1}_" + f"{index}_filtering.sam")
            conda:
                os.path.join(CONDAENV, "bamutils.yaml")
            priority: 2
            params:
                min_mapq = config["bam_filtering_before_binning"]["min_quality"],
                min_idt = config["bam_filtering_before_binning"]["min_identity"],
                min_len = config["bam_filtering_before_binning"]["min_len"],
                pp = config["bam_filtering_before_binning"]["properly_paired"],
            script:
                "../scripts/bamprocess.py"

        rule binning_mapping_SB:
            '''
                Align the source reads files against the assembled contigs file to assess contigs' abundance.
            '''
            wildcard_constraints:
                binning_strategy="SASB|CASB"
            output:
                temp(os.path.join(tmpdir, "binning", "{assembly,.*_assembly}/{src}_to_{src1}_" + f"{index}_filtering.sam"))
            input:
                assembly = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/contigs/{src1}_" + f"{assembly}"),
                index1 = expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src1}}/index/{{src1}}_" + index + "_filtering.{id}" + (".bt2l" if config["large_index"] else ".bt2")), id=range(1, 4)),
                index2 = expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src1}}/index/{{src1}}_" + index + "_filtering.rev.{id}" + (".bt2l" if config["large_index"] else ".bt2")), id=range(1,2)),
                reads = input_cmd_SB,
                finished_assembly = os.path.join(tmp, "assembly.checkpoint")
            params:
                prefix = os.path.join("intermediate_results/assembly/{assembly}", assembler, "{src1}/index/{src1}_" + f"{index}_filtering"),
                input_reads = lambda wildcards, input : cmdparser.cmd(wildcards.src, input.reads, reads2use, "bowtie2").cmd,
                cmd = lambda wildcards : conf.mapping_cmd(config, wildcards.assembly),

            threads: 5
            priority: 1
            conda:
                os.path.join(CONDAENV, "bowtie2.yaml")
            shell:
                "bowtie2 "
                "-p {threads} "             # number of parallel threads
                "--no-unal "                # remove unmapped reads (decrease size)
                "-x {params.prefix} "       # index for mapping
                "{params.input_reads} "
                "{params.cmd} "
                "-S {output} "

if "SACB" in binning_request or "CACB" in binning_request:
    rule semibin2_CB_by_src:
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            os.path.join(tmpdir, "binning/semibin2_{binning_strategy}.{src}.done"),
        input:
            semibin2_CB_done = os.path.join(tmpdir, "binning/semibin2_CB_{binning_strategy}.done"),
        params:
            semibin_fasta_bin_outdir = os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/semibin2_output/samples/{src}_post_filtering/output_bins"),
            output_tmp_mags = os.path.join(intermediate_results_dir, "binning/tmp_all_mags"),
            collection = bins_collection
        conda:
            os.path.join(CONDAENV, "semibin2.yaml")
        shell:
            """
            for file in {params.semibin_fasta_bin_outdir}/*.gz; do 
                gunzip "$file" && 
                base_name=$(basename "${{file%.gz}}" | sed 's/SemiBin_//')
                mv "${{file%.gz}}" "{params.semibin_fasta_bin_outdir}/bin_semibin_{wildcards.binning_strategy}_{wildcards.src}.${{base_name}}"; 
            done
            if [[ '{params.collection}' == 'all_mags' ]]; then 
                mkdir -p {params.output_tmp_mags} && mv {params.semibin_fasta_bin_outdir}/* {params.output_tmp_mags} 
            fi
            touch {output}
            """

    def exp_bam_semibin2_CB(wildcards):
        if "SACB" in binning_request:
            binning_dict["SACB"] = list(samples.keys())
            inputs_path = []
            inputs_path = expand(os.path.join(tmpdir, "binning/SACB/semibin2/index/{src}/{src}.sorted.bam"), src=binning_dict.get("SACB"))
            return inputs_path
        if "CACB" in binning_request:
            checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
            print("checkpoint output simka", checkpoint_output_simka)
            binning_dict["CACB"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
            print("binning dict", binning_dict)
            inputs_path = []
            inputs_path = expand(os.path.join(tmpdir, "binning/CACB/semibin2/index/{src}/{src}.sorted.bam"), src=binning_dict.get("CACB"))
            return inputs_path

    rule semibin2_CB:
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            os.path.join(tmpdir, "binning/semibin2_CB_{binning_strategy}.done"),
        input:
            mapping_CB = exp_bam_semibin2_CB,
            concatenated_contigs = os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/concatenated.fa"),
        params:
            semibin_fasta_bin_outdir = os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/semibin2_output"),
            mapping_CB = exp_bam_semibin2_CB,
        conda:
            os.path.join(CONDAENV, "semibin2.yaml")
        shell:
            'SemiBin2 multi_easy_bin -i {input.concatenated_contigs} -b {params.mapping_CB} -o {params.semibin_fasta_bin_outdir} && '
            'touch {output}'

    rule semibin2_sort_and_index:
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/index/{src}/{src}.sorted.bam"),
            os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/index/{src}/{src}.sorted.bam.bai")
        input:
            os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/index/{src}/{src}_to_concatenated_fasta_filtering.sam")
        threads: 1
        priority: 3
        conda:
            os.path.join(CONDAENV, "samtools.yaml")
        shell:
            "samtools view -u {input} | "
            "samtools sort "
            "-@ {threads} "             # number of threads used
            "-o {output[0]} "
            "&& samtools index "
            "{output[0]} "

    def input_cmd_semibin2(wildcards):
        if wildcards.binning_strategy == "SACB":
            list_reads = []
            for run in reads2use[wildcards.src]:
                list_reads.extend(reads2use[wildcards.src][run])
            return list_reads
        elif wildcards.binning_strategy == "CACB":
            if simka_type == "None":
                return os.path.join(tmpdir, 'samples.txt')
            return os.path.join(intermediate_results_dir, "assembly/co_assembly/clusters/{src}.txt")
        else:
            raise ValueError

    rule semibin2_bowtie2_map:
        """
        Align reads against the concatenated assembly, using bowtie2.
        """
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/index/{src}/{src}_to_concatenated_fasta_filtering.sam")
        input:
            index1 = expand(os.path.join(tmpdir, "binning/{{binning_strategy}}/semibin2/index/{{binning_strategy}}_concatenated_fasta.{id}.bt2l"), id=range(1,4)),
            index2 = expand(os.path.join(tmpdir, "binning/{{binning_strategy}}/semibin2/index/{{binning_strategy}}_concatenated_fasta.rev.{id}.bt2l"), id=range(1, 2)),
            reads  = input_cmd_semibin2,
        log:
            "logs/index_concatenated_fasta/{binning_strategy}_{src}_mapping_concatenated_fasta.log"
        threads: 10
        priority: 80
        conda:
            os.path.join(CONDAENV, "bowtie2.yaml")
        params:
            prefix = os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/index/{binning_strategy}_concatenated_fasta"),
            input_reads = lambda wildcards, input : cmdparser.cmd(wildcards.src, input.reads, reads2use, "bowtie2").cmd,
            cmd = lambda wildcards : conf.mapping_cmd(config, "single_assembly" if wildcards.binning_strategy == "SACB" else "co_assembly"),
        shell:
            "bowtie2 "
            "-p {threads} "             # number of parallel threads
            "--no-unal "                # remove unmapped reads (decrease size)
            "-x {params.prefix} "       # index for mapping
            "{params.input_reads} "
            "{params.cmd} "
            "-S {output} "

    rule semibin2_bowtie2_index:
        """
            Build a Bowtie2 index from concatenated fasta file
        """
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            expand(os.path.join(tmpdir, "binning/{{binning_strategy}}/semibin2/index/{{binning_strategy}}_concatenated_fasta.{id}.bt2l"), id=range(1,4)),
            expand(os.path.join(tmpdir, "binning/{{binning_strategy}}/semibin2/index/{{binning_strategy}}_concatenated_fasta.rev.{id}.bt2l"), id=range(1, 2)),
        input:
            os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/concatenated.fa"),
        log:
            "logs/index_concatenated_fasta/{binning_strategy}_concatenated_fasta.log"
        threads: 10
        conda:
            os.path.join(CONDAENV, "bowtie2.yaml")
        params:
            prefix = os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/index/{binning_strategy}_concatenated_fasta")
        priority:5
        shell:
            "bowtie2-build "
            "--large-index "
            "--threads {threads} "          # Number of parallel threads
            "{input} "                      # List of genomes in a single gz
            "{params.prefix} "              # Basename of the database
            "&> {log}"

    def expand_contigs(wildcards):
        if "SACB" in binning_request:
            assembly_dict["single_assembly"]= list(samples.keys())
            binning_dict["SACB"] = list(samples.keys())
            inputs_path = expand(os.path.join(intermediate_results_dir, "assembly/single_assembly", assembler, "{src}/contigs/{src}_" + assembly), src=binning_dict.get("SACB"))
            return inputs_path
        if "CACB" in binning_request:
            checkpoint_output_simka = checkpoints.cluster_simka.get(**wildcards).output[0]
            assembly_dict["co_assembly"]= glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
            binning_dict["CACB"] = glob_wildcards(os.path.join(checkpoint_output_simka, "{clusterid}.txt")).clusterid
            inputs_path = expand(os.path.join(intermediate_results_dir, "assembly/co_assembly", assembler, "{src}/contigs/{src}_" + assembly), src=binning_dict.get("CACB"))
            return inputs_path

    rule semibin2_concatenate_fasta:
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            os.path.join(tmpdir, "binning/{binning_strategy}/semibin2/concatenated.fa"),
        input:
            os.path.join(tmp, "assembly.checkpoint"),
            contigs = expand_contigs
        params:
            output = os.path.join(tmpdir, "binning/{binning_strategy}/semibin2"),
            contigs = expand_contigs
        conda:
            os.path.join(CONDAENV, "semibin2.yaml")
        shell:
            'SemiBin2 concatenate_fasta --input-fasta {params.contigs} --output {params.output} && '
            'gunzip {output} '

    rule sort_and_index_binning_CB:
        """
        """
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            temp(os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/compute_depth/{src}_to_{src1}.sorted.bam")),
            temp(os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/compute_depth/{src}_to_{src1}.sorted.bam.bai")),
        input:
            os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/mapped_reads/{src}.filtered.sam"),
        threads: 1
        priority: 3
        conda:
            os.path.join(CONDAENV, "samtools.yaml")
        shell:
            "samtools view -u {input} | "
            "samtools sort "
            "-@ {threads} "             # number of threads used
            "-o {output[0]} "
            "&& samtools index "
            "{output[0]} "

    def input_cmd_CB(wildcards):
        if wildcards.assembly == "single_assembly":
            list_reads = []
            for run in reads2use[wildcards.src]:
                list_reads.extend(reads2use[wildcards.src][run])
            return list_reads
        elif wildcards.assembly == "co_assembly":
            if simka_type == "None":
                return os.path.join(tmpdir, 'samples.txt')
            return os.path.join(intermediate_results_dir, "assembly/co_assembly/clusters/{src}.txt")
        else:
            raise ValueError

    rule filter_bam_CB:
        """
        Filter reads based on mapping quality and identity.
        Output is temporary because it will be sorted.
        """
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            temp(os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/mapped_reads/{src}.filtered.sam")),
        input:
            os.path.join(tmpdir, "binning", "{assembly,.*_assembly}/{src}_to_{src1}_" + f"{index}_filtering.sam")
        conda:
            os.path.join(CONDAENV, "bamutils.yaml")
        priority: 2
        params:
            min_mapq = config["bam_filtering_before_binning"]["min_quality"],
            min_idt = config["bam_filtering_before_binning"]["min_identity"],
            min_len = config["bam_filtering_before_binning"]["min_len"],
            pp = config["bam_filtering_before_binning"]["properly_paired"],
        script:
            "../scripts/bamprocess.py"

    rule binning_mapping_CB:
        '''
            Align the source reads files against the assembled contigs file to assess contigs' abundance.
        '''
        wildcard_constraints:
            binning_strategy="SACB|CACB"
        output:
            temp(os.path.join(tmpdir, "binning", "{assembly,.*_assembly}/{src}_to_{src1}_" + f"{index}_filtering.sam"))
        input:
            assembly = os.path.join(intermediate_results_dir, "assembly/{assembly}", assembler, "{src1}/contigs/{src1}_" + f"{assembly}"),
            index1 = expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src1}}/index/{{src1}}_" + index + "_filtering.{id}" + (".bt2l" if config["large_index"] else ".bt2")), id=range(1, 4)),
            index2 = expand(os.path.join(intermediate_results_dir, "assembly/{{assembly}}", assembler, "{{src1}}/index/{{src1}}_" + index + "_filtering.rev.{id}" + (".bt2l" if config["large_index"] else ".bt2")), id=range(1,2)),
            reads = input_cmd_CB,
            finished_assembly = os.path.join(tmp, "assembly.checkpoint")
        params:
            prefix = os.path.join("intermediate_results/assembly/{assembly}", assembler, "{src1}/index/{src1}_" + f"{index}_filtering"),
            input_reads = lambda wildcards, input : cmdparser.cmd(wildcards.src, input.reads, reads2use, "bowtie2").cmd,
            cmd = lambda wildcards : conf.mapping_cmd(config, wildcards.assembly),

        threads: 5
        priority: 1
        conda:
            os.path.join(CONDAENV, "bowtie2.yaml")
        shell:
            "bowtie2 "
            "-p {threads} "             # number of parallel threads
            "--no-unal "                # remove unmapped reads (decrease size)
            "-x {params.prefix} "       # index for mapping
            "{params.input_reads} "
            "{params.cmd} "
            "-S {output} "