prototype unfettered element types in jaxpr arrays

From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.

Within jaxprs, we also need to handle transformations with these types.

In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
  * allowing ShapedArray to have any object for its 'dtype' attribute
  * added core.custom_eltype set
  * extended existing handlers for ShapedArray to call the corresponding custom
    element type handlers
  * mlir lowerings of some fully-element-type-polymorphic primitives
  * tests

In this PR, we only actually use these new extension points in tests.

The applications to come that we have in mind are:
  * arrays of prngkeys (and even custom prngs, as well as reuse error checking)
  * arrays of bounded int type for dynamic shapes (and especially raggedness)
  * float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.

Jargon-wise, we may want to distinguish:
  * 'eltype' meaning jaxpr element types
  * 'dtype' meaning numpy dtypes (an existing convention)
  * 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.

We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
  * [x] broadcast
  * [ ] reshape
  * [x] transpose
  * [ ] pad
  * [x] slice, dynamic_slice, dynamic_update_slice
  * [ ] concatenate
  * [ ] all_to_all, gather, scatter, all_gather, collective_permute
  * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.

We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.

Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
Matthew Johnson 2022-08-05 15:16:51 -07:00
parent 4b6d4a4ef7
commit 348da51dc6
12 changed files with 453 additions and 74 deletions

View File

@ -2938,7 +2938,7 @@ def _valid_jaxtype(arg):
try:
xla.abstractify(arg) # faster than core.get_aval
except TypeError:
return False
return core.valid_jaxtype(arg)
else:
return True

View File

@ -435,7 +435,11 @@ def _shaped_abstractify_slow(x):
weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype)
else:
dtype = dtypes.result_type(x) # TODO(frostig,mattjj): why this case?
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
named_shape=named_shape)
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior

View File

