Implement mask for some primitives + jit. (#2922)

* Implement mask for slice, conv, pad, transpose, where

* Remove tentative mask(jit)

* Add explanatory comment to dot_general masking rule

* Rm reshape from select masking rule

* Rm unnecessary check from lax slice abstract_eval rule

* Revert to standard indentation in masking_test.py

* Begin simplifying masking tests

* Finish drafting masking check function

* More progress simplifying tests

* Add conv masking in batch dim

* Finish fixing up tests

* Revert to old API, making out_shape compulsory again

* More efficient conv masking rule

* Tidy up masking_test imports

* Check that out tree is preserved by masking

* fix flake errors

Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Julius Kunze 2020-06-03 22:40:48 +02:00 committed by GitHub
parent 0db57cb541
commit d1dbf7c7d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 592 additions and 360 deletions

View File

@ -54,7 +54,6 @@ from .lib import xla_client as xc
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .interpreters.masking import eval_polymorphic_shape, Poly, Mon
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
@ -1289,67 +1288,50 @@ def _parallelize(fun):
def mask(fun: Callable, in_shapes, out_shape) -> Callable:
_check_callable(fun)
unique_ids = masking.UniqueIds()
in_specs, in_shapes_tree = tree_flatten(in_shapes)
out_specs, out_shapes_tree = tree_flatten(out_shape)
in_specs = map(masking.parse_spec, in_specs)
out_specs = map(masking.parse_spec, out_specs)
in_specs = map(partial(masking.remap_ids, unique_ids), in_specs)
unique_ids: Dict[Any, Any] = collections.defaultdict(object)
in_specs = map(partial(_remap_ids, unique_ids), in_specs)
out_specs = map(partial(_remap_ids, unique_ids), out_specs)
out_specs, out_spec_tree = tree_flatten(out_shape)
out_specs = map(masking.parse_spec, out_specs)
out_specs = map(partial(masking.remap_ids, unique_ids), out_specs)
def wrapped_fun(args, logical_env):
args_flat, in_tree = tree_flatten(args)
if in_tree != in_shapes_tree: raise TypeError("pytree mismatch")
if in_tree != in_shapes_tree:
raise TypeError(f"Tree mismatch: Input {in_tree} and shape spec {in_shapes_tree}.")
logical_env = {unique_ids[name] : val for name, val in logical_env.items()}
in_shapes = map(masking.finalize_spec, in_specs, map(onp.shape, args_flat))
padded_env = _bind_shapes(in_shapes, [x.shape for x in args_flat])
padded_env = masking.bind_shapes(in_shapes, [x.shape for x in args_flat])
f = lu.wrap_init(fun)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
outs, out_shapes_ = masking.mask_fun(
flat_fun, logical_env, padded_env, args_flat, in_shapes)
if not out_tree() == out_shapes_tree: raise TypeError("pytree mismatch")
out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs))
if not out_shapes == list(out_shapes_):
raise masking.ShapeError
if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env)
for out, shape in zip(outs, out_shapes)):
raise masking.ShapeError
return tree_unflatten(out_tree(), outs)
flat_fun, out_tree_thunk = flatten_fun_nokwargs(f, in_tree)
outs, out_shapes = masking.mask_fun(
flat_fun, logical_env, padded_env, args_flat, in_shapes)
out_tree = out_tree_thunk()
masking.check_shapes(out_specs, out_spec_tree, list(out_shapes), out_tree)
def padded_spec(shape_spec):
return tuple(dim if dim is masking._monomorphic_dim else
masking.eval_poly(dim, padded_env) for dim in shape_spec)
masking.check_shapes(map(padded_spec, out_specs), out_spec_tree,
map(onp.shape, outs), out_tree, "Padded output")
return tree_unflatten(out_tree, outs)
return wrapped_fun
def _remap_ids(names, shape_spec):
return masking.ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
: coeff for mon, coeff in poly.items()})
if poly is not masking._monomorphic_dim else
masking._monomorphic_dim for poly in shape_spec)
def _bind_shapes(shape_exprs, shapes):
env = {}
for shape_expr, shape in zip(shape_exprs, shapes):
for poly, d in zip(shape_expr, shape):
if type(poly) is not Poly or poly.is_constant:
continue
else:
(binder,), = poly # TODO generalize to handle striding
if env.setdefault(binder, d) != d: raise masking.ShapeError
return env
@curry
def shapecheck(in_shapes, out_shape, fun: Callable):
_check_callable(fun)
in_shapes, in_tree = tree_flatten(in_shapes)
in_shapes = map(masking.parse_spec, in_shapes)
out_shapes, out_tree = tree_flatten(out_shape)
out_shapes = map(masking.parse_spec, out_shapes)
flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
out_specs, out_spec_tree = tree_flatten(out_shape)
out_specs = map(masking.parse_spec, out_specs)
flat_fun, out_tree_thunk = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
if out_tree != out_tree_(): raise TypeError("pytree mismatch")
if not all(map(masking._shape_spec_consistent, out_shapes, out_shapes_)):
raise masking.ShapeError
out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
masking.check_shapes(map(tuple, out_specs), out_spec_tree,
map(tuple, out_shapes), out_tree_thunk())
return fun
def jvp(fun: Callable, primals, tangents) -> Tuple[Any, Any]:

View File

