Internal change

PiperOrigin-RevId: 465705931
This commit is contained in:
jax authors 2022-08-05 21:08:16 -07:00
parent 3f9c300521
commit 6b0c0dc321
12 changed files with 74 additions and 453 deletions

View File

@ -2938,7 +2938,7 @@ def _valid_jaxtype(arg):
try:
xla.abstractify(arg) # faster than core.get_aval
except TypeError:
return core.valid_jaxtype(arg)
return False
else:
return True

View File

@ -435,11 +435,7 @@ def _shaped_abstractify_slow(x):
weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype)
else:
dtype = dtypes.result_type(x) # TODO(frostig,mattjj): why this case?
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
named_shape=named_shape)
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior

View File

@ -651,8 +651,6 @@ def array_result_handler(sticky_device: Optional[Device],
if aval.dtype == dtypes.float0:
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
aval = core.raise_to_shaped(aval)
if type(aval.dtype) in core.custom_eltypes:
return aval.dtype.result_handler(sticky_device, aval)
handler = lambda _, b: _maybe_create_array_from_da(b, aval, sticky_device)
handler.args = aval, sticky_device # for C++ dispatch path in api.py
return handler

View File

@ -25,7 +25,6 @@ from typing import Any, Dict, List
import numpy as np
import jax
from jax._src.config import flags, config
from jax._src.lib import xla_client
@ -86,8 +85,6 @@ def _to_complex_dtype(dtype):
@functools.lru_cache(maxsize=None)
def _canonicalize_dtype(x64_enabled, dtype):
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
if type(dtype) in jax.core.custom_eltypes:
return dtype
try:
dtype = np.dtype(dtype)
except TypeError as e:

View File

@ -29,13 +29,11 @@ from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
import jax._src.pretty_printer as pp
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map)
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
@ -234,8 +232,9 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
stacked_y = tree_map(stack, *maybe_reversed(ys))
return carry, stacked_y
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
x_shapes = [x.shape[1:] for x in xs_flat]
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
@ -243,7 +242,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
carry_avals = tuple(_map(_abstractify, init_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(
f, in_tree, (*carry_avals, *x_avals), "scan")
f, in_tree, carry_avals + x_avals, "scan")
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
@ -415,7 +414,7 @@ def _index_array(i, aval, x):
return slicing.index_in_dim(x, i, keepdims=False)
def _empty_array(sz, aval):
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
return lax.full((sz,) + aval.shape, 0, aval.dtype)
def _update_array(i, aval, xs, x):
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
@ -969,28 +968,6 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry
f'called with sequence of type\n{_avals_short(x_avals)}')
return [*init_avals, *y_avals], jaxpr.effects
def _scan_pp_rule(eqn, context, settings):
printed_params = dict(eqn.params)
del printed_params['linear']
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
del printed_params['length']
if printed_params['unroll'] == 1:
del printed_params['unroll']
if printed_params['num_carry'] == 0:
del printed_params['num_carry']
if printed_params['num_consts'] == 0:
del printed_params['num_consts']
if not printed_params['reverse']:
del printed_params['reverse']
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
def scan_bind(*args, **params):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
@ -1015,8 +992,6 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
# TODO(mattjj,frostig): un-comment this pp rule
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
### while_loop
@ -1401,32 +1376,38 @@ def _while_transpose_error(*_, **kwargs):
"lax.while_loop or lax.fori_loop. "
"Try using lax.scan instead.")
# For a while loop with ordered effects in the cond, we need a special
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
# this:
# ```
# while cond(x):
# x = body(x)
# ```
# into something that looks like this:
# ```
# while True:
# token, pred = cond(token, x)
# if not pred:
# break
# token, x = body(token, x)
# ```
# Unfortunately, with an MHLO while we can't (1) return multiple values
# from a `cond` and (2) can't break a while loop. We thus adopt the
# following rewrite strategy:
# ```
# def new_cond(pred, token, x):
# return pred
# token, pred = cond(token, x)
# while new_cond(pred, token, x):
# token, x = body(token, x)
# token, pred = cond(token, x)
# ```
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_initial_style_primitive(while_p)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
def _pred_bcast_select_mhlo(
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
y, = ys
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
bcast_pred = mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
return mhlo.SelectOp(bcast_pred, x, y).results
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
@ -1434,6 +1415,32 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in
core.ordered_effects]
if cond_ordered_effects:
# For a while loop with ordered effects in the cond, we need a special
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
# this:
# ```
# while cond(x):
# x = body(x)
# ```
# into something that looks like this:
# ```
# while True:
# token, pred = cond(token, x)
# if not pred:
# break
# token, x = body(token, x)
# ```
# Unfortunately, with an MHLO while we can't (1) return multiple values
# from a `cond` and (2) can't break a while loop. We thus adopt the
# following rewrite strategy:
# ```
# def new_cond(pred, token, x):
# return pred
# token, pred = cond(token, x)
# while new_cond(pred, token, x):
# token, x = body(token, x)
# token, pred = cond(token, x)
# ```
def cond(args):
return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
def body(args):
@ -1533,47 +1540,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
return z
def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
if joined_effects - allowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {joined_effects - allowed_effects}')
return body_jaxpr.out_avals, joined_effects
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_initial_style_primitive(while_p)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck
def _pred_bcast_select_mhlo(
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
y, = ys
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
bcast_pred = mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
return mhlo.SelectOp(bcast_pred, x, y).results
### fori_loop

View File

@ -2808,10 +2808,6 @@ def _broadcast_in_dim_partial_eval(
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.broadcast_in_dim_mlir(
ctx, x, *dyn_shape, shape=shape,
broadcast_dimensions=broadcast_dimensions)
if dyn_shape:
shape = _merge_dyn_shape(shape, dyn_shape)
return mhlo.DynamicBroadcastInDimOp(
@ -3292,8 +3288,6 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.transpose_mlir(ctx, x, permutation=permutation)
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
@ -4538,8 +4532,6 @@ def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
"""Check that dtypes agree, possibly ignoring float precision."""
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
# allows mixed floating point precision, but the HLO verifier often rejects it
if any(type(t) in core.custom_eltypes for t in ttypes):
return # TODO(mattjj,frostig): do some checking, friend
types = map(np.dtype, ttypes) # canonicalize
if ignore_fp_precision:
types = [
@ -4699,14 +4691,3 @@ def _check_user_dtype_supported(dtype, fun_name=None):
fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=3)
def empty(eltype):
return empty_p.bind(eltype=eltype)
empty_p = core.Primitive('empty')
empty_p.def_abstract_eval(lambda *, eltype: core.ShapedArray((), eltype))
def _empty_lower(ctx, *, eltype):
if type(eltype) in core.custom_eltypes:
return eltype.empty_mlir(ctx)
return mlir.ir_constants(np.empty((), np.dtype(eltype)))
mlir.register_lowering(empty_p, _empty_lower)

View File

@ -897,9 +897,6 @@ ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.dynamic_slice_mlir(ctx, x, start_indices, slice_sizes)
return mhlo.DynamicSliceOp(x, start_indices,
mlir.dense_int_elements(slice_sizes)).results
@ -996,9 +993,6 @@ batching.primitive_batchers[dynamic_update_slice_p] = \
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.dynamic_update_slice_mlir(
ctx, x, update, *start_indices)
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results

View File

@ -1132,7 +1132,7 @@ def lattice_join(x: Optional[AbstractValue],
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
def valid_jaxtype(x) -> bool:
def valid_jaxtype(x):
try:
concrete_aval(x)
except TypeError:
@ -1187,25 +1187,16 @@ def concrete_or_error(force: Any, val: Any, context=""):
return force(val)
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
custom_eltypes: Set[Any] = set()
def _short_dtype_name(dtype) -> str:
if type(dtype) in custom_eltypes:
return str(dtype)
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))
def _dtype_object(dtype):
return dtype if type(dtype) in custom_eltypes else np.dtype(dtype)
def _short_dtype_name(dtype):
return (dtype.name.replace('float', 'f').replace('uint', 'u')
.replace('int', 'i').replace('complex', 'c'))
class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 4
def __init__(self, dtype, weak_type=False):
self.dtype = _dtype_object(dtype)
self.dtype = np.dtype(dtype)
self.weak_type = weak_type
def update(self, dtype=None, weak_type=None):
@ -1273,7 +1264,7 @@ class ShapedArray(UnshapedArray):
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.dtype = np.dtype(dtype)
self.weak_type = weak_type
self.named_shape = {} if named_shape is None else dict(named_shape)
@ -1894,9 +1885,6 @@ class NamedShape:
return total
def __str__(self):
# TODO(mattjj,frostig): revise not to miss commas
if not self.__named:
return str(self.__positional)
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")

View File

@ -1791,15 +1791,6 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
def _empty(*, eltype):
if type(eltype) in core.custom_eltypes:
raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers
return tf.constant(np.array(0, dtype=eltype))
tf_impl[lax_internal.empty_p] = _empty
def _reshape(operand, *, new_sizes, dimensions):
if dimensions is None:
dimensions = tf.range(tf.rank(operand))

View File

@ -917,7 +917,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
with self.assertRaisesRegex(TypeError,
"Argument 'b' .*DimPoly.*not a valid JAX type"):
"Argument 'b' of type <class 'jax.experimental.jax2tf.shape_poly._DimPolynomial'> is not a valid JAX type"):
jax2tf.convert(lambda x: jnp.prod(x.shape),
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))