@ -651,6 +651,8 @@ def array_result_handler(sticky_device: Optional[Device],
if aval.dtype == dtypes.float0:
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
aval = core.raise_to_shaped(aval)
if type(aval.dtype) in core.custom_eltypes:
return aval.dtype.result_handler(sticky_device, aval)
handler = lambda _, b: _maybe_create_array_from_da(b, aval, sticky_device)
handler.args = aval, sticky_device # for C++ dispatch path in api.py
return handler

View File

@ -25,6 +25,7 @@ from typing import Any, Dict, List
import numpy as np
import jax
from jax._src.config import flags, config
from jax._src.lib import xla_client
@ -85,6 +86,8 @@ def _to_complex_dtype(dtype):
@functools.lru_cache(maxsize=None)
def _canonicalize_dtype(x64_enabled, dtype):
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
if type(dtype) in jax.core.custom_eltypes:
return dtype
try:
dtype = np.dtype(dtype)
except TypeError as e:

View File

@ -29,11 +29,13 @@ from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
import jax._src.pretty_printer as pp
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map)
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
@ -232,9 +234,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
stacked_y = tree_map(stack, *maybe_reversed(ys))
return carry, stacked_y
x_shapes = [x.shape[1:] for x in xs_flat]
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
@ -242,7 +243,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
carry_avals = tuple(_map(_abstractify, init_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, "scan")
f, in_tree, (*carry_avals, *x_avals), "scan")
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
@ -414,7 +415,7 @@ def _index_array(i, aval, x):
return slicing.index_in_dim(x, i, keepdims=False)
def _empty_array(sz, aval):
return lax.full((sz,) + aval.shape, 0, aval.dtype)
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
def _update_array(i, aval, xs, x):
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
@ -968,6 +969,28 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry
f'called with sequence of type\n{_avals_short(x_avals)}')
return [*init_avals, *y_avals], jaxpr.effects
def _scan_pp_rule(eqn, context, settings):
printed_params = dict(eqn.params)
del printed_params['linear']
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
del printed_params['length']
if printed_params['unroll'] == 1:
del printed_params['unroll']
if printed_params['num_carry'] == 0:
del printed_params['num_carry']
if printed_params['num_consts'] == 0:
del printed_params['num_consts']
if not printed_params['reverse']:
del printed_params['reverse']
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
def scan_bind(*args, **params):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
@ -992,6 +1015,8 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
# TODO(mattjj,frostig): un-comment this pp rule
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
### while_loop
@ -1376,38 +1401,32 @@ def _while_transpose_error(*_, **kwargs):
"lax.while_loop or lax.fori_loop. "
"Try using lax.scan instead.")
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_initial_style_primitive(while_p)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
def _pred_bcast_select_mhlo(
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
y, = ys
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
bcast_pred = mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
return mhlo.SelectOp(bcast_pred, x, y).results
# For a while loop with ordered effects in the cond, we need a special
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
# this:
# ```
# while cond(x):
# x = body(x)
# ```
# into something that looks like this:
# ```
# while True:
# token, pred = cond(token, x)
# if not pred:
# break
# token, x = body(token, x)
# ```
# Unfortunately, with an MHLO while we can't (1) return multiple values
# from a `cond` and (2) can't break a while loop. We thus adopt the
# following rewrite strategy:
# ```
# def new_cond(pred, token, x):
# return pred
# token, pred = cond(token, x)
# while new_cond(pred, token, x):
# token, x = body(token, x)
# token, pred = cond(token, x)
# ```
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
@ -1415,32 +1434,6 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in
core.ordered_effects]
if cond_ordered_effects:
# For a while loop with ordered effects in the cond, we need a special
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
# this:
# ```
# while cond(x):
# x = body(x)
# ```
# into something that looks like this:
# ```
# while True:
# token, pred = cond(token, x)
# if not pred:
# break
# token, x = body(token, x)
# ```
# Unfortunately, with an MHLO while we can't (1) return multiple values
# from a `cond` and (2) can't break a while loop. We thus adopt the
# following rewrite strategy:
# ```
# def new_cond(pred, token, x):
# return pred
# token, pred = cond(token, x)
# while new_cond(pred, token, x):
# token, x = body(token, x)
# token, pred = cond(token, x)
# ```
def cond(args):
return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
def body(args):
@ -1540,7 +1533,47 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
return z
def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
if joined_effects - allowed_effects:
raise NotImplementedError(
f'Effects not supported in `while`: {joined_effects - allowed_effects}')
return body_jaxpr.out_avals, joined_effects
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_initial_style_primitive(while_p)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck
def _pred_bcast_select_mhlo(
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
y, = ys
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
bcast_pred = mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
return mhlo.SelectOp(bcast_pred, x, y).results
### fori_loop

View File

@ -2808,6 +2808,10 @@ def _broadcast_in_dim_partial_eval(
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.broadcast_in_dim_mlir(
ctx, x, *dyn_shape, shape=shape,
broadcast_dimensions=broadcast_dimensions)
if dyn_shape:
shape = _merge_dyn_shape(shape, dyn_shape)
return mhlo.DynamicBroadcastInDimOp(
@ -3288,6 +3292,8 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.transpose_mlir(ctx, x, permutation=permutation)
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
@ -4532,6 +4538,8 @@ def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
"""Check that dtypes agree, possibly ignoring float precision."""
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
# allows mixed floating point precision, but the HLO verifier often rejects it
if any(type(t) in core.custom_eltypes for t in ttypes):
return # TODO(mattjj,frostig): do some checking, friend
types = map(np.dtype, ttypes) # canonicalize
if ignore_fp_precision:
types = [
@ -4691,3 +4699,14 @@ def _check_user_dtype_supported(dtype, fun_name=None):
fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=3)
def empty(eltype):
return empty_p.bind(eltype=eltype)
empty_p = core.Primitive('empty')
empty_p.def_abstract_eval(lambda *, eltype: core.ShapedArray((), eltype))
def _empty_lower(ctx, *, eltype):
if type(eltype) in core.custom_eltypes:
return eltype.empty_mlir(ctx)
return mlir.ir_constants(np.empty((), np.dtype(eltype)))
mlir.register_lowering(empty_p, _empty_lower)

View File

@ -897,6 +897,9 @@ ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.dynamic_slice_mlir(ctx, x, start_indices, slice_sizes)
return mhlo.DynamicSliceOp(x, start_indices,
mlir.dense_int_elements(slice_sizes)).results
@ -993,6 +996,9 @@ batching.primitive_batchers[dynamic_update_slice_p] = \
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.dynamic_update_slice_mlir(
ctx, x, update, *start_indices)
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results

View File

@ -1132,7 +1132,7 @@ def lattice_join(x: Optional[AbstractValue],
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
def valid_jaxtype(x):
def valid_jaxtype(x) -> bool:
try:
concrete_aval(x)
except TypeError:
@ -1187,16 +1187,25 @@ def concrete_or_error(force: Any, val: Any, context=""):
return force(val)
def _short_dtype_name(dtype):
return (dtype.name.replace('float', 'f').replace('uint', 'u')
.replace('int', 'i').replace('complex', 'c'))
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
custom_eltypes: Set[Any] = set()
def _short_dtype_name(dtype) -> str:
if type(dtype) in custom_eltypes:
return str(dtype)
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))
def _dtype_object(dtype):
return dtype if type(dtype) in custom_eltypes else np.dtype(dtype)
class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 4
def __init__(self, dtype, weak_type=False):
self.dtype = np.dtype(dtype)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
def update(self, dtype=None, weak_type=None):
@ -1264,7 +1273,7 @@ class ShapedArray(UnshapedArray):
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
self.shape = canonicalize_shape(shape)
self.dtype = np.dtype(dtype)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
self.named_shape = {} if named_shape is None else dict(named_shape)
@ -1885,6 +1894,9 @@ class NamedShape:
return total
def __str__(self):
# TODO(mattjj,frostig): revise not to miss commas
if not self.__named:
return str(self.__positional)
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")

View File

@ -1791,6 +1791,15 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
def _empty(*, eltype):
if type(eltype) in core.custom_eltypes:
raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers
return tf.constant(np.array(0, dtype=eltype))
tf_impl[lax_internal.empty_p] = _empty
def _reshape(operand, *, new_sizes, dimensions):
if dimensions is None:
dimensions = tf.range(tf.rank(operand))

View File

@ -917,7 +917,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
with self.assertRaisesRegex(TypeError,
"Argument 'b' of type <class 'jax.experimental.jax2tf.shape_poly._DimPolynomial'> is not a valid JAX type"):
"Argument 'b' .*DimPoly.*not a valid JAX type"):
jax2tf.convert(lambda x: jnp.prod(x.shape),
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))