@ -18,12 +18,13 @@ from functools import partial
from itertools import chain, product
import operator as op
import string
from typing import Callable, Dict
from typing import Callable, Dict, Sequence, Union
import numpy as onp
from .. import abstract_arrays
from .. import core
from ..tree_util import tree_unflatten
from ..core import Trace, Tracer
from ..util import safe_map, safe_zip, unzip2, prod
from ..abstract_arrays import ShapedArray
@ -53,42 +54,52 @@ def naryop_masking_rule(prim, padded_vals, logical_shapes):
ShapeEnvs = namedtuple("ShapeEnvs", ["logical", "padded"])
shape_envs = ShapeEnvs({}, {}) # TODO(mattjj): make this a stack for efficiency
def is_tracing():
return bool(shape_envs.padded)
@contextmanager
def extend_shape_envs(logical_env, padded_env):
global shape_envs
new_logical = dict(chain(shape_envs.logical.items(), logical_env.items()))
new_padded = dict(chain(shape_envs.padded.items(), padded_env.items()))
shape_envs, prev = ShapeEnvs(new_logical, new_padded), shape_envs
yield
shape_envs = prev
try:
yield
finally:
shape_envs = prev
def shape_as_value(shape):
assert is_tracing() or not is_polymorphic(shape)
return eval_polymorphic_shape(shape, shape_envs.logical)
def padded_shape_as_value(shape):
assert is_tracing() or not is_polymorphic(shape)
return eval_polymorphic_shape(shape, shape_envs.padded)
def mask_fun(fun, logical_env, padded_env, in_vals, shape_exprs):
def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
with core.new_master(MaskTrace) as master:
fun, out_shapes = mask_subtrace(fun, master)
fun, out_shapes = mask_subtrace(fun, master, polymorphic_shapes)
with extend_shape_envs(logical_env, padded_env):
out_vals = fun.call_wrapped(in_vals, shape_exprs)
out_vals = fun.call_wrapped(*in_vals)
del master
return out_vals, out_shapes()
@lu.transformation_with_aux
def mask_subtrace(master, in_vals, shape_exprs):
def mask_subtrace(master, polymorphic_shapes, *in_vals):
trace = MaskTrace(master, core.cur_sublevel())
in_tracers = [MaskTracer(trace, x, s).full_lower()
for x, s in zip(in_vals, shape_exprs)]
for x, s in zip(in_vals, polymorphic_shapes)]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals, out_shapes = unzip2((t.val, t.shape_expr) for t in out_tracers)
out_vals, out_shapes = unzip2((t.val, t.polymorphic_shape)
for t in out_tracers)
yield out_vals, out_shapes
def eval_polymorphic_shape(shape, values_dict):
return tuple(dim.evaluate(values_dict) if type(dim) is Poly else dim
for dim in shape)
return tuple(eval_poly(dim, values_dict) for dim in shape)
def eval_poly(poly, values_dict):
return poly.evaluate(values_dict) if type(poly) is Poly else poly
def _ensure_poly(p):
if type(p) is Poly:
@ -96,12 +107,16 @@ def _ensure_poly(p):
return Poly({Mon(): p})
def is_polymorphic(shape: Sequence[Union[int, 'Poly']]):
return any(map(lambda d: type(d) is Poly, shape))
class Poly(dict):
"""Polynomial with nonnegative integer coefficients for polymorphic shapes."""
def __init__(self, coeffs):
# Makes sure Polynomials are always in canonical form
coeffs = {mon: op.index(coeff) for mon, coeff in coeffs.items() if coeff != 0}
coeffs = {mon: op.index(coeff)
for mon, coeff in coeffs.items() if coeff != 0}
coeffs = coeffs or {Mon(): 0}
super().__init__(coeffs)
@ -149,13 +164,14 @@ class Poly(dict):
def divided(count):
q, r = divmod(count, divisor)
if r != 0:
raise ValueError('shapecheck currently only supports strides '
'that exactly divide the strided axis length.')
raise ValueError('shapecheck and masking currently only support '
'strides that exactly divide the strided axis '
'length.')
return q
return Poly(
{k: coeff // divisor if k.degree == 0 else divided(coeff)
for k, coeff in self.items()}), self.get(Mon(), 0) % divisor
for k, coeff in self.items()}), self.get(Mon(), 0) % divisor
def __hash__(self):
return hash(tuple(sorted(self.items())))
@ -196,6 +212,9 @@ class Poly(dict):
if (v != 1 or k.degree == 0) else str(k)
for k, v in sorted(self.items())).strip()
def __repr__(self):
return str(self)
def __int__(self):
assert self.is_constant
return op.index(next(iter(self.values())))
@ -213,7 +232,7 @@ abstract_arrays._DIMENSION_TYPES.add(Poly)
class Mon(dict):
def __hash__(self):
return hash(tuple(self.items()))
return hash(frozenset(self.items()))
def __str__(self):
return ' '.join('{}**{}'.format(k, v) if v != 1 else str(k)
@ -221,8 +240,8 @@ class Mon(dict):
def __lt__(self, other):
# sort by total degree, then lexicographically on indets
self_key = self.degree, tuple(sorted(self))
other_key = other.degree, tuple(sorted(other))
self_key = -self.degree, tuple(sorted(self))
other_key = -other.degree, tuple(sorted(other))
return self_key < other_key
def __mul__(self, other):
@ -258,9 +277,9 @@ class ShapeSpec(tuple):
def __str__(self):
return 'ShapeSpec({})'.format(', '.join(map(str, self)))
def finalize_spec(spec, shape):
def finalize_spec(polymorphic_shape, padded_shape):
return tuple(_parse_lit(d) if e is _monomorphic_dim else e
for e, d in zip(spec, shape))
for e, d in zip(polymorphic_shape, padded_shape))
def parse_spec(spec=''):
if not spec:
@ -311,19 +330,24 @@ def _shape_spec_consistent(spec, expr):
return all(a == b for a, b in zip(spec, expr) if a is not _monomorphic_dim)
class MaskTracer(Tracer):
__slots__ = ["val", "shape_expr"]
__slots__ = ["val", "polymorphic_shape"]
def __init__(self, trace, val, shape_expr):
self._trace = trace
def __init__(self, trace, val, polymorphic_shape):
super().__init__(trace)
self.val = val
self.shape_expr = shape_expr
self.polymorphic_shape = polymorphic_shape
@property
def aval(self):
return ShapedArray(self.shape_expr, self.val.dtype)
return ShapedArray(self.polymorphic_shape, self.dtype)
@property
def dtype(self):
return self.val.dtype
def is_pure(self):
return all(type(poly) is not Poly or poly.is_constant for poly in self.shape_expr)
return all(type(poly) is not Poly or poly.is_constant
for poly in self.polymorphic_shape)
def full_lower(self):
if self.is_pure():
@ -340,23 +364,74 @@ class MaskTrace(Trace):
return MaskTracer(self, val, onp.shape(val))
def sublift(self, val):
return MaskTracer(self, val.val, val.shape_expr)
return MaskTracer(self, val.val, val.polymorphic_shape)
def process_primitive(self, primitive, tracers, params):
vals, shape_exprs = unzip2((t.val, t.shape_expr) for t in tracers)
vals, polymorphic_shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
if primitive in shape_parameterized_primitive_rules:
rule = shape_parameterized_primitive_rules[primitive]
out, out_shape = rule(shape_envs, vals, shape_exprs, **params)
out, out_shape = rule(shape_envs, vals, polymorphic_shapes, **params)
else:
avals = [t.aval for t in tracers]
out = primitive.abstract_eval(*avals, **params)
out_shape = [o.shape for o in out] if primitive.multiple_results else out.shape
logical_shapes = map(partial(eval_polymorphic_shape, values_dict=shape_envs.logical), shape_exprs)
out = masking_rules[primitive](vals, logical_shapes, **params)
logical_shapes = map(shape_as_value, polymorphic_shapes)
masking_rule = masking_rules.get(primitive)
if masking_rule is None:
raise NotImplementedError('Masking rule for {} not implemented yet.'.format(primitive))
out = masking_rule(vals, logical_shapes, **params)
if not primitive.multiple_results:
return MaskTracer(self, out, out_shape)
else:
return map(partial(MaskTracer, self), out, out_shape)
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
raise NotImplementedError # TODO mask-of-jit
def process_call(self, call_primitive, f, tracers, params):
raise NotImplementedError
def post_process_call(self, call_primitive, out_tracers, params):
raise NotImplementedError
class UniqueId:
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
def __lt__(self, other):
return self.name < other.name
class UniqueIds(dict):
def __missing__(self, key):
unique_id = UniqueId(key)
self[key] = unique_id
return unique_id
def remap_ids(names, shape_spec):
return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
: coeff for mon, coeff in poly.items()})
if poly is not _monomorphic_dim else
_monomorphic_dim for poly in shape_spec)
def bind_shapes(polymorphic_shapes, padded_shapes):
env = {}
for polymorphic_shape, padded_shape in zip(polymorphic_shapes, padded_shapes):
for poly, d in zip(polymorphic_shape, padded_shape):
if type(poly) is not Poly or poly.is_constant:
if int(poly) != d: raise ShapeError
else:
poly = poly.copy()
const_coeff = poly.pop(Mon({}), 0)
(mon, linear_coeff), = poly.items()
(id, index), = mon.items()
if index != 1: raise ShapeError
d, r = divmod(d - const_coeff, linear_coeff)
assert r == 0
if env.setdefault(id, d) != d: raise ShapeError
return env
def check_shapes(specs, spec_tree, shapes, tree, message_prefix="Output"):
if spec_tree != tree or not all(map(_shape_spec_consistent, specs, shapes)):
specs = tree_unflatten(spec_tree, specs)
shapes = tree_unflatten(tree, shapes)
raise ShapeError(f"{message_prefix} shapes should be {specs} but are {shapes}.")