View File

@ -129,8 +129,6 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
) -> Sequence[ir.Type]:
if type(aval.dtype) in core.custom_eltypes:
return aval.dtype.aval_to_ir_types(aval)
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:

View File

@ -11,16 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
from functools import partial
import itertools
import operator
import types
import unittest
from unittest import SkipTest
from typing import Tuple
from absl.testing import absltest
from absl.testing import parameterized
@ -35,17 +33,13 @@ from jax.test_util import check_grads
from jax import tree_util
import jax.util
from jax.interpreters import xla
from jax.interpreters import mlir
from jax.interpreters import batching
from jax._src.lib.mlir.dialects import mhlo
from jax._src import dispatch
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src import lax_reference
from jax._src.util import prod
from jax._src.lax import lax as lax_internal
from jax.config import config
config.parse_flags_with_absl()
@ -2976,288 +2970,5 @@ class LaxNamedShapeTest(jtu.JaxTestCase):
(out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
self.assertEqual(out, expected)
class FooTy:
name = 'foo'
def __hash__(self) -> int:
return hash(FooTy)
def __eq__(self, other) -> bool:
return type(other) is FooTy
def __repr__(self) -> str:
return self.name
__str__ = __repr__
# handlers
@staticmethod
def aval_to_ir_types(aval):
aval2 = core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))
return mlir.aval_to_ir_types(aval2)
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return FooArray(aval.shape, buf)
return handler
# eltype-polymorphic primitive lowering rules
@staticmethod
def empty_mlir(ctx):
return mlir.ir_constants(np.empty((2,), dtype=np.dtype('uint32')))
@staticmethod
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, 2))
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
@staticmethod
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results
@staticmethod
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
if dyn_shape: raise NotImplementedError
aval_out, = ctx.avals_out
broadcast_dimensions = [*broadcast_dimensions, aval_out.ndim]
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)).results
@staticmethod
def transpose_mlir(ctx, x, *, permutation):
perm = [*permutation, len(permutation)]
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
# primitives
make_p = core.Primitive('make')
bake_p = core.Primitive('bake')
take_p = core.Primitive('take')
def make(shape): return make_p.bind(shape=tuple(shape))
def bake(k): return bake_p.bind(k)
def take(k): return take_p.bind(k)
@make_p.def_abstract_eval
def make_abstract_eval(*, shape):
return core.ShapedArray(shape, FooTy())
@bake_p.def_abstract_eval
def bake_abstract_eval(x):
if type(x.dtype) != FooTy: raise TypeError
return core.ShapedArray(tuple(reversed(x.shape)), FooTy())
@take_p.def_abstract_eval
def take_abstract_eval(x):
return core.ShapedArray(x.shape, jnp.dtype('float32'))
# runtime ('outside jit') data types
class FooArray:
shape: Tuple[int, ...]
data: jnp.ndarray
def __init__(self, shape, data):
assert data.shape == (*shape, 2)
self.shape = shape
self.data = data
def __repr__(self) -> str:
shape = ','.join(map(str, self.shape))
return f'foo[{shape}] with value\n{self.data}'
size = property(lambda self: self.data.size // 2)
ndim = property(lambda self: self.data.ndim - 1)
def device_put_foo_array(x: FooArray, device):
return dispatch._device_put_array(x.data, device)
def foo_array_constant_handler(x, c):
return mlir._device_array_constant_handler(x.data, c)
def make_lowering(*, shape):
return jnp.zeros((*shape, 2), 'uint32')
def bake_lowering(k):
return k.T
def take_lowering(k):
return jnp.broadcast_to(jnp.float32(k.size), k.shape)
def bake_vmap(batched_args, batch_dims):
xs, = batched_args
bdim_in, = batch_dims
ys = bake(xs)
perm = list(reversed(range(xs.ndim)))
bdim_out = perm[bdim_in]
return ys, bdim_out
class CustomElementTypesTest(jtu.JaxTestCase):
def setUp(self):
core.custom_eltypes.add(FooTy)
core.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
xla.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
dispatch.device_put_handlers[FooArray] = device_put_foo_array
mlir._constant_handlers[FooArray] = foo_array_constant_handler
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False))
mlir.register_lowering(take_p, mlir.lower_fun(take_lowering, False))
batching.defvectorized(take_p)
batching.primitive_batchers[bake_p] = bake_vmap
def tearDown(self):
core.custom_eltypes.remove(FooTy)
del core.pytype_aval_mappings[FooArray]
del xla.canonicalize_dtype_handlers[FooArray]
del xla.pytype_aval_mappings[FooArray]
del dispatch.device_put_handlers[FooArray]
del mlir._constant_handlers[FooArray]
del mlir._lowerings[make_p]
del mlir._lowerings[bake_p]
del mlir._lowerings[take_p]
del batching.primitive_batchers[take_p]
del batching.primitive_batchers[bake_p]
def test_shaped_array_construction(self):
aval = core.ShapedArray((), FooTy())
self.assertEqual(aval.str_short(), 'foo[]')
aval = core.ShapedArray((3, 4), FooTy())
self.assertEqual(aval.str_short(), 'foo[3,4]')
def test_make_jaxpr_identity(self):
x = types.SimpleNamespace(shape=(3,), dtype=FooTy())
jaxpr = jax.make_jaxpr(lambda x: x)(x).jaxpr
# { lambda ; a:foo[3]. let in (a,) }
self.assertLen(jaxpr.invars, 1)
a, = jaxpr.invars
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
self.assertLen(jaxpr.outvars, 1)
a, = jaxpr.outvars
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
# tests after here need the primitives
def test_make_jaxpr_with_primitives(self):
def f():
k1 = make((3, 4))
k2 = bake(k1)
x = take(k2)
return x
jaxpr = jax.make_jaxpr(f)().jaxpr
# { lambda ; . let
# a:foo[3,4] = make[shape=(3, 4)]
# b:foo[4,3] = bake a
# c:f32[4,3] = take b
# in (c,) }
self.assertLen(jaxpr.invars, 0)
self.assertLen(jaxpr.eqns, 3)
e1, e2, e3 = jaxpr.eqns
self.assertIs(e1.primitive, make_p)
self.assertLen(e1.outvars, 1)
a, = e1.outvars
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
self.assertIs(e2.primitive, bake_p)
self.assertLen(e2.outvars, 1)
b, = e2.outvars
self.assertEqual(b.aval, core.ShapedArray((4, 3), FooTy()))
self.assertIs(e3.primitive, take_p)
self.assertLen(e3.outvars, 1)
c, = e3.outvars
self.assertEqual(c.aval, core.ShapedArray((4, 3), np.dtype('float32')))
# tests after here need FooArray and lowerings
def test_jit_closure(self):
k = FooArray((), jnp.arange(2, dtype='uint32'))
@jax.jit
def f():
jnp.add(1, 1) # make jit not hit trivial dispatch path
return k
y = f() # doesn't crash
self.assertIsInstance(y, FooArray)
self.assertEqual(y.shape, ())
def test_jit_identity(self):
k = FooArray((), jnp.arange(2, dtype='uint32'))
@jax.jit
def f(k):
jnp.add(1, 1) # make jit not hit trivial dispatch path
return k
y = f(k) # doesn't crash
self.assertIsInstance(y, FooArray)
self.assertEqual(y.shape, ())
def test_jit_multiple_primitives(self):
@jax.jit
def f():
k1 = make((3,))
k2 = bake(k1)
y = take(k2)
return y
y = f()
self.assertArraysAllClose(y, jnp.array([3., 3., 3.]), check_dtypes=False)
def test_scan_jaxpr(self):
ks = jax.jit(lambda: make((3, 4)))()
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
# { lambda ; a:foo[3,4]. let
# b:foo[3,4] = scan[
# jaxpr={ lambda ; c:foo[4]. let d:foo[4] = bake c in (d,) }
# ] a
# in (b,) }
self.assertLen(jaxpr.invars, 1)
a, = jaxpr.invars
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
self.assertLen(jaxpr.eqns, 1)
e, = jaxpr.eqns
self.assertLen(e.outvars, 1)
b, = e.outvars
self.assertEqual(b.aval, core.ShapedArray((3, 4), FooTy()))
def test_scan_lowering(self):
ks = jax.jit(lambda: make((3, 4)))()
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
_, out = jax.jit(f)(ks) # doesn't crash
self.assertIsInstance(out, FooArray)
self.assertEqual(out.shape, (3, 4))
def test_vmap(self):
ks = jax.jit(lambda: make((3, 4, 5)))()
ys = jax.vmap(jax.jit(lambda k: take(bake(k))))(ks)
expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32')
self.assertAllClose(ys, expected)
def test_transpose(self):
ks = jax.jit(lambda: make((3, 4)))()
ys = jax.jit(lambda x: x.T)(ks)
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (4, 3))
# TODO(frostig,mattjj): more polymorphic primitives tests
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())