# vim: ft=python
"""Snakefile to re-sort POD files by channel.
   Useful for duplex calling.
"""
from collections import defaultdict
import json

INPUT_DIR     = config.get('input', "input")
PORES_PER_POD = int(config.get('ppp', "50"))

# Utility functions
from common_functions import n_sorted, get_common_prefix, pore_to_batch

def list_pod5_channels(wc=None):
    """List all the .pod5.channels files (ie. one per pod5)
    """
    all_pod5 = glob_wildcards(INPUT_DIR + "/{p5}.pod5")

    assert all_pod5.p5, f"No .pod5 files seen in {INPUT_DIR}/"

    c_extn = "bchannels" if PORES_PER_POD > 1 else "channels"
    return [ f"results/tmp/{p5}.pod5.{c_extn}" for p5 in sorted(all_pod5.p5) ]

def load_channels_in_use():
    """Loads channels_in_use.json, if it's available.
    """
    with open(checkpoints.get_channels_in_use.get().output[0]) as jfh:
        return json.load(jfh)

rule main:
    input: "results/tmp/all_channels_done.touch"

localrules: get_channels_in_use, output_for_all_channels, batched_channels_in_pod

wildcard_constraints:
    channel = r"\d+(-\d+)?",  # Channel may be single or a range
    pref    = r"[^/]*",       # Prefix may be empty

rule get_channels_in_pod:
    output: temporary("results/tmp/{foo}.pod5.channels")
    input:  INPUT_DIR + "/{foo}.pod5"
    conda:  "envs/pod5.yaml"
    shell:
        "pod5 view -H {input} -i read_id,channel > {output}"

rule batched_channels_in_pod:
    output: temporary("{foo}.pod5.bchannels")
    input:  "{foo}.pod5.channels"
    run:
        with open(str(output), "x") as ofh:
            with open(str(input)) as ifh:
                if PORES_PER_POD <= 1:
                    # Just write the lines to the file
                    ofh.writelines(ifh)
                else:
                    for r, c in (line.split() for line in ifh):
                        c_batch = pore_to_batch(c, PORES_PER_POD)
                        # TODO - check this is actually supposed to be a tab
                        ofh.write(f"{r}\t{c_batch}\n")

checkpoint get_channels_in_use:
    output: "results/tmp/channels_in_use.json"
    input:  list_pod5_channels
    run:
        # "cut -d, -f 1 {input} | sort -u > {output}"
        # Save channel -> pod5_files dict as json
        res = defaultdict(set)
        for f in input:
            f_base = re.match( r"results/tmp/((.+/|)\w+\.pod5)\.b?channels$",
                               f ).group(1)
            with open(f) as fh:
                for read_id, channel in (l.split() for l in fh):
                    res[channel].add(f_base)

        # Sets to lists
        with open(str(output), "x") as ofh:
            json.dump( {k: n_sorted(v) for k, v in res.items()},
                       fp = ofh,
                       sort_keys = True,
                       indent = 4 )

def i_pod5_for_channel(wc):
    """See which pod5 files have reads on a given channel, to save us scanning
       every file every time.
    """
    channel = wc.channel
    ciu = load_channels_in_use()

    c_extn = "bchannels" if PORES_PER_POD > 1 else "channels"

    return dict( pod5 = [ f"{INPUT_DIR}/{f}" for f in ciu[channel] ],
                 chan = [ f"results/tmp/{f}.{c_extn}" for f in ciu[channel] ] )

# FIXME - this could blow the command line limit if there are too many pod5 files,
# so an alternative could be to make a load of symlinks (in tmpfs) then use the
# recursive mode. Icky.
# Or filter in batches and merge. Which is also icky.
rule pod5_for_channel:
    output:
        pod5 =   "results/{pref}channel_{channel}.pod5",
        idlist = temporary("results/tmp/{pref}channel_{channel}.idlist"),
    log: "results/tmp/{pref}channel_{channel}.pod5-filter.log"
    input:
        unpack(i_pod5_for_channel)
    params:
        awk_filter = '($2=="{channel}"){{print$1}}'
    conda:  "envs/pod5.yaml"
    shell:
       r"""awk {params.awk_filter:q} {input.chan} > {output.idlist}
           pod5 filter -o {output.pod5} -i {output.idlist} {input.pod5} >& {log}
        """

def i_output_for_all_channels(wc=None):
    """Reads the list of all the channels for which there are records in any POD5.
       Returns a list of POD5 files to make in ./results
    """
    ciu = load_channels_in_use()

    # Keep the common prefix from the original POD5 filenames
    common_prefix = get_common_prefix( list_pod5_channels(),
                                       base_only = True,
                                       extn = r"_\d+\.pod5\.b?channels" )
    pref = f"{common_prefix}_" if common_prefix else ""

    return [ f"results/{pref}channel_{c}.pod5" for c in ciu ]

rule output_for_all_channels:
    output: touch("results/tmp/all_channels_done.touch")
    input:  i_output_for_all_channels
