#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import argparse
import pandas as pd
from Bio import SeqIO


def filterbycoverage(
        ref,
        cov, 
        minr=0,
        minl=1000,
        minc=1,
        minp=20,
        properly_paired=False,
        sep="\t"):
        """
            filter an assembly by coverage, length and number of reads mapped.
            
            Parameters
            ----------
            ref : str (required)
                reference assembly to filter
            cov: str (default = None)
                file containing coverage statistics produce by Sample.reference_coverage().
            minr: int (default = 0)
                minimal number of reads mapped on a contig
            minl: int (default = 100)
                contigs length threshold
            minc: int (default = 1)
                depth of coverage threshold
            minp: int ([0-100] , default = 20) 
                breath of coverage threshold
            properly_paired: boolean (default = True)
                if True, only properly paired reads statistics are used for filtering
        """
        
        
        try:
            coverage_df = pd.read_csv(cov, header=0,sep=sep, index_col=0, squeeze=True)
            coverage_df = coverage_df.transpose()
            coverage_dict = coverage_df.to_dict()
        except:
            raise FileNotFoundError

        try: 
            fasta_sequences = SeqIO.parse(open(ref),'fasta')
        except:
            raise FileNotFoundError

        breath="breath_coverage"
        depth="depth_avg_coverage"
        if properly_paired:
            breath = "breath_coverage_properly_paired"
            depth = "depth_avg_coverage_properly_paired"

        filtered_sequences = []
        failed_sequences = []
        covstats_dict_filtered = {}
        covstats_dict_failed = {}
        for record in fasta_sequences:
            if record.name in coverage_dict.keys() and \
                        coverage_dict[record.name]["length"] >= minl and \
                        coverage_dict[record.name]["reads"] >= minr and \
                        coverage_dict[record.name][depth] >= minc and \
                        coverage_dict[record.name][breath] >= minp:
                filtered_sequences.append(record)
                covstats_dict_filtered[record.name]=coverage_dict[record.name]
            else:
                failed_sequences.append(record)
                try:
                    covstats_dict_failed[record.name]=coverage_dict[record.name]
                except: # contigs not in coverage_dict because of reference_coverage filtering
                    covstats_dict_failed[record.name]= None
        
        print("Found %i high quality sequences" % len(filtered_sequences))
        print("Found %i bad quality sequences" % len(failed_sequences))
        
        
        return filtered_sequences, failed_sequences


if __name__ == "__main__":           
        try:
            ref = snakemake.input.assembly
            cov = snakemake.input.covstats
            outfilter = snakemake.output.filtered
            outfailed = snakemake.output.failed
            
            min_reads = snakemake.params.min_reads
            min_length = snakemake.params.min_length
            min_breath = snakemake.params.min_breath
            min_depth = snakemake.params.min_depth
            properly_paired = snakemake.params.properly_paired 
            sep = "\t"
        except:
            # filterbycoverage arguments #
            parser = argparse.ArgumentParser(
                prog='filterbycoverage',
                description='Filter fasta file depending on coverage depth and breadth retrieve from indexed '
                'BAM.'
            )
            parser.add_argument(
                '-r','--ref', type=str,
                default=None,
                help='(Optional) reference assembly to filter in fasta format'
            )
            parser.add_argument(
                '-c','--covstats',type=str,
                help='Statistics file produce by metacovest. If set, bypass coverage statistics recovery'
            )
            parser.add_argument(
                '--filtered',type=str,
                default=sys.stdout,
                help='(Optional) outfile to store quality condigs, fasta format'
            )
            parser.add_argument(
                '--failed',type=str,
                default=sys.stdout,
                help='(Optional) outfile to store bad quality condigs, fasta format'
            )
            parser.add_argument(
                '--minr', nargs='?', type=int,
                default=0,
                help='minimum number of reads mapped to a contig (default: 0)'
            )
            parser.add_argument(
                '--minl', nargs='?', type=int,
                default=2000,
                help='contig length threshold (default: 200)'
            )
            parser.add_argument(
                '--minc', nargs='?', type=int,
                default=1,
                help='depth of coverage threshold (default: 1)'
            )
            parser.add_argument(
                '--minp', nargs='?', type=int,
                default=20,
                help='breath of coverage threshold ([0-100]) (default: 20)'
            )
            parser.add_argument(
                '-p','--properly-paired', action='store_true',
                help='properly paired reads statistics are used for filtering '
            )
            parser.add_argument(
                '--sep', type=str,
                help='coverage file separator'
            )
            args = parser.parse_args()
        
            ref = args.ref
            cov = args.covstats
            outfilter = args.filtered
            outfailed = args.failed
            

            min_reads = args.minr
            min_length = args.minl
            min_breath = args.minp
            min_depth = args.minc
            properly_paired = args.properly_paired
            sep = args.sep
        
        
        filtered_sequences, failed_sequences = filterbycoverage(
                ref,
                cov,
                minr=min_reads,
                minl=min_length,
                minc=min_depth,
                minp=min_breath,
                properly_paired=properly_paired,
                sep=sep
        )
        
    # if outfilter is not None:
        SeqIO.write(filtered_sequences, outfilter, "fasta-2line")
        SeqIO.write(failed_sequences, outfailed, "fasta-2line")