View File

@ -129,6 +129,8 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
) -> Sequence[ir.Type]:
if type(aval.dtype) in core.custom_eltypes:
return aval.dtype.aval_to_ir_types(aval)
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:

View File

@ -11,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
from functools import partial
import itertools
import operator
import types
import unittest
from unittest import SkipTest
from typing import Tuple
from absl.testing import absltest
from absl.testing import parameterized
@ -33,13 +35,17 @@ from jax.test_util import check_grads
from jax import tree_util
import jax.util
from jax.interpreters import xla
from jax.interpreters import mlir
from jax.interpreters import batching
from jax._src.lib.mlir.dialects import mhlo
from jax._src import dispatch
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src import lax_reference
from jax._src.util import prod
from jax._src.lax import lax as lax_internal
from jax.config import config
config.parse_flags_with_absl()
@ -2970,5 +2976,288 @@ class LaxNamedShapeTest(jtu.JaxTestCase):
(out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
self.assertEqual(out, expected)
class FooTy:
name = 'foo'
def __hash__(self) -> int:
return hash(FooTy)
def __eq__(self, other) -> bool:
return type(other) is FooTy
def __repr__(self) -> str:
return self.name
__str__ = __repr__
# handlers
@staticmethod
def aval_to_ir_types(aval):
aval2 = core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))
return mlir.aval_to_ir_types(aval2)
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return FooArray(aval.shape, buf)
return handler
# eltype-polymorphic primitive lowering rules
@staticmethod
def empty_mlir(ctx):
return mlir.ir_constants(np.empty((2,), dtype=np.dtype('uint32')))
@staticmethod
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, 2))
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
@staticmethod
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
start_indices).results
@staticmethod
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
if dyn_shape: raise NotImplementedError
aval_out, = ctx.avals_out
broadcast_dimensions = [*broadcast_dimensions, aval_out.ndim]
return mhlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)).results
@staticmethod
def transpose_mlir(ctx, x, *, permutation):
perm = [*permutation, len(permutation)]
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
# primitives
make_p = core.Primitive('make')
bake_p = core.Primitive('bake')
take_p = core.Primitive('take')
def make(shape): return make_p.bind(shape=tuple(shape))
def bake(k): return bake_p.bind(k)
def take(k): return take_p.bind(k)
@make_p.def_abstract_eval
def make_abstract_eval(*, shape):
return core.ShapedArray(shape, FooTy())
@bake_p.def_abstract_eval
def bake_abstract_eval(x):
if type(x.dtype) != FooTy: raise TypeError
return core.ShapedArray(tuple(reversed(x.shape)), FooTy())
@take_p.def_abstract_eval
def take_abstract_eval(x):
return core.ShapedArray(x.shape, jnp.dtype('float32'))
# runtime ('outside jit') data types
class FooArray:
shape: Tuple[int, ...]
data: jnp.ndarray
def __init__(self, shape, data):
assert data.shape == (*shape, 2)
self.shape = shape
self.data = data
def __repr__(self) -> str:
shape = ','.join(map(str, self.shape))
return f'foo[{shape}] with value\n{self.data}'
size = property(lambda self: self.data.size // 2)
ndim = property(lambda self: self.data.ndim - 1)
def device_put_foo_array(x: FooArray, device):
return dispatch._device_put_array(x.data, device)
def foo_array_constant_handler(x, c):
return mlir._device_array_constant_handler(x.data, c)
def make_lowering(*, shape):
return jnp.zeros((*shape, 2), 'uint32')
def bake_lowering(k):
return k.T
def take_lowering(k):
return jnp.broadcast_to(jnp.float32(k.size), k.shape)
def bake_vmap(batched_args, batch_dims):
xs, = batched_args
bdim_in, = batch_dims
ys = bake(xs)
perm = list(reversed(range(xs.ndim)))
bdim_out = perm[bdim_in]
return ys, bdim_out
class CustomElementTypesTest(jtu.JaxTestCase):
def setUp(self):
core.custom_eltypes.add(FooTy)
core.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
xla.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
dispatch.device_put_handlers[FooArray] = device_put_foo_array
mlir._constant_handlers[FooArray] = foo_array_constant_handler
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False))
mlir.register_lowering(take_p, mlir.lower_fun(take_lowering, False))
batching.defvectorized(take_p)
batching.primitive_batchers[bake_p] = bake_vmap
def tearDown(self):
core.custom_eltypes.remove(FooTy)
del core.pytype_aval_mappings[FooArray]
del xla.canonicalize_dtype_handlers[FooArray]
del xla.pytype_aval_mappings[FooArray]
del dispatch.device_put_handlers[FooArray]
del mlir._constant_handlers[FooArray]
del mlir._lowerings[make_p]
del mlir._lowerings[bake_p]
del mlir._lowerings[take_p]
del batching.primitive_batchers[take_p]
del batching.primitive_batchers[bake_p]
def test_shaped_array_construction(self):
aval = core.ShapedArray((), FooTy())
self.assertEqual(aval.str_short(), 'foo[]')
aval = core.ShapedArray((3, 4), FooTy())
self.assertEqual(aval.str_short(), 'foo[3,4]')
def test_make_jaxpr_identity(self):
x = types.SimpleNamespace(shape=(3,), dtype=FooTy())
jaxpr = jax.make_jaxpr(lambda x: x)(x).jaxpr
# { lambda ; a:foo[3]. let in (a,) }
self.assertLen(jaxpr.invars, 1)
a, = jaxpr.invars
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
self.assertLen(jaxpr.outvars, 1)
a, = jaxpr.outvars
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
# tests after here need the primitives
def test_make_jaxpr_with_primitives(self):
def f():
k1 = make((3, 4))
k2 = bake(k1)
x = take(k2)
return x
jaxpr = jax.make_jaxpr(f)().jaxpr
# { lambda ; . let
# a:foo[3,4] = make[shape=(3, 4)]
# b:foo[4,3] = bake a
# c:f32[4,3] = take b
# in (c,) }
self.assertLen(jaxpr.invars, 0)
self.assertLen(jaxpr.eqns, 3)
e1, e2, e3 = jaxpr.eqns
self.assertIs(e1.primitive, make_p)
self.assertLen(e1.outvars, 1)
a, = e1.outvars
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
self.assertIs(e2.primitive, bake_p)
self.assertLen(e2.outvars, 1)
b, = e2.outvars
self.assertEqual(b.aval, core.ShapedArray((4, 3), FooTy()))
self.assertIs(e3.primitive, take_p)
self.assertLen(e3.outvars, 1)
c, = e3.outvars
self.assertEqual(c.aval, core.ShapedArray((4, 3), np.dtype('float32')))
# tests after here need FooArray and lowerings
def test_jit_closure(self):
k = FooArray((), jnp.arange(2, dtype='uint32'))
@jax.jit
def f():
jnp.add(1, 1) # make jit not hit trivial dispatch path
return k
y = f() # doesn't crash
self.assertIsInstance(y, FooArray)
self.assertEqual(y.shape, ())
def test_jit_identity(self):
k = FooArray((), jnp.arange(2, dtype='uint32'))
@jax.jit
def f(k):
jnp.add(1, 1) # make jit not hit trivial dispatch path
return k
y = f(k) # doesn't crash
self.assertIsInstance(y, FooArray)
self.assertEqual(y.shape, ())
def test_jit_multiple_primitives(self):
@jax.jit
def f():
k1 = make((3,))
k2 = bake(k1)
y = take(k2)
return y
y = f()
self.assertArraysAllClose(y, jnp.array([3., 3., 3.]), check_dtypes=False)
def test_scan_jaxpr(self):
ks = jax.jit(lambda: make((3, 4)))()
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
# { lambda ; a:foo[3,4]. let
# b:foo[3,4] = scan[
# jaxpr={ lambda ; c:foo[4]. let d:foo[4] = bake c in (d,) }
# ] a
# in (b,) }
self.assertLen(jaxpr.invars, 1)
a, = jaxpr.invars
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
self.assertLen(jaxpr.eqns, 1)
e, = jaxpr.eqns
self.assertLen(e.outvars, 1)
b, = e.outvars
self.assertEqual(b.aval, core.ShapedArray((3, 4), FooTy()))
def test_scan_lowering(self):
ks = jax.jit(lambda: make((3, 4)))()
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
_, out = jax.jit(f)(ks) # doesn't crash
self.assertIsInstance(out, FooArray)
self.assertEqual(out.shape, (3, 4))
def test_vmap(self):
ks = jax.jit(lambda: make((3, 4, 5)))()
ys = jax.vmap(jax.jit(lambda k: take(bake(k))))(ks)
expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32')
self.assertAllClose(ys, expected)
def test_transpose(self):
ks = jax.jit(lambda: make((3, 4)))()
ys = jax.jit(lambda x: x.T)(ks)
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (4, 3))
# TODO(frostig,mattjj): more polymorphic primitives tests
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())