View File

@ -36,7 +36,7 @@ from .. import dtypes
from .. import lazy
from .. import lib
from ..config import flags
from ..core import Primitive
from ..core import Primitive, _canonicalize_dimension
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
AbstractToken, array_types, make_shaped_array,
raise_to_shaped, abstract_token, canonicalize_shape)
@ -1534,8 +1534,8 @@ def slice_in_dim(operand: Array, start_index: Optional[int],
# translate `None`
len_axis = operand.shape[axis]
start_index_int = int(start_index) if start_index is not None else 0
limit_index_int = int(limit_index) if limit_index is not None else len_axis
start_index_int = _canonicalize_dimension(start_index) if start_index is not None else 0
limit_index_int = _canonicalize_dimension(limit_index) if limit_index is not None else len_axis
# translate negative indices
if start_index_int < 0:
@ -1700,7 +1700,7 @@ def _iter(tracer):
if tracer.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
n = tracer.shape[0]
n = int(tracer.shape[0])
# return (index_in_dim(tracer, i, keepdims=False) for i in range(n))
return iter([index_in_dim(tracer, i, keepdims=False) for i in range(n)])
ShapedArray._iter = staticmethod(_iter)
@ -2523,6 +2523,45 @@ def _conv_general_dilated_batch_rule(
out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
return out, out_spec[1]
def _masked(padded_value, logical_shape, dimensions, value=0):
"""
Sets all padding to the given value (default is 0) in the given dimensions.
All values outside the logical shape are considered padding.
"""
if len(dimensions) == 0:
return padded_value
masks = [broadcasted_iota(onp.int32, padded_value.shape, d) < logical_shape[d]
for d in dimensions]
mask_intersection = masks[0]
for mask in masks[1:]:
mask_intersection &= mask
return select(mask_intersection, padded_value, full_like(padded_value, value))
def _conv_general_dilated_masking_rule(
padded_vals, logical_shapes, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
lhs_shape, rhs_shape, precision):
lhs, rhs = padded_vals
logical_lhs_shape, logical_rhs_shape = logical_shapes
o, i, *window_dimensions = dimension_numbers.rhs_spec
assert (onp.all(onp.take(rhs.shape, window_dimensions)
== onp.take(logical_rhs_shape, window_dimensions))), \
"Conv filter masking not yet implemented."
n, c, *padded_dimensions = dimension_numbers.lhs_spec
return conv_general_dilated(
_masked(lhs, logical_lhs_shape, padded_dimensions),
_masked(rhs, logical_rhs_shape, (i,)),
window_strides=window_strides, padding=padding,
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=precision)
conv_general_dilated_p = standard_primitive(
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
'conv_general_dilated', _conv_general_dilated_translation_rule)
@ -2531,7 +2570,8 @@ ad.defbilinear(conv_general_dilated_p,
_conv_general_dilated_transpose_rhs)
batching.primitive_batchers[conv_general_dilated_p] = \
_conv_general_dilated_batch_rule
masking.masking_rules[conv_general_dilated_p] = \
_conv_general_dilated_masking_rule
def _reshape_axis_into(src, dst, x):
perm = [i for i in range(x.ndim) if i != src]
@ -2683,21 +2723,13 @@ def _dot_general_translation_rule(c, lhs, rhs, *, dimension_numbers, precision):
def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
precision):
lhs, rhs = padded_vals
lhs_shape, rhs_shape = logical_shapes
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
# we need only mask the lhs contraction dimensions
if len(lhs_contract) == 0:
return dot_general(lhs, rhs, dimension_numbers, precision=precision)
else:
masks = [broadcasted_iota(onp.int32, lhs.shape, d) < lhs_shape[d]
for d in lhs_contract]
mask_intersection = masks[0]
for mask in masks[1:]:
mask_intersection &= mask
masked_lhs = select(mask_intersection, lhs, zeros_like_array(lhs))
return dot_general(masked_lhs, rhs, dimension_numbers, precision=precision)
# Only need to mask off contraction dims of one side - we mask the lhs here
# but this is arbitrary. Could check the sizes of lhs and rhs and mask
# whichever is smallest.
lhs_shape, _ = logical_shapes
(lhs_contract, _), _ = dimension_numbers
return dot_general(_masked(lhs, lhs_shape, lhs_contract),
rhs, dimension_numbers, precision=precision)
dot_general_p = standard_primitive(_dot_general_shape_rule,
_dot_general_dtype_rule, 'dot_general',
@ -2925,11 +2957,23 @@ def _pad_translation_rule(c, operand, padding_value, *, padding_config):
return xops.Pad(operand, padding_value,
xc.make_padding_config(padding_config))
def _pad_masking_rule(padded_vals, logical_shapes, padding_config):
operand, padding_value = padded_vals
shape, _ = logical_shapes
out = pad(operand, padding_value, padding_config)
out_shape = [lo + shape[i] * (interior + 1)
for i, (lo, hi, interior) in enumerate(padding_config)]
padded_dims = [i for i, config in enumerate(padding_config)
if config != (0, 0, 0)]
return _masked(out, out_shape, padded_dims, padding_value)
pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad',
translation_rule=_pad_translation_rule)
ad.deflinear(pad_p, _pad_transpose)
ad.primitive_transposes[pad_p] = _pad_transpose
batching.primitive_batchers[pad_p] = _pad_batch_rule
masking.masking_rules[pad_p] = _pad_masking_rule
# The squeeze primitive exists for the benefit of masking and other
@ -3183,12 +3227,16 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
return transpose(operand, perm), 0
def _transpose_masking_rule(padded_vals, logical_shapes, permutation):
return transpose(*padded_vals, permutation=permutation)
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
transpose_p.def_impl(_transpose_impl)
ad.deflinear(transpose_p,
lambda t, permutation: [transpose(t, onp.argsort(permutation))])
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
masking.masking_rules[transpose_p] = _transpose_masking_rule
def _select_shape_rule(pred, on_true, on_false):
@ -3257,6 +3305,13 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
on_false = broadcast(on_false, pred.shape)
return select(pred, on_true, on_false), 0
def _select_masking_rule(padded_vals, logical_shapes):
pred_shape, true_shape, false_shape = [
masking.padded_shape_as_value(val.shape) for val in padded_vals]
assert onp.array_equal(pred_shape, true_shape)
assert onp.array_equal(pred_shape, false_shape)
return select(*padded_vals)
select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select')
ad.defjvp(select_p,
None,
@ -3264,6 +3319,7 @@ ad.defjvp(select_p,
lambda g, b, x, y: select(b, _zeros(g), g))
ad.primitive_transposes[select_p] = _select_transpose_rule
batching.primitive_batchers[select_p] = _select_batch_rule
masking.masking_rules[select_p] = _select_masking_rule
def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
@ -3277,7 +3333,9 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
msg = ("slice limit_indices must have the same length as start_indices, "
"got start_inidices {} and limit_indices {}.")
raise TypeError(msg.format(start_indices, limit_indices))
if not onp.all(onp.less_equal(limit_indices, operand.shape)):
if (not masking.is_polymorphic(limit_indices) and
not masking.is_polymorphic(operand.shape) and
not onp.all(onp.less_equal(limit_indices, operand.shape))):
msg = ("slice limit_indices must be less than or equal to operand shape, "
"got limit_indices {} for operand shape {}.")
raise TypeError(msg.format(limit_indices, operand.shape))
@ -3285,7 +3343,8 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
msg = ("slice start_indices must be greater than or equal to zero, "
"got start_indices of {}.")
raise TypeError(msg.format(start_indices))
if not onp.all(onp.greater_equal(limit_indices, start_indices)):
if (not masking.is_polymorphic(limit_indices) and
not onp.all(onp.greater_equal(limit_indices, start_indices))):
msg = ("slice limit_indices must be greater than or equal to start_indices,"
" got start_indices {} and limit_indices {}.")
raise TypeError(msg.format(start_indices, limit_indices))
@ -3345,10 +3404,19 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
return out, bdim
def _slice_masking_rule(
padded_vals, logical_shapes, start_indices, limit_indices, strides):
operand, = padded_vals
return slice(operand,
start_indices=masking.padded_shape_as_value(start_indices),
limit_indices=masking.padded_shape_as_value(limit_indices),
strides=strides)
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice',
_slice_translation_rule)
ad.deflinear2(slice_p, _slice_transpose_rule)
batching.primitive_batchers[slice_p] = _slice_batching_rule
masking.masking_rules[slice_p] = _slice_masking_rule
def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes):
@ -4056,14 +4124,15 @@ def _masking_defreducer(prim, identity):
masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
axes):
axes, input_shape=None):
(padded_val,), (logical_shape,) = padded_vals, logical_shapes
padded_shape = masking.padded_shape_as_value(padded_val.shape)
masks = [broadcasted_iota(onp.int32, padded_shape, i) < d
for i, d in enumerate(logical_shape) if i in axes]
mask = _reduce(operator.and_, masks)
masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype))
return prim.bind(masked_val, axes=axes)
bind = prim.bind if input_shape is None else partial(prim.bind, input_shape=padded_shape)
return bind(masked_val, axes=axes)
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
_reduce_translation_rule)
@ -4103,7 +4172,8 @@ _masking_defreducer(reduce_sum_p,
lambda shape, dtype: onp.broadcast_to(onp.array(0, dtype), shape))
def _reduce_op_shape_rule(operand, *, axes):
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
del input_shape # unused.
return tuple(onp.delete(operand.shape, axes))
def _reduce_prod_translation_rule(c, operand, *, axes):
@ -4151,6 +4221,8 @@ reduce_prod_p = standard_primitive(
'reduce_prod', _reduce_prod_translation_rule)
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
batching.defreducer(reduce_prod_p)
_masking_defreducer(reduce_prod_p,
lambda shape, dtype: onp.broadcast_to(onp.array(1, dtype), shape))
def _reduce_chooser_shape_rule(operand, *, axes):
@ -4178,6 +4250,8 @@ reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
'reduce_max', _reduce_max_translation_rule)
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_max_p)
_masking_defreducer(reduce_max_p,
lambda shape, dtype: onp.broadcast_to(onp.array(-onp.inf, dtype), shape))
_reduce_min_translation_rule = partial(
@ -4186,6 +4260,8 @@ reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
'reduce_min', _reduce_min_translation_rule)
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_min_p)
_masking_defreducer(reduce_min_p,
lambda shape, dtype: onp.broadcast_to(onp.array(onp.inf, dtype), shape))
def _reduce_logical_shape_rule(operand, *, axes):

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import itertools as it
from unittest import SkipTest
@ -20,23 +19,31 @@ from unittest import SkipTest
import numpy as np
from absl.testing import absltest, parameterized
from jax.interpreters.masking import shape_as_value, ShapeError, \
parse_spec, Poly, Mon
parse_spec, Poly, Mon, finalize_spec, eval_polymorphic_shape, remap_ids, \
UniqueIds
from jax import numpy as jnp, test_util as jtu, mask, vmap, jit, grad, lax, \
shapecheck, api
shapecheck, core
from jax.config import config
from jax.numpy.lax_numpy import _polymorphic_slice_indices
from jax.scipy.special import expit
from jax.util import safe_map, safe_zip
from jax.test_util import rand_default, rand_int
from jax.tree_util import tree_flatten
config.parse_flags_with_absl()
map = safe_map
zip = safe_zip
# These are 'manual' tests for masking. The more exhaustive,
# more systematic tests should live in lax_test.py.
# TODO:
# These should be only the 'manual' tests for masking.
# Move the more exhaustive, systematic tests into lax_test.py.
def constant_poly(c):
return Poly({Mon(): c})
class ShapesTest(jtu.JaxTestCase):
class PolyTest(jtu.JaxTestCase):
@parameterized.parameters([
['(m, n)', 'ShapeSpec(m, n)'],
@ -49,14 +56,17 @@ class ShapesTest(jtu.JaxTestCase):
['(3 * m)', 'ShapeSpec(3 m)'],
['m', 'ShapeSpec(m)'],
['', 'ShapeSpec()'],
['n + -1*n', 'ShapeSpec(0)'],
['m + n', 'ShapeSpec(m + n)'],
['m + n * k', 'ShapeSpec(m + k n)'],
['m + n * k', 'ShapeSpec(k n + m)'],
['m + 3 * k', 'ShapeSpec(3 k + m)'],
['-3 + k + k * k', 'ShapeSpec(k**2 + k + -3)'],
['', 'ShapeSpec()'],
['_', 'ShapeSpec(_)'],
])
def test_parse_spec(self, spec, ans):
self.assertEqual(str(parse_spec(spec)), ans)
self.assertEqual(str(remap_ids(UniqueIds(), parse_spec(spec))), ans)
def test_Poly_equal(self):
assert constant_poly(3) == 3
@ -74,7 +84,12 @@ class ShapesTest(jtu.JaxTestCase):
def test_Poly_hash(self):
assert not len(set(hash(Poly({Mon(): i})) for i in range(10))) == 1
assert hash(Poly({Mon(): 3, Mon({'n': 1}): 4})) == hash(Poly({Mon({'n': 1}): 4, Mon(): 3}))
assert (hash(Poly({Mon(): 3, Mon({'n': 1}): 4}))
== hash(Poly({Mon({'n': 1}): 4, Mon(): 3})))
def test_Mon_hash(self):
assert not len(set(hash(Mon({'a': i})) for i in range(10))) == 1
assert hash(Mon({'a': 1, 'b': 1})) == hash(Mon({'b': 1, 'a': 1}))
def test_Poly_compare(self):
poly = Poly({Mon(): 3, Mon({'n': 1}): 4})
@ -100,208 +115,7 @@ class ShapesTest(jtu.JaxTestCase):
n = Poly({Mon({'n': 1}): 1})
assert -1 - n == -n - 1
def test_add_broadcast(self):
@shapecheck(['n', '(m, n)'], '(m, n)')
@shapecheck(['(m, n)', 'n'], '(m, n)')
@shapecheck(['n', ''], 'n')
def add(a, b):
return a + b
def test_sum(self):
@shapecheck(['(m, n)'], '')
def sum(x):
return jnp.sum(x)
def test_prod(self):
@shapecheck(['(m, n)'], '')
def prod(x):
return jnp.prod(x)
def test_max(self):
@shapecheck(['(m, n)'], '')
def prod(x):
return jnp.max(x)
def test_min(self):
@shapecheck(['(m, n)'], '')
def prod(x):
return jnp.min(x)
def test_dot(self):
@shapecheck(['(m, n)', 'n'], 'm')
def matvec(A, b):
return jnp.dot(A, b)
def thunk():
@shapecheck(['(m, n)', 'n'], 'm')
def matvec(A, b):
return lax.dot_general(A, b, [((0,), (0,)), ((), ())])
self.assertRaisesRegex(TypeError, "", thunk)
def test_flatten(self):
@shapecheck(['(m, n)'], 'm * n')
def flatten(x):
return lax.reshape(x, (x.shape[0] * x.shape[1],))
def test_concatenate(self):
@shapecheck(['m', 'n', 'm'], '3*m + n')
def cat(x, y, z):
return lax.concatenate([x, y, x, z], 0)
def thunk():
@shapecheck(['m', 'n', 'm'], '3*m + n')
def cat(x, y, z):
return lax.concatenate([x, y, x], 0)
self.assertRaisesRegex(ShapeError, "", thunk)
def test_device_put(self):
@shapecheck(['n'], 'n')
def d_put(x):
return api.device_put(x)
def test_broadcast_in_dim(self):
@shapecheck(['(n,)'], '(3, n, 4)')
def broadcast_in_dim(x):
return lax.broadcast_in_dim(x, shape=(3, x.shape[0], 4), broadcast_dimensions=(1,))
@shapecheck(['(n, 1)'], '(3, n, 4, 1)')
def broadcast_in_dim2(x):
return lax.broadcast_in_dim(x, shape=(3, x.shape[0], 4, x.shape[1]), broadcast_dimensions=(1, 3))
def test_jit(self):
@shapecheck(['n'], '2*n')
@jit
def concat(x):
return lax.concatenate([x, x], 0)
# TODO:
# @shapecheck(['n'], 'n')
# @jit
# @grad
# def sum_square(x):
# return jnp.sum(x ** 2)
def test_pad(self):
@shapecheck(['n'], '2*n+1')
def p(x):
return lax.pad(x, jnp.array(0., x.dtype), [(1, 1, 1)])
def test_numpy_pad(self):
@shapecheck(['n'], 'n+1')
def p(x):
return jnp.pad(x, (0, 1))
@parameterized.named_parameters(jtu.cases_from_list(
{
'testcase_name': "strides={}_padding={}_lhs_dilation={}_dimension_numbers"
"={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
strides, padding, lhs_dilation, dimension_numbers, lhs_perm, rhs_perm, out_perm),
'strides': strides, 'padding': padding, 'lhs_dilation': lhs_dilation,
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
'rhs_perm': rhs_perm, 'out_perm': out_perm}
for strides in [(1, 1), (2, 1)]
for padding in ['SAME', 'VALID', ((1, 0), (2, 0))]
for lhs_dilation in (None, (1, 2))
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
)
# String padding is not implemented for transposed convolution, see conv_general_dilated implementation:
if (lhs_dilation is None or not isinstance(padding, str)) and
# only test strides with same padding:
(strides[0] == 1 or padding == 'SAME')))
def test_conv(self, strides, padding, lhs_dilation,
dimension_numbers, lhs_perm, rhs_perm, out_perm):
valid = padding == 'VALID'
is_strided = strides[0] != 1
lhs_shape = '({}, {}, {}, {})'.format(*np.take(['n', 'i', '2*h' if is_strided else 'h', 'w'], lhs_perm))
rhs_shape = '({}, {}, {}, {})'.format(*np.take(['o', 'i', '2', '3'], rhs_perm))
out_shape = '({}, {}, {}, {})'.format(*np.take([
'n', 'o', 'h+-1' if valid and not is_strided else 'h',
('w+-2' if valid else 'w') if lhs_dilation is None else '2*w+-1'], out_perm))
@shapecheck([lhs_shape, rhs_shape], out_shape)
def conv(lhs, rhs):
return lax.conv_general_dilated(
lhs, rhs, strides, padding,
lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers)
def test_indexing(self):
@shapecheck(['n'], '')
def first(x):
return x[0]
@shapecheck(['n'], '')
def last(x):
return x[-1]
@shapecheck(['(n,m,a)'], 'n,m')
@vmap
@shapecheck(['(n,a)'], 'n')
def last_column(x):
return x[..., -1]
def test_slicing(self):
@shapecheck(['n'], 'n+-1')
def slice(x):
return x[1:]
@shapecheck(['n'], 'n+-1')
def slice2(x):
return x[:-1]
@shapecheck(['n'], 'n+-1')
def inverse(x):
return x[:0:-1]
@shapecheck(['n'], 'n+-1')
def inverse2(x):
return x[-2::-1]
def test_poly_slicing(self):
@shapecheck(['n'], 'n+-1')
def slice_poly_stop(x):
return x[:x.shape[0] - 1]
# TODO: @shapecheck(['n'], '1')
def slice_poly_start(x):
return x[x.shape[0] - 1:]
def test_iota(self):
raise SkipTest("not yet implemented")
# https://travis-ci.org/github/google/jax/jobs/682086351
@shapecheck(['n'], 'n')
def range_like(x):
return lax.iota(jnp.int32, x.shape[0])
def test_arange(self):
raise SkipTest("not yet implemented")
# https://travis-ci.org/github/google/jax/jobs/682086351
@shapecheck(['n'], 'n')
def arange_like(x):
return jnp.arange(x.shape[0], dtype=jnp.int32)
def test_expit(self):
@shapecheck(['n'], 'n')
def expit_(x):
return expit(x)
def test_reshape(self):
@shapecheck(['n, a, b'], 'n, a*b')
def flatten(x):
return jnp.reshape(x, (x.shape[0], x.shape[1] * x.shape[2]))
def test_ravel(self):
a = jnp.array(1)
@shapecheck(['n'], '')
def thunk(n):
return -(a + n.ravel()[0] * 0)
class MaskingTest(jtu.JaxTestCase):
def test_sum(self):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
@ -324,10 +138,43 @@ class MaskingTest(jtu.JaxTestCase):
expected = np.array([0, 1, 2, 3, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
def check(self, fun, in_shapes, out_shape, logical_env, padded_in_shapes,
dtypes, rng, rtol=None, atol=None):
shapecheck(in_shapes, out_shape)(fun)
masked_fun = mask(fun, in_shapes, out_shape)
padded_args = [rng(shape, dtype)
for shape, dtype in zip(padded_in_shapes, dtypes)]
padded_outs, outs_tree = tree_flatten(masked_fun(padded_args, logical_env))
out_specs, _ = tree_flatten(out_shape)
out_specs = map(parse_spec, out_specs)
out_specs = map(finalize_spec, out_specs, map(np.shape, padded_outs))
logical_out_shapes = [eval_polymorphic_shape(s, logical_env)
for s in out_specs]
logical_out_slices = [tuple(map(slice, s)) for s in logical_out_shapes]
logical_outs = [o[s] for o, s in zip(padded_outs, logical_out_slices)]
in_specs = map(parse_spec, in_shapes)
in_specs = map(finalize_spec, in_specs, padded_in_shapes)
logical_in_shapes = [eval_polymorphic_shape(s, logical_env)
for s in in_specs]
logical_in_slices = [tuple(map(slice, s)) for s in logical_in_shapes]
logical_args = [a[s] for a, s in zip(padded_args, logical_in_slices)]
logical_outs_expected, logical_outs_tree = tree_flatten(fun(*logical_args))
assert outs_tree == logical_outs_tree
self.assertAllClose(logical_outs, logical_outs_expected, check_dtypes=True,
atol=atol, rtol=rtol)
# Check that abstract evaluation works
padded_outs_jit, _ = tree_flatten(jit(masked_fun)(padded_args, logical_env))
self.assertAllClose(padded_outs_jit, padded_outs, check_dtypes=True,
atol=atol, rtol=rtol)
def test_add(self):
@partial(mask, in_shapes=['n', 'n'], out_shape='n')
def addvecs(x, y):
return x + y
self.check(lax.add, ['n', ''], 'n', {'n': 3}, [(4,), ()], ['float_', 'float_'],
rand_default(self.rng()))
addvecs = mask(lax.add, in_shapes=['n', 'n'], out_shape='n')
x = jnp.array([3, 1, 4, 1, 5, 9])
y = jnp.array([2, 6, 5, 3, 5, 8])
@ -384,35 +231,12 @@ class MaskingTest(jtu.JaxTestCase):
expected = 5
self.assertAllClose(ans, expected, check_dtypes=False)
def test_concatenate(self):
@partial(mask, in_shapes=['n', 'm', 'n'], out_shape='m + 2 * n')
def cat(x, y, z):
return lax.concatenate([x, y, z], 0)
ans = cat([jnp.array([1, 9]), jnp.array([2, 4, 9]), jnp.array([3, 9])],
dict(n=1, m=2))
expected = np.array([1, 2, 4, 3])
self.assertAllClose(ans[:4], expected, check_dtypes=False)
def test_dot(self):
@partial(mask, in_shapes=['(m, k)', '(k, n)'], out_shape='(m, n)')
def dot(x, y):
return lax.dot(x, y)
x = np.arange(6, dtype=np.float32).reshape((2, 3))
y = np.arange(12, dtype=np.float32).reshape((3, 4))
ans = dot([x, y], dict(m=2, k=2, n=2))
expected = np.dot(x[:2, :2], y[:2, :2])
self.assertAllClose(ans[:2, :2], expected, check_dtypes=False)
def test_mean(self):
@partial(mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return jnp.sum(x) / shape_as_value(x.shape)[0]
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
expected = 8 / 3
self.assertAllClose(ans, expected, check_dtypes=False)
# TODO Shapecheck fails - shape_as_value can't deal with abstract eval yet
raise SkipTest
self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'], '',
{'n': 3}, [(4,)], ['float_'],
rand_default(self.rng()))
def test_arithmetic(self):
@partial(mask, in_shapes=['(n, m)', 'm'], out_shape='(n, m)')
@ -420,9 +244,11 @@ class MaskingTest(jtu.JaxTestCase):
return x * y
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
with self.assertRaisesRegex(KeyError, 'broadcast_in_dim'):
_ = times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
dict(n=4, m=5))
with self.assertRaisesRegex(
NotImplementedError,
'Masking rule for broadcast_in_dim not implemented yet.'):
times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
dict(n=4, m=5))
# expected = np.array([[1, 2, 3], [8, 10, 12]])
# self.assertAllClose(ans, expected, check_dtypes=False)
@ -432,8 +258,10 @@ class MaskingTest(jtu.JaxTestCase):
return jnp.stack([x, y], 0)
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
with self.assertRaisesRegex(KeyError, 'broadcast_in_dim'):
_ = stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
with self.assertRaisesRegex(
NotImplementedError,
'Masking rule for broadcast_in_dim not implemented yet.'):
stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
# expected = np.array([[1, 2, 3], [4, 5, 6]])
# self.assertAllClose(ans, expected, check_dtypes=False)
@ -464,6 +292,10 @@ class MaskingTest(jtu.JaxTestCase):
expected = jnp.array([3, 5])
self.assertAllClose(ans, expected, check_dtypes=False)
@shapecheck(['(2*n, n)'], '_, n')
def identity(x):
return x
def test_rnn(self):
n = 3
@ -544,8 +376,285 @@ class MaskingTest(jtu.JaxTestCase):
expected = grad(lambda W: rnn_reference(W, seqs_, ys).sum())(W)
self.assertAllClose(
ans, expected, check_dtypes=False,
rtol=2e-2 if jtu.device_under_test() == "tpu" else 1e-5)
ans, expected, check_dtypes=False,
rtol=2e-2 if jtu.device_under_test() == "tpu" else 1e-5)
def test_concatenate(self):
self.check(lambda x, y, z: lax.concatenate([x, y, z], 0),
['n', 'm', 'n'], 'm + 2 * n', {'n': 2, 'm': 3},
[(4,), (3,), (4,)], ['float_', 'float_', 'float_'],
rand_default(self.rng()))
def test_dot(self):
self.check(lax.dot, ['(m, k)', '(k, n)'], '(m, n)',
dict(m=2, k=3, n=4), [(4, 5), (5, 7)], ['float_', 'float_'],
rand_default(self.rng()))
self.check(lax.dot, ['(m, n)', 'n'], 'm', dict(m=2, n=3), [(4, 5), (5,)],
['float_', 'float_'], rand_default(self.rng()))
def test_jit(self):
raise SkipTest
@partial(mask, in_shapes=['n'], out_shape='2*n')
@jit
def duplicate(x):
assert python_should_be_executing
return lax.concatenate([x, x], 0)
python_should_be_executing = True
out = duplicate([jnp.arange(3)], dict(n=2))
assert np.all(np.array([0, 1, 0, 1]) == out[:4])
python_should_be_executing = False
out = duplicate([jnp.arange(3)], dict(n=2))
assert np.all(np.array([0, 1, 0, 1]) == out[:4])
@parameterized.named_parameters({
'testcase_name': "padding_config={}_shapes={}".format(padding_config,
shape),
'padding_config': padding_config,
'shape': shape} for padding_config, shape in (
(((1, 2, 0),), (2,)),
(((1, 2, 0), (3, 4, 0)), (1, 2)),
(((0, 0, 0), (0, 0, 0)), (1, 2)),
(((1, 2, 3),), (2,)),
(((1, 2, 1), (3, 4, 2)), (3, 2)),
(((-1, 2, 0),), (2,)),
(((-1, -2, 0), (1, 2, 0)), (4, 2)),
(((-1, 2, 0), (1, 2, 2)), (4, 2)),
(((-1, -2, 2),), (5,)),
(((-1, -2, 1), (1, 2, 2)), (4, 2))))
def test_pad(self, padding_config, shape):
def pad(x):
return lax.pad(x, jnp.array(1., x.dtype), padding_config)
if len(shape) == 1:
padding_config_, = padding_config
linear_coeff = padding_config_[2] + 1
const_coeff = sum(padding_config_[:2]) - padding_config_[2]
out_shape = str(linear_coeff) + ' * h + ' + str(const_coeff)
self.check(pad, ['h'], out_shape, dict(h=shape[0]),
[tuple(np.add(shape, 1))], ['float_'],
rand_default(self.rng()))
def test_numpy_pad(self):
# TODO (j-towns) requires mask(jit)
raise SkipTest
def numpy_pad(x):
return jnp.pad(x, (0, 1), constant_values=5.)
self.check(numpy_pad, ['n'], 'n + 1', dict(n=2), [(3,)], ['float_'],
rand_default(self.rng()))
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name': "padding={}_lhs_dilation={}_"
"dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
padding, lhs_dilation, dimension_numbers, lhs_perm,
rhs_perm, out_perm),
'padding': padding, 'lhs_dilation': lhs_dilation,
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
'rhs_perm': rhs_perm, 'out_perm': out_perm}
for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
for lhs_dilation in (None, (1, 2))
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
)
# String padding is not implemented for transposed convolution, see
# conv_general_dilated implementation:
if (lhs_dilation is None or not isinstance(padding, str))))
def test_conv(
self, padding, lhs_dilation, dimension_numbers, lhs_perm,
rhs_perm, out_perm):
def conv(lhs, rhs):
return lax.conv_general_dilated(
lhs, rhs, (1, 1), padding, lhs_dilation=lhs_dilation,
dimension_numbers=dimension_numbers)
template = '({}, {}, {}, {})'
lhs_shape = template.format(*np.take(['n', 'c', 'h', 'w'], lhs_perm))
rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
if padding == 'VALID':
out_shape = template.format(
*np.take(['n', 'o', 'h+-1', 'w+-2'], out_perm))
elif lhs_dilation:
out_shape = template.format(
*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
else:
out_shape = template.format(
*np.take(['n', 'o', 'h', 'w'], out_perm))
logical_env = dict(n=3, c=2, h=4, w=5, o=6)
self.check(conv, [lhs_shape, rhs_shape], out_shape,
logical_env, [tuple(np.take([4, 3, 6, 7], lhs_perm)),
tuple(np.take([7, 3, 2, 3], rhs_perm))],
['float_', 'float_'], rand_default(self.rng()), rtol=1e-4,
atol=1e-4)
@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name': "padding={}_lhs_dilation={}_"
"dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
padding, lhs_dilation, dimension_numbers, lhs_perm,
rhs_perm, out_perm),
'padding': padding, 'lhs_dilation': lhs_dilation,
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
'rhs_perm': rhs_perm, 'out_perm': out_perm}
for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
for lhs_dilation in (None, (1, 2))
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
)
# String padding is not implemented for transposed convolution, see
# conv_general_dilated implementation:
if (lhs_dilation is None or not isinstance(padding, str))))
def test_conv_strided(
self, padding, lhs_dilation, dimension_numbers, lhs_perm,
rhs_perm, out_perm):
def conv(lhs, rhs):
return lax.conv_general_dilated(
lhs, rhs, (2, 1), padding, lhs_dilation=lhs_dilation,
dimension_numbers=dimension_numbers)
template = '({}, {}, {}, {})'
rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
if padding == 'VALID':
lhs_shape = template.format(*np.take(['n', 'c', '2*h+1', 'w'], lhs_perm))
lhs_shape_padded = tuple(np.take([4, 3, 5, 7], lhs_perm))
out_shape = template.format(*np.take(['n', 'o', 'h', 'w+-2'], out_perm))
elif lhs_dilation:
lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
out_shape = template.format(*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
else:
lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
out_shape = template.format(*np.take(['n', 'o', 'h', 'w'], out_perm))
logical_env = dict(n=3, c=2, h=4, w=5, o=6)
self.check(conv, [lhs_shape, rhs_shape], out_shape,
logical_env, [lhs_shape_padded,
tuple(np.take([7, 3, 2, 3], rhs_perm))],
['float_', 'float_'], rand_default(self.rng()), rtol=1e-4,
atol=1e-4)
def test_indexing(self):
# Requires gather support
raise SkipTest
self.check(lambda x: x[0], ['n'], '', {'n': 2}, [(3,)], ['float_'],
rand_default(self.rng()))
self.check(lambda x: x[-1], ['n'], '', {'n': 2}, [(3,)], ['float_'],
rand_default(self.rng()))
def test_slicing(self):
raise SkipTest
# Requires gather support
self.check(lambda x: x[1:], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
self.check(lambda x: x[:-1], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
self.check(lambda x: x[..., -1], ['(n,3)'], 'n', {'n': 2}, [(3, 4)], ['float_'])
def test_rev(self):
@shapecheck(['n'], 'n+-1')
def rev(x):
return x[:0:-1]
@shapecheck(['n'], 'n+-1')
def rev2(x):
return x[-2::-1]
# TODO implement masking for rev_p:
# self.check(lambda x: x[:0:-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
# self.check(lambda x: x[-2::-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
def test_lax_slice(self):
self.check(lambda x: lax.slice(x, (1,), (x.shape[0],)), ['n'], 'n+-1',
{'n': 2}, [(3,)], ['float_'], rand_default(self.rng()))
# TODO: self.check(lambda x: lax.slice(x, (x.shape[0] // 2,), (x.shape[0],)), ['2*n'], dict(n=jnp.array([2, 3])), 'n')
def test_reshape(self):
raise SkipTest
def test_transpose(self):
self.check(lambda x: lax.transpose(x, (1, 0, 2)),
['(a, b, c)'], 'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)],
['float_'], rand_default(self.rng()))
def test_sum_2d(self):
self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)], ['float_'],
rand_default(self.rng()))
def test_expit(self):
raise SkipTest("custom_jvp doesn't work with masking yet")
self.check(expit, ['n'], 'n', dict(n=3), [(4,)], ['float_'],
rand_default(self.rng()))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
for dtype in [np.float32, np.float64]))
def test_uniform(self, dtype):
raise SkipTest("not yet implemented")
# TODO needs fix for https://github.com/google/jax/issues/2155
def test_broadcast_in_dim(self):
raise SkipTest
def test_destructure(self):
def d(key):
key1, key2 = key
return key1
self.check(d, ['2'], '', {}, [(2,)], ['int_'], rand_int(self.rng(), 0, 10))
def test_where(self):
# Requires mask(jit)
raise SkipTest
self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n',
{'n': 2}, [(3,)], ['float_'], rand_default(self.rng()))
def test_split(self):
raise SkipTest
@parameterized.named_parameters(jtu.cases_from_list([{
'testcase_name': "operator={}".format(operator.__name__), 'operator': operator}
for operator in [jnp.sum, jnp.prod, jnp.max, jnp.min]]))
def test_reduce(self, operator):
self.check(operator, ['(m, n)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
rand_default(self.rng()))
def test_output_shape_error(self):
def thunk():
shapecheck(['n'], 'n+-1')(lambda x: x)
message = "Output shapes should be (n + -1,) but are (n,)."
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
def thunk():
shapecheck(['n'], ['7*n', 'n'])(lambda x: (x, x))
message = "Output shapes should be [(7 n,), (n,)] but are ((n,), (n,))."
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
def test_output_tree_error(self):
def thunk():
shapecheck(['n'], ('n', 'n'))(lambda x: [x, x])
message = "Output shapes should be ((n,), (n,)) but are [(n,), (n,)]."
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
def test_unsupported_op(self):
p = core.Primitive('unsupported_op')
p.def_abstract_eval(lambda x: x)
p.def_impl(lambda x: x)
def thunk():
mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2})
message = "Masking rule for unsupported_op not implemented yet."
self.assertRaisesWithLiteralMatch(NotImplementedError, message, thunk)
def test_nesting(self):
raise SkipTest("not yet implemented")
@ -568,16 +677,6 @@ class MaskingTest(jtu.JaxTestCase):
expected = 3+1 + 5+9+2
self.assertAllClose(ans, expected, check_dtypes=False)
def test_arange(self):
raise SkipTest("not yet implemented")
@partial(mask, in_shapes=['n'], out_shape='n')
def padded_add(x):
return x + lax.iota(x.shape[0])
ans = padded_add([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
expected = np.array([3, 2, 6])
self.assertAllClose(ans[:3], expected, check_dtypes=False)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_start={}_stop={}_step={}_length={}"