mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implement mask for some primitives + jit. (#2922)
* Implement mask for slice, conv, pad, transpose, where * Remove tentative mask(jit) * Add explanatory comment to dot_general masking rule * Rm reshape from select masking rule * Rm unnecessary check from lax slice abstract_eval rule * Revert to standard indentation in masking_test.py * Begin simplifying masking tests * Finish drafting masking check function * More progress simplifying tests * Add conv masking in batch dim * Finish fixing up tests * Revert to old API, making out_shape compulsory again * More efficient conv masking rule * Tidy up masking_test imports * Check that out tree is preserved by masking * fix flake errors Co-authored-by: Jamie Townsend <jamestownsend@google.com> Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com> Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
0db57cb541
commit
d1dbf7c7d8
72
jax/api.py
72
jax/api.py
@ -54,7 +54,6 @@ from .lib import xla_client as xc
|
||||
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
|
||||
host_id, host_ids, host_count)
|
||||
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from .interpreters.masking import eval_polymorphic_shape, Poly, Mon
|
||||
from .interpreters import partial_eval as pe
|
||||
from .interpreters import xla
|
||||
from .interpreters import pxla
|
||||
@ -1289,67 +1288,50 @@ def _parallelize(fun):
|
||||
|
||||
def mask(fun: Callable, in_shapes, out_shape) -> Callable:
|
||||
_check_callable(fun)
|
||||
unique_ids = masking.UniqueIds()
|
||||
|
||||
in_specs, in_shapes_tree = tree_flatten(in_shapes)
|
||||
out_specs, out_shapes_tree = tree_flatten(out_shape)
|
||||
|
||||
in_specs = map(masking.parse_spec, in_specs)
|
||||
out_specs = map(masking.parse_spec, out_specs)
|
||||
in_specs = map(partial(masking.remap_ids, unique_ids), in_specs)
|
||||
|
||||
unique_ids: Dict[Any, Any] = collections.defaultdict(object)
|
||||
in_specs = map(partial(_remap_ids, unique_ids), in_specs)
|
||||
out_specs = map(partial(_remap_ids, unique_ids), out_specs)
|
||||
out_specs, out_spec_tree = tree_flatten(out_shape)
|
||||
out_specs = map(masking.parse_spec, out_specs)
|
||||
out_specs = map(partial(masking.remap_ids, unique_ids), out_specs)
|
||||
|
||||
def wrapped_fun(args, logical_env):
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
if in_tree != in_shapes_tree: raise TypeError("pytree mismatch")
|
||||
if in_tree != in_shapes_tree:
|
||||
raise TypeError(f"Tree mismatch: Input {in_tree} and shape spec {in_shapes_tree}.")
|
||||
logical_env = {unique_ids[name] : val for name, val in logical_env.items()}
|
||||
in_shapes = map(masking.finalize_spec, in_specs, map(onp.shape, args_flat))
|
||||
padded_env = _bind_shapes(in_shapes, [x.shape for x in args_flat])
|
||||
padded_env = masking.bind_shapes(in_shapes, [x.shape for x in args_flat])
|
||||
f = lu.wrap_init(fun)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
||||
outs, out_shapes_ = masking.mask_fun(
|
||||
flat_fun, logical_env, padded_env, args_flat, in_shapes)
|
||||
if not out_tree() == out_shapes_tree: raise TypeError("pytree mismatch")
|
||||
out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs))
|
||||
if not out_shapes == list(out_shapes_):
|
||||
raise masking.ShapeError
|
||||
if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env)
|
||||
for out, shape in zip(outs, out_shapes)):
|
||||
raise masking.ShapeError
|
||||
return tree_unflatten(out_tree(), outs)
|
||||
flat_fun, out_tree_thunk = flatten_fun_nokwargs(f, in_tree)
|
||||
outs, out_shapes = masking.mask_fun(
|
||||
flat_fun, logical_env, padded_env, args_flat, in_shapes)
|
||||
out_tree = out_tree_thunk()
|
||||
|
||||
masking.check_shapes(out_specs, out_spec_tree, list(out_shapes), out_tree)
|
||||
def padded_spec(shape_spec):
|
||||
return tuple(dim if dim is masking._monomorphic_dim else
|
||||
masking.eval_poly(dim, padded_env) for dim in shape_spec)
|
||||
masking.check_shapes(map(padded_spec, out_specs), out_spec_tree,
|
||||
map(onp.shape, outs), out_tree, "Padded output")
|
||||
return tree_unflatten(out_tree, outs)
|
||||
return wrapped_fun
|
||||
|
||||
def _remap_ids(names, shape_spec):
|
||||
return masking.ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
|
||||
: coeff for mon, coeff in poly.items()})
|
||||
if poly is not masking._monomorphic_dim else
|
||||
masking._monomorphic_dim for poly in shape_spec)
|
||||
|
||||
def _bind_shapes(shape_exprs, shapes):
|
||||
env = {}
|
||||
for shape_expr, shape in zip(shape_exprs, shapes):
|
||||
for poly, d in zip(shape_expr, shape):
|
||||
if type(poly) is not Poly or poly.is_constant:
|
||||
continue
|
||||
else:
|
||||
(binder,), = poly # TODO generalize to handle striding
|
||||
if env.setdefault(binder, d) != d: raise masking.ShapeError
|
||||
return env
|
||||
|
||||
|
||||
@curry
|
||||
def shapecheck(in_shapes, out_shape, fun: Callable):
|
||||
_check_callable(fun)
|
||||
in_shapes, in_tree = tree_flatten(in_shapes)
|
||||
in_shapes = map(masking.parse_spec, in_shapes)
|
||||
out_shapes, out_tree = tree_flatten(out_shape)
|
||||
out_shapes = map(masking.parse_spec, out_shapes)
|
||||
flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
out_specs, out_spec_tree = tree_flatten(out_shape)
|
||||
out_specs = map(masking.parse_spec, out_specs)
|
||||
flat_fun, out_tree_thunk = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
|
||||
out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
|
||||
if out_tree != out_tree_(): raise TypeError("pytree mismatch")
|
||||
if not all(map(masking._shape_spec_consistent, out_shapes, out_shapes_)):
|
||||
raise masking.ShapeError
|
||||
out_shapes = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
|
||||
masking.check_shapes(map(tuple, out_specs), out_spec_tree,
|
||||
map(tuple, out_shapes), out_tree_thunk())
|
||||
return fun
|
||||
|
||||
def jvp(fun: Callable, primals, tangents) -> Tuple[Any, Any]:
|
||||
|
@ -18,12 +18,13 @@ from functools import partial
|
||||
from itertools import chain, product
|
||||
import operator as op
|
||||
import string
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Sequence, Union
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from .. import abstract_arrays
|
||||
from .. import core
|
||||
from ..tree_util import tree_unflatten
|
||||
from ..core import Trace, Tracer
|
||||
from ..util import safe_map, safe_zip, unzip2, prod
|
||||
from ..abstract_arrays import ShapedArray
|
||||
@ -53,42 +54,52 @@ def naryop_masking_rule(prim, padded_vals, logical_shapes):
|
||||
ShapeEnvs = namedtuple("ShapeEnvs", ["logical", "padded"])
|
||||
shape_envs = ShapeEnvs({}, {}) # TODO(mattjj): make this a stack for efficiency
|
||||
|
||||
def is_tracing():
|
||||
return bool(shape_envs.padded)
|
||||
|
||||
@contextmanager
|
||||
def extend_shape_envs(logical_env, padded_env):
|
||||
global shape_envs
|
||||
new_logical = dict(chain(shape_envs.logical.items(), logical_env.items()))
|
||||
new_padded = dict(chain(shape_envs.padded.items(), padded_env.items()))
|
||||
shape_envs, prev = ShapeEnvs(new_logical, new_padded), shape_envs
|
||||
yield
|
||||
shape_envs = prev
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
shape_envs = prev
|
||||
|
||||
def shape_as_value(shape):
|
||||
assert is_tracing() or not is_polymorphic(shape)
|
||||
return eval_polymorphic_shape(shape, shape_envs.logical)
|
||||
|
||||
def padded_shape_as_value(shape):
|
||||
assert is_tracing() or not is_polymorphic(shape)
|
||||
return eval_polymorphic_shape(shape, shape_envs.padded)
|
||||
|
||||
def mask_fun(fun, logical_env, padded_env, in_vals, shape_exprs):
|
||||
def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
|
||||
with core.new_master(MaskTrace) as master:
|
||||
fun, out_shapes = mask_subtrace(fun, master)
|
||||
fun, out_shapes = mask_subtrace(fun, master, polymorphic_shapes)
|
||||
with extend_shape_envs(logical_env, padded_env):
|
||||
out_vals = fun.call_wrapped(in_vals, shape_exprs)
|
||||
out_vals = fun.call_wrapped(*in_vals)
|
||||
del master
|
||||
return out_vals, out_shapes()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def mask_subtrace(master, in_vals, shape_exprs):
|
||||
def mask_subtrace(master, polymorphic_shapes, *in_vals):
|
||||
trace = MaskTrace(master, core.cur_sublevel())
|
||||
in_tracers = [MaskTracer(trace, x, s).full_lower()
|
||||
for x, s in zip(in_vals, shape_exprs)]
|
||||
for x, s in zip(in_vals, polymorphic_shapes)]
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
out_vals, out_shapes = unzip2((t.val, t.shape_expr) for t in out_tracers)
|
||||
out_vals, out_shapes = unzip2((t.val, t.polymorphic_shape)
|
||||
for t in out_tracers)
|
||||
yield out_vals, out_shapes
|
||||
|
||||
def eval_polymorphic_shape(shape, values_dict):
|
||||
return tuple(dim.evaluate(values_dict) if type(dim) is Poly else dim
|
||||
for dim in shape)
|
||||
return tuple(eval_poly(dim, values_dict) for dim in shape)
|
||||
|
||||
def eval_poly(poly, values_dict):
|
||||
return poly.evaluate(values_dict) if type(poly) is Poly else poly
|
||||
|
||||
def _ensure_poly(p):
|
||||
if type(p) is Poly:
|
||||
@ -96,12 +107,16 @@ def _ensure_poly(p):
|
||||
|
||||
return Poly({Mon(): p})
|
||||
|
||||
def is_polymorphic(shape: Sequence[Union[int, 'Poly']]):
|
||||
return any(map(lambda d: type(d) is Poly, shape))
|
||||
|
||||
class Poly(dict):
|
||||
"""Polynomial with nonnegative integer coefficients for polymorphic shapes."""
|
||||
|
||||
def __init__(self, coeffs):
|
||||
# Makes sure Polynomials are always in canonical form
|
||||
coeffs = {mon: op.index(coeff) for mon, coeff in coeffs.items() if coeff != 0}
|
||||
coeffs = {mon: op.index(coeff)
|
||||
for mon, coeff in coeffs.items() if coeff != 0}
|
||||
coeffs = coeffs or {Mon(): 0}
|
||||
super().__init__(coeffs)
|
||||
|
||||
@ -149,13 +164,14 @@ class Poly(dict):
|
||||
def divided(count):
|
||||
q, r = divmod(count, divisor)
|
||||
if r != 0:
|
||||
raise ValueError('shapecheck currently only supports strides '
|
||||
'that exactly divide the strided axis length.')
|
||||
raise ValueError('shapecheck and masking currently only support '
|
||||
'strides that exactly divide the strided axis '
|
||||
'length.')
|
||||
return q
|
||||
|
||||
return Poly(
|
||||
{k: coeff // divisor if k.degree == 0 else divided(coeff)
|
||||
for k, coeff in self.items()}), self.get(Mon(), 0) % divisor
|
||||
for k, coeff in self.items()}), self.get(Mon(), 0) % divisor
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(sorted(self.items())))
|
||||
@ -196,6 +212,9 @@ class Poly(dict):
|
||||
if (v != 1 or k.degree == 0) else str(k)
|
||||
for k, v in sorted(self.items())).strip()
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __int__(self):
|
||||
assert self.is_constant
|
||||
return op.index(next(iter(self.values())))
|
||||
@ -213,7 +232,7 @@ abstract_arrays._DIMENSION_TYPES.add(Poly)
|
||||
|
||||
class Mon(dict):
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.items()))
|
||||
return hash(frozenset(self.items()))
|
||||
|
||||
def __str__(self):
|
||||
return ' '.join('{}**{}'.format(k, v) if v != 1 else str(k)
|
||||
@ -221,8 +240,8 @@ class Mon(dict):
|
||||
|
||||
def __lt__(self, other):
|
||||
# sort by total degree, then lexicographically on indets
|
||||
self_key = self.degree, tuple(sorted(self))
|
||||
other_key = other.degree, tuple(sorted(other))
|
||||
self_key = -self.degree, tuple(sorted(self))
|
||||
other_key = -other.degree, tuple(sorted(other))
|
||||
return self_key < other_key
|
||||
|
||||
def __mul__(self, other):
|
||||
@ -258,9 +277,9 @@ class ShapeSpec(tuple):
|
||||
def __str__(self):
|
||||
return 'ShapeSpec({})'.format(', '.join(map(str, self)))
|
||||
|
||||
def finalize_spec(spec, shape):
|
||||
def finalize_spec(polymorphic_shape, padded_shape):
|
||||
return tuple(_parse_lit(d) if e is _monomorphic_dim else e
|
||||
for e, d in zip(spec, shape))
|
||||
for e, d in zip(polymorphic_shape, padded_shape))
|
||||
|
||||
def parse_spec(spec=''):
|
||||
if not spec:
|
||||
@ -311,19 +330,24 @@ def _shape_spec_consistent(spec, expr):
|
||||
return all(a == b for a, b in zip(spec, expr) if a is not _monomorphic_dim)
|
||||
|
||||
class MaskTracer(Tracer):
|
||||
__slots__ = ["val", "shape_expr"]
|
||||
__slots__ = ["val", "polymorphic_shape"]
|
||||
|
||||
def __init__(self, trace, val, shape_expr):
|
||||
self._trace = trace
|
||||
def __init__(self, trace, val, polymorphic_shape):
|
||||
super().__init__(trace)
|
||||
self.val = val
|
||||
self.shape_expr = shape_expr
|
||||
self.polymorphic_shape = polymorphic_shape
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
return ShapedArray(self.shape_expr, self.val.dtype)
|
||||
return ShapedArray(self.polymorphic_shape, self.dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.val.dtype
|
||||
|
||||
def is_pure(self):
|
||||
return all(type(poly) is not Poly or poly.is_constant for poly in self.shape_expr)
|
||||
return all(type(poly) is not Poly or poly.is_constant
|
||||
for poly in self.polymorphic_shape)
|
||||
|
||||
def full_lower(self):
|
||||
if self.is_pure():
|
||||
@ -340,23 +364,74 @@ class MaskTrace(Trace):
|
||||
return MaskTracer(self, val, onp.shape(val))
|
||||
|
||||
def sublift(self, val):
|
||||
return MaskTracer(self, val.val, val.shape_expr)
|
||||
return MaskTracer(self, val.val, val.polymorphic_shape)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
vals, shape_exprs = unzip2((t.val, t.shape_expr) for t in tracers)
|
||||
vals, polymorphic_shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
|
||||
if primitive in shape_parameterized_primitive_rules:
|
||||
rule = shape_parameterized_primitive_rules[primitive]
|
||||
out, out_shape = rule(shape_envs, vals, shape_exprs, **params)
|
||||
out, out_shape = rule(shape_envs, vals, polymorphic_shapes, **params)
|
||||
else:
|
||||
avals = [t.aval for t in tracers]
|
||||
out = primitive.abstract_eval(*avals, **params)
|
||||
out_shape = [o.shape for o in out] if primitive.multiple_results else out.shape
|
||||
logical_shapes = map(partial(eval_polymorphic_shape, values_dict=shape_envs.logical), shape_exprs)
|
||||
out = masking_rules[primitive](vals, logical_shapes, **params)
|
||||
logical_shapes = map(shape_as_value, polymorphic_shapes)
|
||||
masking_rule = masking_rules.get(primitive)
|
||||
if masking_rule is None:
|
||||
raise NotImplementedError('Masking rule for {} not implemented yet.'.format(primitive))
|
||||
out = masking_rule(vals, logical_shapes, **params)
|
||||
if not primitive.multiple_results:
|
||||
return MaskTracer(self, out, out_shape)
|
||||
else:
|
||||
return map(partial(MaskTracer, self), out, out_shape)
|
||||
|
||||
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
||||
raise NotImplementedError # TODO mask-of-jit
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
raise NotImplementedError
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
raise NotImplementedError
|
||||
|
||||
class UniqueId:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.name < other.name
|
||||
|
||||
class UniqueIds(dict):
|
||||
def __missing__(self, key):
|
||||
unique_id = UniqueId(key)
|
||||
self[key] = unique_id
|
||||
return unique_id
|
||||
|
||||
def remap_ids(names, shape_spec):
|
||||
return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
|
||||
: coeff for mon, coeff in poly.items()})
|
||||
if poly is not _monomorphic_dim else
|
||||
_monomorphic_dim for poly in shape_spec)
|
||||
|
||||
def bind_shapes(polymorphic_shapes, padded_shapes):
|
||||
env = {}
|
||||
for polymorphic_shape, padded_shape in zip(polymorphic_shapes, padded_shapes):
|
||||
for poly, d in zip(polymorphic_shape, padded_shape):
|
||||
if type(poly) is not Poly or poly.is_constant:
|
||||
if int(poly) != d: raise ShapeError
|
||||
else:
|
||||
poly = poly.copy()
|
||||
const_coeff = poly.pop(Mon({}), 0)
|
||||
(mon, linear_coeff), = poly.items()
|
||||
(id, index), = mon.items()
|
||||
if index != 1: raise ShapeError
|
||||
d, r = divmod(d - const_coeff, linear_coeff)
|
||||
assert r == 0
|
||||
if env.setdefault(id, d) != d: raise ShapeError
|
||||
return env
|
||||
|
||||
def check_shapes(specs, spec_tree, shapes, tree, message_prefix="Output"):
|
||||
if spec_tree != tree or not all(map(_shape_spec_consistent, specs, shapes)):
|
||||
specs = tree_unflatten(spec_tree, specs)
|
||||
shapes = tree_unflatten(tree, shapes)
|
||||
raise ShapeError(f"{message_prefix} shapes should be {specs} but are {shapes}.")
|
||||
|
126
jax/lax/lax.py
126
jax/lax/lax.py
@ -36,7 +36,7 @@ from .. import dtypes
|
||||
from .. import lazy
|
||||
from .. import lib
|
||||
from ..config import flags
|
||||
from ..core import Primitive
|
||||
from ..core import Primitive, _canonicalize_dimension
|
||||
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
||||
AbstractToken, array_types, make_shaped_array,
|
||||
raise_to_shaped, abstract_token, canonicalize_shape)
|
||||
@ -1534,8 +1534,8 @@ def slice_in_dim(operand: Array, start_index: Optional[int],
|
||||
|
||||
# translate `None`
|
||||
len_axis = operand.shape[axis]
|
||||
start_index_int = int(start_index) if start_index is not None else 0
|
||||
limit_index_int = int(limit_index) if limit_index is not None else len_axis
|
||||
start_index_int = _canonicalize_dimension(start_index) if start_index is not None else 0
|
||||
limit_index_int = _canonicalize_dimension(limit_index) if limit_index is not None else len_axis
|
||||
|
||||
# translate negative indices
|
||||
if start_index_int < 0:
|
||||
@ -1700,7 +1700,7 @@ def _iter(tracer):
|
||||
if tracer.ndim == 0:
|
||||
raise TypeError("iteration over a 0-d array") # same as numpy error
|
||||
else:
|
||||
n = tracer.shape[0]
|
||||
n = int(tracer.shape[0])
|
||||
# return (index_in_dim(tracer, i, keepdims=False) for i in range(n))
|
||||
return iter([index_in_dim(tracer, i, keepdims=False) for i in range(n)])
|
||||
ShapedArray._iter = staticmethod(_iter)
|
||||
@ -2523,6 +2523,45 @@ def _conv_general_dilated_batch_rule(
|
||||
out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
|
||||
return out, out_spec[1]
|
||||
|
||||
def _masked(padded_value, logical_shape, dimensions, value=0):
|
||||
"""
|
||||
Sets all padding to the given value (default is 0) in the given dimensions.
|
||||
All values outside the logical shape are considered padding.
|
||||
"""
|
||||
if len(dimensions) == 0:
|
||||
return padded_value
|
||||
|
||||
masks = [broadcasted_iota(onp.int32, padded_value.shape, d) < logical_shape[d]
|
||||
for d in dimensions]
|
||||
mask_intersection = masks[0]
|
||||
for mask in masks[1:]:
|
||||
mask_intersection &= mask
|
||||
return select(mask_intersection, padded_value, full_like(padded_value, value))
|
||||
|
||||
def _conv_general_dilated_masking_rule(
|
||||
padded_vals, logical_shapes, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
||||
lhs_shape, rhs_shape, precision):
|
||||
lhs, rhs = padded_vals
|
||||
logical_lhs_shape, logical_rhs_shape = logical_shapes
|
||||
|
||||
o, i, *window_dimensions = dimension_numbers.rhs_spec
|
||||
assert (onp.all(onp.take(rhs.shape, window_dimensions)
|
||||
== onp.take(logical_rhs_shape, window_dimensions))), \
|
||||
"Conv filter masking not yet implemented."
|
||||
|
||||
n, c, *padded_dimensions = dimension_numbers.lhs_spec
|
||||
|
||||
return conv_general_dilated(
|
||||
_masked(lhs, logical_lhs_shape, padded_dimensions),
|
||||
_masked(rhs, logical_rhs_shape, (i,)),
|
||||
window_strides=window_strides, padding=padding,
|
||||
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation,
|
||||
dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
precision=precision)
|
||||
|
||||
conv_general_dilated_p = standard_primitive(
|
||||
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
|
||||
'conv_general_dilated', _conv_general_dilated_translation_rule)
|
||||
@ -2531,7 +2570,8 @@ ad.defbilinear(conv_general_dilated_p,
|
||||
_conv_general_dilated_transpose_rhs)
|
||||
batching.primitive_batchers[conv_general_dilated_p] = \
|
||||
_conv_general_dilated_batch_rule
|
||||
|
||||
masking.masking_rules[conv_general_dilated_p] = \
|
||||
_conv_general_dilated_masking_rule
|
||||
|
||||
def _reshape_axis_into(src, dst, x):
|
||||
perm = [i for i in range(x.ndim) if i != src]
|
||||
@ -2683,21 +2723,13 @@ def _dot_general_translation_rule(c, lhs, rhs, *, dimension_numbers, precision):
|
||||
def _dot_general_masking_rule(padded_vals, logical_shapes, *, dimension_numbers,
|
||||
precision):
|
||||
lhs, rhs = padded_vals
|
||||
lhs_shape, rhs_shape = logical_shapes
|
||||
lhs_ndim, rhs_ndim = len(lhs_shape), len(rhs_shape)
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
# we need only mask the lhs contraction dimensions
|
||||
if len(lhs_contract) == 0:
|
||||
return dot_general(lhs, rhs, dimension_numbers, precision=precision)
|
||||
else:
|
||||
masks = [broadcasted_iota(onp.int32, lhs.shape, d) < lhs_shape[d]
|
||||
for d in lhs_contract]
|
||||
mask_intersection = masks[0]
|
||||
for mask in masks[1:]:
|
||||
mask_intersection &= mask
|
||||
masked_lhs = select(mask_intersection, lhs, zeros_like_array(lhs))
|
||||
return dot_general(masked_lhs, rhs, dimension_numbers, precision=precision)
|
||||
# Only need to mask off contraction dims of one side - we mask the lhs here
|
||||
# but this is arbitrary. Could check the sizes of lhs and rhs and mask
|
||||
# whichever is smallest.
|
||||
lhs_shape, _ = logical_shapes
|
||||
(lhs_contract, _), _ = dimension_numbers
|
||||
return dot_general(_masked(lhs, lhs_shape, lhs_contract),
|
||||
rhs, dimension_numbers, precision=precision)
|
||||
|
||||
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
||||
_dot_general_dtype_rule, 'dot_general',
|
||||
@ -2925,11 +2957,23 @@ def _pad_translation_rule(c, operand, padding_value, *, padding_config):
|
||||
return xops.Pad(operand, padding_value,
|
||||
xc.make_padding_config(padding_config))
|
||||
|
||||
def _pad_masking_rule(padded_vals, logical_shapes, padding_config):
|
||||
operand, padding_value = padded_vals
|
||||
shape, _ = logical_shapes
|
||||
|
||||
out = pad(operand, padding_value, padding_config)
|
||||
out_shape = [lo + shape[i] * (interior + 1)
|
||||
for i, (lo, hi, interior) in enumerate(padding_config)]
|
||||
padded_dims = [i for i, config in enumerate(padding_config)
|
||||
if config != (0, 0, 0)]
|
||||
return _masked(out, out_shape, padded_dims, padding_value)
|
||||
|
||||
pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad',
|
||||
translation_rule=_pad_translation_rule)
|
||||
ad.deflinear(pad_p, _pad_transpose)
|
||||
ad.primitive_transposes[pad_p] = _pad_transpose
|
||||
batching.primitive_batchers[pad_p] = _pad_batch_rule
|
||||
masking.masking_rules[pad_p] = _pad_masking_rule
|
||||
|
||||
|
||||
# The squeeze primitive exists for the benefit of masking and other
|
||||
@ -3183,12 +3227,16 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
|
||||
perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
|
||||
return transpose(operand, perm), 0
|
||||
|
||||
def _transpose_masking_rule(padded_vals, logical_shapes, permutation):
|
||||
return transpose(*padded_vals, permutation=permutation)
|
||||
|
||||
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
'transpose')
|
||||
transpose_p.def_impl(_transpose_impl)
|
||||
ad.deflinear(transpose_p,
|
||||
lambda t, permutation: [transpose(t, onp.argsort(permutation))])
|
||||
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
|
||||
masking.masking_rules[transpose_p] = _transpose_masking_rule
|
||||
|
||||
|
||||
def _select_shape_rule(pred, on_true, on_false):
|
||||
@ -3257,6 +3305,13 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
||||
on_false = broadcast(on_false, pred.shape)
|
||||
return select(pred, on_true, on_false), 0
|
||||
|
||||
def _select_masking_rule(padded_vals, logical_shapes):
|
||||
pred_shape, true_shape, false_shape = [
|
||||
masking.padded_shape_as_value(val.shape) for val in padded_vals]
|
||||
assert onp.array_equal(pred_shape, true_shape)
|
||||
assert onp.array_equal(pred_shape, false_shape)
|
||||
return select(*padded_vals)
|
||||
|
||||
select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select')
|
||||
ad.defjvp(select_p,
|
||||
None,
|
||||
@ -3264,6 +3319,7 @@ ad.defjvp(select_p,
|
||||
lambda g, b, x, y: select(b, _zeros(g), g))
|
||||
ad.primitive_transposes[select_p] = _select_transpose_rule
|
||||
batching.primitive_batchers[select_p] = _select_batch_rule
|
||||
masking.masking_rules[select_p] = _select_masking_rule
|
||||
|
||||
|
||||
def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
|
||||
@ -3277,7 +3333,9 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
|
||||
msg = ("slice limit_indices must have the same length as start_indices, "
|
||||
"got start_inidices {} and limit_indices {}.")
|
||||
raise TypeError(msg.format(start_indices, limit_indices))
|
||||
if not onp.all(onp.less_equal(limit_indices, operand.shape)):
|
||||
if (not masking.is_polymorphic(limit_indices) and
|
||||
not masking.is_polymorphic(operand.shape) and
|
||||
not onp.all(onp.less_equal(limit_indices, operand.shape))):
|
||||
msg = ("slice limit_indices must be less than or equal to operand shape, "
|
||||
"got limit_indices {} for operand shape {}.")
|
||||
raise TypeError(msg.format(limit_indices, operand.shape))
|
||||
@ -3285,7 +3343,8 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
|
||||
msg = ("slice start_indices must be greater than or equal to zero, "
|
||||
"got start_indices of {}.")
|
||||
raise TypeError(msg.format(start_indices))
|
||||
if not onp.all(onp.greater_equal(limit_indices, start_indices)):
|
||||
if (not masking.is_polymorphic(limit_indices) and
|
||||
not onp.all(onp.greater_equal(limit_indices, start_indices))):
|
||||
msg = ("slice limit_indices must be greater than or equal to start_indices,"
|
||||
" got start_indices {} and limit_indices {}.")
|
||||
raise TypeError(msg.format(start_indices, limit_indices))
|
||||
@ -3345,10 +3404,19 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
|
||||
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
|
||||
return out, bdim
|
||||
|
||||
def _slice_masking_rule(
|
||||
padded_vals, logical_shapes, start_indices, limit_indices, strides):
|
||||
operand, = padded_vals
|
||||
return slice(operand,
|
||||
start_indices=masking.padded_shape_as_value(start_indices),
|
||||
limit_indices=masking.padded_shape_as_value(limit_indices),
|
||||
strides=strides)
|
||||
|
||||
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice',
|
||||
_slice_translation_rule)
|
||||
ad.deflinear2(slice_p, _slice_transpose_rule)
|
||||
batching.primitive_batchers[slice_p] = _slice_batching_rule
|
||||
masking.masking_rules[slice_p] = _slice_masking_rule
|
||||
|
||||
|
||||
def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes):
|
||||
@ -4056,14 +4124,15 @@ def _masking_defreducer(prim, identity):
|
||||
masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
|
||||
|
||||
def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
|
||||
axes):
|
||||
axes, input_shape=None):
|
||||
(padded_val,), (logical_shape,) = padded_vals, logical_shapes
|
||||
padded_shape = masking.padded_shape_as_value(padded_val.shape)
|
||||
masks = [broadcasted_iota(onp.int32, padded_shape, i) < d
|
||||
for i, d in enumerate(logical_shape) if i in axes]
|
||||
mask = _reduce(operator.and_, masks)
|
||||
masked_val = select(mask, padded_val, identity(padded_shape, padded_val.dtype))
|
||||
return prim.bind(masked_val, axes=axes)
|
||||
bind = prim.bind if input_shape is None else partial(prim.bind, input_shape=padded_shape)
|
||||
return bind(masked_val, axes=axes)
|
||||
|
||||
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
|
||||
_reduce_translation_rule)
|
||||
@ -4103,7 +4172,8 @@ _masking_defreducer(reduce_sum_p,
|
||||
lambda shape, dtype: onp.broadcast_to(onp.array(0, dtype), shape))
|
||||
|
||||
|
||||
def _reduce_op_shape_rule(operand, *, axes):
|
||||
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
||||
del input_shape # unused.
|
||||
return tuple(onp.delete(operand.shape, axes))
|
||||
|
||||
def _reduce_prod_translation_rule(c, operand, *, axes):
|
||||
@ -4151,6 +4221,8 @@ reduce_prod_p = standard_primitive(
|
||||
'reduce_prod', _reduce_prod_translation_rule)
|
||||
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
|
||||
batching.defreducer(reduce_prod_p)
|
||||
_masking_defreducer(reduce_prod_p,
|
||||
lambda shape, dtype: onp.broadcast_to(onp.array(1, dtype), shape))
|
||||
|
||||
|
||||
def _reduce_chooser_shape_rule(operand, *, axes):
|
||||
@ -4178,6 +4250,8 @@ reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
||||
'reduce_max', _reduce_max_translation_rule)
|
||||
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_max_p)
|
||||
_masking_defreducer(reduce_max_p,
|
||||
lambda shape, dtype: onp.broadcast_to(onp.array(-onp.inf, dtype), shape))
|
||||
|
||||
|
||||
_reduce_min_translation_rule = partial(
|
||||
@ -4186,6 +4260,8 @@ reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
||||
'reduce_min', _reduce_min_translation_rule)
|
||||
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_min_p)
|
||||
_masking_defreducer(reduce_min_p,
|
||||
lambda shape, dtype: onp.broadcast_to(onp.array(onp.inf, dtype), shape))
|
||||
|
||||
|
||||
def _reduce_logical_shape_rule(operand, *, axes):
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from unittest import SkipTest
|
||||
@ -20,23 +19,31 @@ from unittest import SkipTest
|
||||
import numpy as np
|
||||
from absl.testing import absltest, parameterized
|
||||
from jax.interpreters.masking import shape_as_value, ShapeError, \
|
||||
parse_spec, Poly, Mon
|
||||
parse_spec, Poly, Mon, finalize_spec, eval_polymorphic_shape, remap_ids, \
|
||||
UniqueIds
|
||||
from jax import numpy as jnp, test_util as jtu, mask, vmap, jit, grad, lax, \
|
||||
shapecheck, api
|
||||
shapecheck, core
|
||||
from jax.config import config
|
||||
from jax.numpy.lax_numpy import _polymorphic_slice_indices
|
||||
from jax.scipy.special import expit
|
||||
from jax.util import safe_map, safe_zip
|
||||
from jax.test_util import rand_default, rand_int
|
||||
from jax.tree_util import tree_flatten
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
# These are 'manual' tests for masking. The more exhaustive,
|
||||
# more systematic tests should live in lax_test.py.
|
||||
|
||||
# TODO:
|
||||
# These should be only the 'manual' tests for masking.
|
||||
# Move the more exhaustive, systematic tests into lax_test.py.
|
||||
|
||||
def constant_poly(c):
|
||||
return Poly({Mon(): c})
|
||||
|
||||
class ShapesTest(jtu.JaxTestCase):
|
||||
class PolyTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
['(m, n)', 'ShapeSpec(m, n)'],
|
||||
@ -49,14 +56,17 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
['(3 * m)', 'ShapeSpec(3 m)'],
|
||||
['m', 'ShapeSpec(m)'],
|
||||
['', 'ShapeSpec()'],
|
||||
['n + -1*n', 'ShapeSpec(0)'],
|
||||
['m + n', 'ShapeSpec(m + n)'],
|
||||
['m + n * k', 'ShapeSpec(m + k n)'],
|
||||
['m + n * k', 'ShapeSpec(k n + m)'],
|
||||
['m + 3 * k', 'ShapeSpec(3 k + m)'],
|
||||
['-3 + k + k * k', 'ShapeSpec(k**2 + k + -3)'],
|
||||
['', 'ShapeSpec()'],
|
||||
['_', 'ShapeSpec(_)'],
|
||||
])
|
||||
def test_parse_spec(self, spec, ans):
|
||||
self.assertEqual(str(parse_spec(spec)), ans)
|
||||
self.assertEqual(str(remap_ids(UniqueIds(), parse_spec(spec))), ans)
|
||||
|
||||
def test_Poly_equal(self):
|
||||
assert constant_poly(3) == 3
|
||||
@ -74,7 +84,12 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_Poly_hash(self):
|
||||
assert not len(set(hash(Poly({Mon(): i})) for i in range(10))) == 1
|
||||
assert hash(Poly({Mon(): 3, Mon({'n': 1}): 4})) == hash(Poly({Mon({'n': 1}): 4, Mon(): 3}))
|
||||
assert (hash(Poly({Mon(): 3, Mon({'n': 1}): 4}))
|
||||
== hash(Poly({Mon({'n': 1}): 4, Mon(): 3})))
|
||||
|
||||
def test_Mon_hash(self):
|
||||
assert not len(set(hash(Mon({'a': i})) for i in range(10))) == 1
|
||||
assert hash(Mon({'a': 1, 'b': 1})) == hash(Mon({'b': 1, 'a': 1}))
|
||||
|
||||
def test_Poly_compare(self):
|
||||
poly = Poly({Mon(): 3, Mon({'n': 1}): 4})
|
||||
@ -100,208 +115,7 @@ class ShapesTest(jtu.JaxTestCase):
|
||||
n = Poly({Mon({'n': 1}): 1})
|
||||
assert -1 - n == -n - 1
|
||||
|
||||
def test_add_broadcast(self):
|
||||
@shapecheck(['n', '(m, n)'], '(m, n)')
|
||||
@shapecheck(['(m, n)', 'n'], '(m, n)')
|
||||
@shapecheck(['n', ''], 'n')
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
def test_sum(self):
|
||||
@shapecheck(['(m, n)'], '')
|
||||
def sum(x):
|
||||
return jnp.sum(x)
|
||||
|
||||
def test_prod(self):
|
||||
@shapecheck(['(m, n)'], '')
|
||||
def prod(x):
|
||||
return jnp.prod(x)
|
||||
|
||||
def test_max(self):
|
||||
@shapecheck(['(m, n)'], '')
|
||||
def prod(x):
|
||||
return jnp.max(x)
|
||||
|
||||
def test_min(self):
|
||||
@shapecheck(['(m, n)'], '')
|
||||
def prod(x):
|
||||
return jnp.min(x)
|
||||
|
||||
def test_dot(self):
|
||||
@shapecheck(['(m, n)', 'n'], 'm')
|
||||
def matvec(A, b):
|
||||
return jnp.dot(A, b)
|
||||
|
||||
def thunk():
|
||||
@shapecheck(['(m, n)', 'n'], 'm')
|
||||
def matvec(A, b):
|
||||
return lax.dot_general(A, b, [((0,), (0,)), ((), ())])
|
||||
self.assertRaisesRegex(TypeError, "", thunk)
|
||||
|
||||
def test_flatten(self):
|
||||
@shapecheck(['(m, n)'], 'm * n')
|
||||
def flatten(x):
|
||||
return lax.reshape(x, (x.shape[0] * x.shape[1],))
|
||||
|
||||
def test_concatenate(self):
|
||||
@shapecheck(['m', 'n', 'm'], '3*m + n')
|
||||
def cat(x, y, z):
|
||||
return lax.concatenate([x, y, x, z], 0)
|
||||
|
||||
def thunk():
|
||||
@shapecheck(['m', 'n', 'm'], '3*m + n')
|
||||
def cat(x, y, z):
|
||||
return lax.concatenate([x, y, x], 0)
|
||||
self.assertRaisesRegex(ShapeError, "", thunk)
|
||||
|
||||
def test_device_put(self):
|
||||
@shapecheck(['n'], 'n')
|
||||
def d_put(x):
|
||||
return api.device_put(x)
|
||||
|
||||
def test_broadcast_in_dim(self):
|
||||
|
||||
@shapecheck(['(n,)'], '(3, n, 4)')
|
||||
def broadcast_in_dim(x):
|
||||
return lax.broadcast_in_dim(x, shape=(3, x.shape[0], 4), broadcast_dimensions=(1,))
|
||||
|
||||
@shapecheck(['(n, 1)'], '(3, n, 4, 1)')
|
||||
def broadcast_in_dim2(x):
|
||||
return lax.broadcast_in_dim(x, shape=(3, x.shape[0], 4, x.shape[1]), broadcast_dimensions=(1, 3))
|
||||
|
||||
def test_jit(self):
|
||||
@shapecheck(['n'], '2*n')
|
||||
@jit
|
||||
def concat(x):
|
||||
return lax.concatenate([x, x], 0)
|
||||
|
||||
# TODO:
|
||||
# @shapecheck(['n'], 'n')
|
||||
# @jit
|
||||
# @grad
|
||||
# def sum_square(x):
|
||||
# return jnp.sum(x ** 2)
|
||||
|
||||
def test_pad(self):
|
||||
@shapecheck(['n'], '2*n+1')
|
||||
def p(x):
|
||||
return lax.pad(x, jnp.array(0., x.dtype), [(1, 1, 1)])
|
||||
|
||||
def test_numpy_pad(self):
|
||||
@shapecheck(['n'], 'n+1')
|
||||
def p(x):
|
||||
return jnp.pad(x, (0, 1))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{
|
||||
'testcase_name': "strides={}_padding={}_lhs_dilation={}_dimension_numbers"
|
||||
"={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
|
||||
strides, padding, lhs_dilation, dimension_numbers, lhs_perm, rhs_perm, out_perm),
|
||||
'strides': strides, 'padding': padding, 'lhs_dilation': lhs_dilation,
|
||||
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
|
||||
'rhs_perm': rhs_perm, 'out_perm': out_perm}
|
||||
for strides in [(1, 1), (2, 1)]
|
||||
for padding in ['SAME', 'VALID', ((1, 0), (2, 0))]
|
||||
for lhs_dilation in (None, (1, 2))
|
||||
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
|
||||
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
|
||||
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
|
||||
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
|
||||
)
|
||||
# String padding is not implemented for transposed convolution, see conv_general_dilated implementation:
|
||||
if (lhs_dilation is None or not isinstance(padding, str)) and
|
||||
# only test strides with same padding:
|
||||
(strides[0] == 1 or padding == 'SAME')))
|
||||
def test_conv(self, strides, padding, lhs_dilation,
|
||||
dimension_numbers, lhs_perm, rhs_perm, out_perm):
|
||||
valid = padding == 'VALID'
|
||||
is_strided = strides[0] != 1
|
||||
lhs_shape = '({}, {}, {}, {})'.format(*np.take(['n', 'i', '2*h' if is_strided else 'h', 'w'], lhs_perm))
|
||||
rhs_shape = '({}, {}, {}, {})'.format(*np.take(['o', 'i', '2', '3'], rhs_perm))
|
||||
out_shape = '({}, {}, {}, {})'.format(*np.take([
|
||||
'n', 'o', 'h+-1' if valid and not is_strided else 'h',
|
||||
('w+-2' if valid else 'w') if lhs_dilation is None else '2*w+-1'], out_perm))
|
||||
|
||||
@shapecheck([lhs_shape, rhs_shape], out_shape)
|
||||
def conv(lhs, rhs):
|
||||
return lax.conv_general_dilated(
|
||||
lhs, rhs, strides, padding,
|
||||
lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers)
|
||||
|
||||
def test_indexing(self):
|
||||
@shapecheck(['n'], '')
|
||||
def first(x):
|
||||
return x[0]
|
||||
|
||||
@shapecheck(['n'], '')
|
||||
def last(x):
|
||||
return x[-1]
|
||||
|
||||
@shapecheck(['(n,m,a)'], 'n,m')
|
||||
@vmap
|
||||
@shapecheck(['(n,a)'], 'n')
|
||||
def last_column(x):
|
||||
return x[..., -1]
|
||||
|
||||
def test_slicing(self):
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def slice(x):
|
||||
return x[1:]
|
||||
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def slice2(x):
|
||||
return x[:-1]
|
||||
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def inverse(x):
|
||||
return x[:0:-1]
|
||||
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def inverse2(x):
|
||||
return x[-2::-1]
|
||||
|
||||
def test_poly_slicing(self):
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def slice_poly_stop(x):
|
||||
return x[:x.shape[0] - 1]
|
||||
|
||||
# TODO: @shapecheck(['n'], '1')
|
||||
def slice_poly_start(x):
|
||||
return x[x.shape[0] - 1:]
|
||||
|
||||
def test_iota(self):
|
||||
raise SkipTest("not yet implemented")
|
||||
# https://travis-ci.org/github/google/jax/jobs/682086351
|
||||
@shapecheck(['n'], 'n')
|
||||
def range_like(x):
|
||||
return lax.iota(jnp.int32, x.shape[0])
|
||||
|
||||
def test_arange(self):
|
||||
raise SkipTest("not yet implemented")
|
||||
# https://travis-ci.org/github/google/jax/jobs/682086351
|
||||
@shapecheck(['n'], 'n')
|
||||
def arange_like(x):
|
||||
return jnp.arange(x.shape[0], dtype=jnp.int32)
|
||||
|
||||
def test_expit(self):
|
||||
@shapecheck(['n'], 'n')
|
||||
def expit_(x):
|
||||
return expit(x)
|
||||
|
||||
def test_reshape(self):
|
||||
@shapecheck(['n, a, b'], 'n, a*b')
|
||||
def flatten(x):
|
||||
return jnp.reshape(x, (x.shape[0], x.shape[1] * x.shape[2]))
|
||||
|
||||
def test_ravel(self):
|
||||
a = jnp.array(1)
|
||||
|
||||
@shapecheck(['n'], '')
|
||||
def thunk(n):
|
||||
return -(a + n.ravel()[0] * 0)
|
||||
|
||||
class MaskingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_sum(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
@ -324,10 +138,43 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
expected = np.array([0, 1, 2, 3, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def check(self, fun, in_shapes, out_shape, logical_env, padded_in_shapes,
|
||||
dtypes, rng, rtol=None, atol=None):
|
||||
shapecheck(in_shapes, out_shape)(fun)
|
||||
masked_fun = mask(fun, in_shapes, out_shape)
|
||||
padded_args = [rng(shape, dtype)
|
||||
for shape, dtype in zip(padded_in_shapes, dtypes)]
|
||||
padded_outs, outs_tree = tree_flatten(masked_fun(padded_args, logical_env))
|
||||
|
||||
out_specs, _ = tree_flatten(out_shape)
|
||||
out_specs = map(parse_spec, out_specs)
|
||||
out_specs = map(finalize_spec, out_specs, map(np.shape, padded_outs))
|
||||
logical_out_shapes = [eval_polymorphic_shape(s, logical_env)
|
||||
for s in out_specs]
|
||||
logical_out_slices = [tuple(map(slice, s)) for s in logical_out_shapes]
|
||||
logical_outs = [o[s] for o, s in zip(padded_outs, logical_out_slices)]
|
||||
|
||||
in_specs = map(parse_spec, in_shapes)
|
||||
in_specs = map(finalize_spec, in_specs, padded_in_shapes)
|
||||
logical_in_shapes = [eval_polymorphic_shape(s, logical_env)
|
||||
for s in in_specs]
|
||||
logical_in_slices = [tuple(map(slice, s)) for s in logical_in_shapes]
|
||||
logical_args = [a[s] for a, s in zip(padded_args, logical_in_slices)]
|
||||
logical_outs_expected, logical_outs_tree = tree_flatten(fun(*logical_args))
|
||||
assert outs_tree == logical_outs_tree
|
||||
self.assertAllClose(logical_outs, logical_outs_expected, check_dtypes=True,
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
# Check that abstract evaluation works
|
||||
padded_outs_jit, _ = tree_flatten(jit(masked_fun)(padded_args, logical_env))
|
||||
self.assertAllClose(padded_outs_jit, padded_outs, check_dtypes=True,
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def test_add(self):
|
||||
@partial(mask, in_shapes=['n', 'n'], out_shape='n')
|
||||
def addvecs(x, y):
|
||||
return x + y
|
||||
self.check(lax.add, ['n', ''], 'n', {'n': 3}, [(4,), ()], ['float_', 'float_'],
|
||||
rand_default(self.rng()))
|
||||
addvecs = mask(lax.add, in_shapes=['n', 'n'], out_shape='n')
|
||||
|
||||
x = jnp.array([3, 1, 4, 1, 5, 9])
|
||||
y = jnp.array([2, 6, 5, 3, 5, 8])
|
||||
@ -384,35 +231,12 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
expected = 5
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_concatenate(self):
|
||||
@partial(mask, in_shapes=['n', 'm', 'n'], out_shape='m + 2 * n')
|
||||
def cat(x, y, z):
|
||||
return lax.concatenate([x, y, z], 0)
|
||||
|
||||
ans = cat([jnp.array([1, 9]), jnp.array([2, 4, 9]), jnp.array([3, 9])],
|
||||
dict(n=1, m=2))
|
||||
expected = np.array([1, 2, 4, 3])
|
||||
self.assertAllClose(ans[:4], expected, check_dtypes=False)
|
||||
|
||||
def test_dot(self):
|
||||
@partial(mask, in_shapes=['(m, k)', '(k, n)'], out_shape='(m, n)')
|
||||
def dot(x, y):
|
||||
return lax.dot(x, y)
|
||||
|
||||
x = np.arange(6, dtype=np.float32).reshape((2, 3))
|
||||
y = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
ans = dot([x, y], dict(m=2, k=2, n=2))
|
||||
expected = np.dot(x[:2, :2], y[:2, :2])
|
||||
self.assertAllClose(ans[:2, :2], expected, check_dtypes=False)
|
||||
|
||||
def test_mean(self):
|
||||
@partial(mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return jnp.sum(x) / shape_as_value(x.shape)[0]
|
||||
|
||||
ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
expected = 8 / 3
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
# TODO Shapecheck fails - shape_as_value can't deal with abstract eval yet
|
||||
raise SkipTest
|
||||
self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'], '',
|
||||
{'n': 3}, [(4,)], ['float_'],
|
||||
rand_default(self.rng()))
|
||||
|
||||
def test_arithmetic(self):
|
||||
@partial(mask, in_shapes=['(n, m)', 'm'], out_shape='(n, m)')
|
||||
@ -420,9 +244,11 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
return x * y
|
||||
|
||||
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
|
||||
with self.assertRaisesRegex(KeyError, 'broadcast_in_dim'):
|
||||
_ = times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
|
||||
dict(n=4, m=5))
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Masking rule for broadcast_in_dim not implemented yet.'):
|
||||
times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
|
||||
dict(n=4, m=5))
|
||||
# expected = np.array([[1, 2, 3], [8, 10, 12]])
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -432,8 +258,10 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
return jnp.stack([x, y], 0)
|
||||
|
||||
# TODO(shoyer): enable this check when broadcast_in_dim supports masking
|
||||
with self.assertRaisesRegex(KeyError, 'broadcast_in_dim'):
|
||||
_ = stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Masking rule for broadcast_in_dim not implemented yet.'):
|
||||
stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
|
||||
# expected = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -464,6 +292,10 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
expected = jnp.array([3, 5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@shapecheck(['(2*n, n)'], '_, n')
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
def test_rnn(self):
|
||||
n = 3
|
||||
|
||||
@ -544,8 +376,285 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
expected = grad(lambda W: rnn_reference(W, seqs_, ys).sum())(W)
|
||||
|
||||
self.assertAllClose(
|
||||
ans, expected, check_dtypes=False,
|
||||
rtol=2e-2 if jtu.device_under_test() == "tpu" else 1e-5)
|
||||
ans, expected, check_dtypes=False,
|
||||
rtol=2e-2 if jtu.device_under_test() == "tpu" else 1e-5)
|
||||
|
||||
def test_concatenate(self):
|
||||
self.check(lambda x, y, z: lax.concatenate([x, y, z], 0),
|
||||
['n', 'm', 'n'], 'm + 2 * n', {'n': 2, 'm': 3},
|
||||
[(4,), (3,), (4,)], ['float_', 'float_', 'float_'],
|
||||
rand_default(self.rng()))
|
||||
|
||||
def test_dot(self):
|
||||
self.check(lax.dot, ['(m, k)', '(k, n)'], '(m, n)',
|
||||
dict(m=2, k=3, n=4), [(4, 5), (5, 7)], ['float_', 'float_'],
|
||||
rand_default(self.rng()))
|
||||
self.check(lax.dot, ['(m, n)', 'n'], 'm', dict(m=2, n=3), [(4, 5), (5,)],
|
||||
['float_', 'float_'], rand_default(self.rng()))
|
||||
|
||||
def test_jit(self):
|
||||
raise SkipTest
|
||||
@partial(mask, in_shapes=['n'], out_shape='2*n')
|
||||
@jit
|
||||
def duplicate(x):
|
||||
assert python_should_be_executing
|
||||
return lax.concatenate([x, x], 0)
|
||||
|
||||
python_should_be_executing = True
|
||||
out = duplicate([jnp.arange(3)], dict(n=2))
|
||||
assert np.all(np.array([0, 1, 0, 1]) == out[:4])
|
||||
|
||||
python_should_be_executing = False
|
||||
out = duplicate([jnp.arange(3)], dict(n=2))
|
||||
assert np.all(np.array([0, 1, 0, 1]) == out[:4])
|
||||
|
||||
@parameterized.named_parameters({
|
||||
'testcase_name': "padding_config={}_shapes={}".format(padding_config,
|
||||
shape),
|
||||
'padding_config': padding_config,
|
||||
'shape': shape} for padding_config, shape in (
|
||||
(((1, 2, 0),), (2,)),
|
||||
(((1, 2, 0), (3, 4, 0)), (1, 2)),
|
||||
(((0, 0, 0), (0, 0, 0)), (1, 2)),
|
||||
(((1, 2, 3),), (2,)),
|
||||
(((1, 2, 1), (3, 4, 2)), (3, 2)),
|
||||
(((-1, 2, 0),), (2,)),
|
||||
(((-1, -2, 0), (1, 2, 0)), (4, 2)),
|
||||
(((-1, 2, 0), (1, 2, 2)), (4, 2)),
|
||||
(((-1, -2, 2),), (5,)),
|
||||
(((-1, -2, 1), (1, 2, 2)), (4, 2))))
|
||||
def test_pad(self, padding_config, shape):
|
||||
def pad(x):
|
||||
return lax.pad(x, jnp.array(1., x.dtype), padding_config)
|
||||
|
||||
if len(shape) == 1:
|
||||
padding_config_, = padding_config
|
||||
linear_coeff = padding_config_[2] + 1
|
||||
const_coeff = sum(padding_config_[:2]) - padding_config_[2]
|
||||
out_shape = str(linear_coeff) + ' * h + ' + str(const_coeff)
|
||||
self.check(pad, ['h'], out_shape, dict(h=shape[0]),
|
||||
[tuple(np.add(shape, 1))], ['float_'],
|
||||
rand_default(self.rng()))
|
||||
|
||||
|
||||
def test_numpy_pad(self):
|
||||
# TODO (j-towns) requires mask(jit)
|
||||
raise SkipTest
|
||||
def numpy_pad(x):
|
||||
return jnp.pad(x, (0, 1), constant_values=5.)
|
||||
|
||||
self.check(numpy_pad, ['n'], 'n + 1', dict(n=2), [(3,)], ['float_'],
|
||||
rand_default(self.rng()))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name': "padding={}_lhs_dilation={}_"
|
||||
"dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
|
||||
padding, lhs_dilation, dimension_numbers, lhs_perm,
|
||||
rhs_perm, out_perm),
|
||||
'padding': padding, 'lhs_dilation': lhs_dilation,
|
||||
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
|
||||
'rhs_perm': rhs_perm, 'out_perm': out_perm}
|
||||
for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
|
||||
for lhs_dilation in (None, (1, 2))
|
||||
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
|
||||
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
|
||||
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
|
||||
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
|
||||
)
|
||||
# String padding is not implemented for transposed convolution, see
|
||||
# conv_general_dilated implementation:
|
||||
if (lhs_dilation is None or not isinstance(padding, str))))
|
||||
def test_conv(
|
||||
self, padding, lhs_dilation, dimension_numbers, lhs_perm,
|
||||
rhs_perm, out_perm):
|
||||
def conv(lhs, rhs):
|
||||
return lax.conv_general_dilated(
|
||||
lhs, rhs, (1, 1), padding, lhs_dilation=lhs_dilation,
|
||||
dimension_numbers=dimension_numbers)
|
||||
|
||||
template = '({}, {}, {}, {})'
|
||||
lhs_shape = template.format(*np.take(['n', 'c', 'h', 'w'], lhs_perm))
|
||||
rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
|
||||
if padding == 'VALID':
|
||||
out_shape = template.format(
|
||||
*np.take(['n', 'o', 'h+-1', 'w+-2'], out_perm))
|
||||
elif lhs_dilation:
|
||||
out_shape = template.format(
|
||||
*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
|
||||
else:
|
||||
out_shape = template.format(
|
||||
*np.take(['n', 'o', 'h', 'w'], out_perm))
|
||||
|
||||
logical_env = dict(n=3, c=2, h=4, w=5, o=6)
|
||||
|
||||
self.check(conv, [lhs_shape, rhs_shape], out_shape,
|
||||
logical_env, [tuple(np.take([4, 3, 6, 7], lhs_perm)),
|
||||
tuple(np.take([7, 3, 2, 3], rhs_perm))],
|
||||
['float_', 'float_'], rand_default(self.rng()), rtol=1e-4,
|
||||
atol=1e-4)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{'testcase_name': "padding={}_lhs_dilation={}_"
|
||||
"dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
|
||||
padding, lhs_dilation, dimension_numbers, lhs_perm,
|
||||
rhs_perm, out_perm),
|
||||
'padding': padding, 'lhs_dilation': lhs_dilation,
|
||||
'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
|
||||
'rhs_perm': rhs_perm, 'out_perm': out_perm}
|
||||
for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
|
||||
for lhs_dilation in (None, (1, 2))
|
||||
for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
|
||||
(("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
|
||||
(("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
|
||||
(("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
|
||||
)
|
||||
# String padding is not implemented for transposed convolution, see
|
||||
# conv_general_dilated implementation:
|
||||
if (lhs_dilation is None or not isinstance(padding, str))))
|
||||
def test_conv_strided(
|
||||
self, padding, lhs_dilation, dimension_numbers, lhs_perm,
|
||||
rhs_perm, out_perm):
|
||||
def conv(lhs, rhs):
|
||||
return lax.conv_general_dilated(
|
||||
lhs, rhs, (2, 1), padding, lhs_dilation=lhs_dilation,
|
||||
dimension_numbers=dimension_numbers)
|
||||
|
||||
template = '({}, {}, {}, {})'
|
||||
rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
|
||||
if padding == 'VALID':
|
||||
lhs_shape = template.format(*np.take(['n', 'c', '2*h+1', 'w'], lhs_perm))
|
||||
lhs_shape_padded = tuple(np.take([4, 3, 5, 7], lhs_perm))
|
||||
out_shape = template.format(*np.take(['n', 'o', 'h', 'w+-2'], out_perm))
|
||||
elif lhs_dilation:
|
||||
lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
|
||||
lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
|
||||
out_shape = template.format(*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
|
||||
else:
|
||||
lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
|
||||
lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
|
||||
out_shape = template.format(*np.take(['n', 'o', 'h', 'w'], out_perm))
|
||||
|
||||
logical_env = dict(n=3, c=2, h=4, w=5, o=6)
|
||||
|
||||
self.check(conv, [lhs_shape, rhs_shape], out_shape,
|
||||
logical_env, [lhs_shape_padded,
|
||||
tuple(np.take([7, 3, 2, 3], rhs_perm))],
|
||||
['float_', 'float_'], rand_default(self.rng()), rtol=1e-4,
|
||||
atol=1e-4)
|
||||
|
||||
def test_indexing(self):
|
||||
# Requires gather support
|
||||
raise SkipTest
|
||||
self.check(lambda x: x[0], ['n'], '', {'n': 2}, [(3,)], ['float_'],
|
||||
rand_default(self.rng()))
|
||||
self.check(lambda x: x[-1], ['n'], '', {'n': 2}, [(3,)], ['float_'],
|
||||
rand_default(self.rng()))
|
||||
|
||||
def test_slicing(self):
|
||||
raise SkipTest
|
||||
# Requires gather support
|
||||
self.check(lambda x: x[1:], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
|
||||
self.check(lambda x: x[:-1], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
|
||||
self.check(lambda x: x[..., -1], ['(n,3)'], 'n', {'n': 2}, [(3, 4)], ['float_'])
|
||||
|
||||
def test_rev(self):
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def rev(x):
|
||||
return x[:0:-1]
|
||||
|
||||
@shapecheck(['n'], 'n+-1')
|
||||
def rev2(x):
|
||||
return x[-2::-1]
|
||||
|
||||
# TODO implement masking for rev_p:
|
||||
# self.check(lambda x: x[:0:-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
|
||||
# self.check(lambda x: x[-2::-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
|
||||
|
||||
def test_lax_slice(self):
|
||||
self.check(lambda x: lax.slice(x, (1,), (x.shape[0],)), ['n'], 'n+-1',
|
||||
{'n': 2}, [(3,)], ['float_'], rand_default(self.rng()))
|
||||
# TODO: self.check(lambda x: lax.slice(x, (x.shape[0] // 2,), (x.shape[0],)), ['2*n'], dict(n=jnp.array([2, 3])), 'n')
|
||||
|
||||
def test_reshape(self):
|
||||
raise SkipTest
|
||||
|
||||
def test_transpose(self):
|
||||
self.check(lambda x: lax.transpose(x, (1, 0, 2)),
|
||||
['(a, b, c)'], 'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)],
|
||||
['float_'], rand_default(self.rng()))
|
||||
|
||||
def test_sum_2d(self):
|
||||
self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)], ['float_'],
|
||||
rand_default(self.rng()))
|
||||
|
||||
def test_expit(self):
|
||||
raise SkipTest("custom_jvp doesn't work with masking yet")
|
||||
self.check(expit, ['n'], 'n', dict(n=3), [(4,)], ['float_'],
|
||||
rand_default(self.rng()))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
|
||||
for dtype in [np.float32, np.float64]))
|
||||
def test_uniform(self, dtype):
|
||||
raise SkipTest("not yet implemented")
|
||||
# TODO needs fix for https://github.com/google/jax/issues/2155
|
||||
|
||||
def test_broadcast_in_dim(self):
|
||||
raise SkipTest
|
||||
|
||||
def test_destructure(self):
|
||||
def d(key):
|
||||
key1, key2 = key
|
||||
return key1
|
||||
|
||||
self.check(d, ['2'], '', {}, [(2,)], ['int_'], rand_int(self.rng(), 0, 10))
|
||||
|
||||
def test_where(self):
|
||||
# Requires mask(jit)
|
||||
raise SkipTest
|
||||
self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n',
|
||||
{'n': 2}, [(3,)], ['float_'], rand_default(self.rng()))
|
||||
|
||||
def test_split(self):
|
||||
raise SkipTest
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list([{
|
||||
'testcase_name': "operator={}".format(operator.__name__), 'operator': operator}
|
||||
for operator in [jnp.sum, jnp.prod, jnp.max, jnp.min]]))
|
||||
def test_reduce(self, operator):
|
||||
self.check(operator, ['(m, n)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
|
||||
rand_default(self.rng()))
|
||||
|
||||
def test_output_shape_error(self):
|
||||
def thunk():
|
||||
shapecheck(['n'], 'n+-1')(lambda x: x)
|
||||
|
||||
message = "Output shapes should be (n + -1,) but are (n,)."
|
||||
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
|
||||
|
||||
def thunk():
|
||||
shapecheck(['n'], ['7*n', 'n'])(lambda x: (x, x))
|
||||
|
||||
message = "Output shapes should be [(7 n,), (n,)] but are ((n,), (n,))."
|
||||
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
|
||||
|
||||
def test_output_tree_error(self):
|
||||
def thunk():
|
||||
shapecheck(['n'], ('n', 'n'))(lambda x: [x, x])
|
||||
|
||||
message = "Output shapes should be ((n,), (n,)) but are [(n,), (n,)]."
|
||||
self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)
|
||||
|
||||
def test_unsupported_op(self):
|
||||
p = core.Primitive('unsupported_op')
|
||||
p.def_abstract_eval(lambda x: x)
|
||||
p.def_impl(lambda x: x)
|
||||
|
||||
def thunk():
|
||||
mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2})
|
||||
|
||||
message = "Masking rule for unsupported_op not implemented yet."
|
||||
self.assertRaisesWithLiteralMatch(NotImplementedError, message, thunk)
|
||||
|
||||
def test_nesting(self):
|
||||
raise SkipTest("not yet implemented")
|
||||
@ -568,16 +677,6 @@ class MaskingTest(jtu.JaxTestCase):
|
||||
expected = 3+1 + 5+9+2
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_arange(self):
|
||||
raise SkipTest("not yet implemented")
|
||||
|
||||
@partial(mask, in_shapes=['n'], out_shape='n')
|
||||
def padded_add(x):
|
||||
return x + lax.iota(x.shape[0])
|
||||
|
||||
ans = padded_add([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
|
||||
expected = np.array([3, 2, 6])
|
||||
self.assertAllClose(ans[:3], expected, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_start={}_stop={}_step={}_length={}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user