Source code for kenjutsu.kenjutsu

# -*- coding: utf-8 -*-

"""
The module ``kenjutsu`` provides support for working with ``slice``\ s.

===============================================================================
Overview
===============================================================================
The module ``kenjutsu`` provides several functions that are useful for working
with a Python ``slice`` or ``tuple`` of ``slice``\ s. This is of particular
value when working with NumPy_.

.. _NumPy: http://www.numpy.org/

===============================================================================
API
===============================================================================
"""


__author__ = "John Kirkham <kirkhamj@janelia.hhmi.org>"
__date__ = "$Sep 08, 2016 15:46:46 EDT$"


import itertools
import numbers
import operator
import math
import warnings


[docs]def reformat_slice(a_slice, a_length=None): """ Takes a slice and reformats it to fill in as many undefined values as possible. Args: a_slice(slice): a slice to reformat. a_length(int): a length to fill for stopping if not provided. Returns: (slice): a new slice with as many values filled in as possible. Examples: >>> reformat_slice(slice(2, -1, None)) slice(2, -1, 1) >>> reformat_slice(slice(2, -1, None), 10) slice(2, 9, 1) """ new_slice = a_slice if (new_slice is Ellipsis) or (new_slice == tuple()): new_slice = slice(None) elif isinstance(a_slice, numbers.Integral): if a_slice < 0: new_slice = slice(a_slice, a_slice-1, -1) else: new_slice = slice(a_slice, a_slice+1, 1) elif not isinstance(a_slice, slice): raise ValueError( "Expected a `slice` type. Instead got `%s`." % str(a_slice) ) if new_slice.step == 0: raise ValueError("Slice cannot have a step size of `0`.") start = new_slice.start stop = new_slice.stop step = new_slice.step # Fill unknown values. if step is None: step = 1 if start is None: if step > 0: start = 0 elif step < 0: start = -1 if (stop is None) and (step > 0): stop = a_length stop_i = stop is not None # Make adjustments for length if a_length is not None: # Normalize out-of-bound step sizes. if step < -a_length: step = -a_length elif step > a_length: step = a_length # Normalize bounded negative values. if -a_length <= start < 0: start += a_length if stop_i and (-a_length <= stop < 0): stop += a_length # Handle out-of-bound limits. if step > 0: if (start > a_length) or (stop < -a_length): start = stop = 0 step = 1 else: if start < -a_length: start = 0 if stop > a_length: stop = a_length elif step < 0: if (start < -a_length) or (stop_i and stop >= (a_length - 1)): start = stop = 0 step = 1 else: if start >= a_length: start = a_length - 1 if stop_i and stop < -a_length: stop = None stop_i = False # Catch some known empty slices. if stop_i and (start == stop): start = stop = 0 step = 1 elif (step > 0) and (stop == 0): start = stop = 0 step = 1 elif (step < 0) and (stop == -1): start = stop = 0 step = 1 elif stop_i and (start >= 0) and (stop >= 0): if (step > 0) and (start > stop): start = stop = 0 step = 1 elif (step < 0) and (start < stop): start = stop = 0 step = 1 new_slice = slice(start, stop, step) if isinstance(a_slice, numbers.Integral): if new_slice.start == new_slice.stop == 0: raise IndexError("Index out of range.") new_slice = new_slice.start return(new_slice)
[docs]def reformat_slices(slices, lengths=None): """ Takes a tuple of slices and reformats them to fill in as many undefined values as possible. Args: slices(tuple(slice)): a tuple of slices to reformat. lengths(tuple(int)): a tuple of lengths to fill. Returns: (slice): a tuple of slices with all default values filled if possible. Examples: >>> reformat_slices( ... ( ... slice(None), ... slice(3, None), ... slice(None, 5), ... slice(None, None, 2) ... ), ... (10, 13, 15, 20) ... ) (slice(0, 10, 1), slice(3, 13, 1), slice(0, 5, 1), slice(0, 20, 2)) """ new_slices = slices if new_slices == tuple(): new_slices = Ellipsis try: len(new_slices) except TypeError: new_slices = (new_slices,) new_lengths = lengths try: if new_lengths is not None: len(new_lengths) except TypeError: new_lengths = (new_lengths,) el_idx = None try: el_idx = new_slices.index(Ellipsis) except ValueError: pass if new_lengths is not None and el_idx is None: if len(new_slices) != len(new_lengths): raise ValueError("Shape must be the same as the number of slices.") elif new_lengths is not None: if (len(new_slices) - 1) > len(new_lengths): raise ValueError( "Shape must be as large or larger than the number of slices" " without the Ellipsis." ) if el_idx is not None: # Break into three cases. # # 1. Before the Ellipsis # 2. The Ellipsis # 3. After the Ellipsis # # Cases 1 and 3 are trivially solved as before. # Case 2 is either a no-op or a bunch of `slice(None)`s. # # The result is a combination of all of these. slices_before = new_slices[:el_idx] slices_after = new_slices[el_idx+1:] if Ellipsis in slices_before or Ellipsis in slices_after: raise ValueError("Only one Ellipsis is permitted. Found multiple.") new_lengths_before = None new_lengths_after = None slice_el = (Ellipsis,) if new_lengths is not None: pos_before = len(slices_before) pos_after = len(new_lengths) - len(slices_after) new_lengths_before = new_lengths[:pos_before] new_lengths_after = new_lengths[pos_after:] new_lengths_el = new_lengths[pos_before:pos_after] slice_el = len(new_lengths_el) * (slice(None),) if slice_el: slice_el = reformat_slices( slice_el, new_lengths_el ) if slices_before: slices_before = reformat_slices(slices_before, new_lengths_before) if slices_after: slices_after = reformat_slices(slices_after, new_lengths_after) new_slices = slices_before + slice_el + slices_after else: if new_lengths is None: new_lengths = [None] * len(new_slices) new_slices = list(new_slices) for i, each_length in enumerate(new_lengths): new_slices[i] = reformat_slice(new_slices[i], each_length) new_slices = tuple(new_slices) return(new_slices)
[docs]class UnknownSliceLengthException(Exception): """ Raised if a slice does not have a known length. """ pass
[docs]def len_slice(a_slice, a_length=None): """ Determines how many elements a slice will contain. Raises: UnknownSliceLengthException: Will raise an exception if a_slice.stop and a_length is None. Args: a_slice(slice): a slice to reformat. a_length(int): a length to fill for stopping if not provided. Returns: (slice): a new slice with as many values filled in as possible. Examples: >>> len_slice(slice(2, None), 10) 8 >>> len_slice(slice(2, 6)) 4 """ if isinstance(a_slice, numbers.Integral): raise TypeError( "An integral index does not provide an object with a length." ) new_slice = reformat_slice(a_slice, a_length) if new_slice.stop is None: if new_slice.step > 0: raise UnknownSliceLengthException( "Cannot determine slice length without a defined end point. " + "The reformatted slice was " + repr(new_slice) + "." ) else: new_slice = slice(new_slice.start, -1, new_slice.step) new_slice_diff = float(new_slice.stop - new_slice.start) new_slice_size = int(math.ceil(new_slice_diff / new_slice.step)) return(new_slice_size)
[docs]def len_slices(slices, lengths=None): """ Takes a tuple of slices and reformats them to fill in as many undefined values as possible. Args: slices(tuple(slice)): a tuple of slices to reformat. lengths(tuple(int)): a tuple of lengths to fill. Returns: (slice): a tuple of slices with all default values filled if possible. Examples: >>> len_slices( ... ( ... slice(None), ... slice(3, None), ... slice(None, 5), ... slice(None, None, 2) ... ), ... (10, 13, 15, 20) ... ) (10, 10, 5, 10) """ new_slices = reformat_slices(slices, lengths) lens = [] for each_slice in new_slices: if not isinstance(each_slice, numbers.Integral): lens.append(len_slice(each_slice)) lens = tuple(lens) return(lens)
[docs]def split_blocks(space_shape, block_shape, block_halo=None): """ Return a list of slicings to cut each block out of an array or other. Takes an array with ``space_shape`` and ``block_shape`` for every dimension and a ``block_halo`` to extend each block on each side. From this, it can compute slicings to use for cutting each block out from the original array, HDF5 dataset or other. Note: Blocks on the boundary that cannot extend the full range will be truncated to the largest block that will fit. This will raise a warning, which can be converted to an exception, if needed. Args: space_shape(tuple): Shape of array to slice block_shape(tuple): Size of each block to take block_halo(tuple): Halo to tack on to each block Returns: collections.Sequence of \ tuples of slices: Provides tuples of slices for \ retrieving blocks. Examples: >>> split_blocks( ... (2, 3,), (1, 1,), (1, 1,) ... ) #doctest: +NORMALIZE_WHITESPACE ([(slice(0, 1, 1), slice(0, 1, 1)), (slice(0, 1, 1), slice(1, 2, 1)), (slice(0, 1, 1), slice(2, 3, 1)), (slice(1, 2, 1), slice(0, 1, 1)), (slice(1, 2, 1), slice(1, 2, 1)), (slice(1, 2, 1), slice(2, 3, 1))], <BLANKLINE> [(slice(0, 2, 1), slice(0, 2, 1)), (slice(0, 2, 1), slice(0, 3, 1)), (slice(0, 2, 1), slice(1, 3, 1)), (slice(0, 2, 1), slice(0, 2, 1)), (slice(0, 2, 1), slice(0, 3, 1)), (slice(0, 2, 1), slice(1, 3, 1))], <BLANKLINE> [(slice(0, 1, 1), slice(0, 1, 1)), (slice(0, 1, 1), slice(1, 2, 1)), (slice(0, 1, 1), slice(1, 2, 1)), (slice(1, 2, 1), slice(0, 1, 1)), (slice(1, 2, 1), slice(1, 2, 1)), (slice(1, 2, 1), slice(1, 2, 1))]) """ try: irange = xrange except NameError: irange = range try: from itertools import ifilter, imap except ImportError: ifilter, imap = filter, map if block_halo is not None: if not (len(space_shape) == len(block_shape) == len(block_halo)): raise ValueError( "The dimensions of `space_shape`, `block_shape`, and" " `block_halo` should be the same." ) else: if not (len(space_shape) == len(block_shape)): raise ValueError( "The dimensions of `space_shape` and `block_shape` should be" " the same." ) block_halo = tuple() for i in irange(len(space_shape)): block_halo += (0,) vec_add = lambda a, b: imap(operator.add, a, b) vec_sub = lambda a, b: imap(operator.sub, a, b) vec_mul = lambda a, b: imap(operator.mul, a, b) vec_mod = lambda a, b: imap(operator.mod, a, b) vec_nonzero = lambda a: \ imap(lambda _: _[0], ifilter(lambda _: _[1], enumerate(a))) vec_str = lambda a: imap(str, a) vec_clip_floor = lambda a, a_min: \ imap(lambda _: _ if _ >= a_min else a_min, a) vec_clip_ceil = lambda a, a_max: \ imap(lambda _: _ if _ <= a_max else a_max, a) vec_clip = lambda a, a_min, a_max: \ vec_clip_ceil(vec_clip_floor(a, a_min), a_max) uneven_block_division = tuple(vec_mod(space_shape, block_shape)) if any(uneven_block_division): uneven_block_division_str = vec_nonzero(uneven_block_division) uneven_block_division_str = vec_str(uneven_block_division_str) uneven_block_division_str = ", ".join(uneven_block_division_str) warnings.warn( "Blocks will not evenly divide the array." + " The following dimensions will be unevenly divided: %s." % uneven_block_division_str, RuntimeWarning ) ranges_per_dim = [] haloed_ranges_per_dim = [] trimmed_halos_per_dim = [] for each_dim in irange(len(space_shape)): # Construct each block using the block size given. Allow to spill over. if block_shape[each_dim] == -1: block_shape = (block_shape[:each_dim] + space_shape[each_dim:each_dim+1] + block_shape[each_dim+1:]) # Generate block ranges. a_range = [] for i in irange(2): offset = i * block_shape[each_dim] this_range = irange( offset, offset + space_shape[each_dim], block_shape[each_dim] ) a_range.append(list(this_range)) # Add the halo to each block on both sides. a_range_haloed = [] for i in irange(2): sign = 2 * i - 1 haloed = vec_mul( itertools.repeat(sign, len(a_range[i])), itertools.repeat(block_halo[each_dim], len(a_range[i])), ) haloed = vec_add(a_range[i], haloed) haloed = vec_clip(haloed, 0, space_shape[each_dim]) a_range_haloed.append(list(haloed)) # Compute how to trim the halo off of each block. # Clip each block to the boundaries. a_trimmed_halo = [] for i in irange(2): trimmed = vec_sub(a_range[i], a_range_haloed[0]) a_trimmed_halo.append(list(trimmed)) a_range[i] = list(vec_clip(a_range[i], 0, space_shape[each_dim])) # Convert all ranges to slices for easier use. a_range = tuple(imap(slice, *a_range)) a_range_haloed = tuple(imap(slice, *a_range_haloed)) a_trimmed_halo = tuple(imap(slice, *a_trimmed_halo)) # Format all slices. a_range = reformat_slices(a_range) a_range_haloed = reformat_slices(a_range_haloed) a_trimmed_halo = reformat_slices(a_trimmed_halo) # Collect all blocks ranges_per_dim.append(a_range) haloed_ranges_per_dim.append(a_range_haloed) trimmed_halos_per_dim.append(a_trimmed_halo) # Take all combinations of all ranges to get blocks. blocks = list(itertools.product(*ranges_per_dim)) haloed_blocks = list(itertools.product(*haloed_ranges_per_dim)) trimmed_halos = list(itertools.product(*trimmed_halos_per_dim)) return(blocks, haloed_blocks, trimmed_halos)