mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
prototype unfettered element types in jaxpr arrays
From where comes the set of element types in jaxprs? Historically, from NumPy and XLA element types. But why would jaxprs be constrained to those? After all, jaxprs are just symbols, my friends. Those symbols need to be grounded when we translate to another compiler's IR, or when we have input or output values with a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to lowered IR types, and also ways to map any operations allowed on these types to lowered IR operations. And we may want Python objects representing values of these types. But once we have those mappings we don't need to be limited by NumPy/XLA element types. Within jaxprs, we also need to handle transformations with these types. In this change we started unfettering jaxpr element types from their vestigial NumPy/XLA constraints. Concretely, that means: * allowing ShapedArray to have any object for its 'dtype' attribute * added core.custom_eltype set * extended existing handlers for ShapedArray to call the corresponding custom element type handlers * mlir lowerings of some fully-element-type-polymorphic primitives * tests In this PR, we only actually use these new extension points in tests. The applications to come that we have in mind are: * arrays of prngkeys (and even custom prngs, as well as reuse error checking) * arrays of bounded int type for dynamic shapes (and especially raggedness) * float0 arrays We do *not* have in mind opening these mechanisms up to users. Think of these as yet another JAX-internal extension point, like all our existing 'handler' tables. Jargon-wise, we may want to distinguish: * 'eltype' meaning jaxpr element types * 'dtype' meaning numpy dtypes (an existing convention) * 'etype' meaning hlo/mhlo element types (an existing convention) But the code doesn't model this jargon at the moment, since we left a lot of attributes and helper functions referring to 'dtype'. We haven't yet handled all the element-type-polymorphic primitives. Here's the list we've thought of so far: * [x] broadcast * [ ] reshape * [x] transpose * [ ] pad * [x] slice, dynamic_slice, dynamic_update_slice * [ ] concatenate * [ ] all_to_all, gather, scatter, all_gather, collective_permute * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect) That last one is interesting: we introduced it so that the scan lowering rule, which lowers first to a "lowered jaxpr dialect" involving only those eltypes which correspond to etypes and involving only while_loop, ds/dus, etc, can be made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic primitive, have a more complicated lowering rule. We also haven't handled AD. Our main applications (at least the first two listed above) don't involve AD types, so it seemed good to skip for now. Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
parent
4b6d4a4ef7
commit
348da51dc6
@ -2938,7 +2938,7 @@ def _valid_jaxtype(arg):
|
||||
try:
|
||||
xla.abstractify(arg) # faster than core.get_aval
|
||||
except TypeError:
|
||||
return False
|
||||
return core.valid_jaxtype(arg)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
@ -435,7 +435,11 @@ def _shaped_abstractify_slow(x):
|
||||
|
||||
weak_type = getattr(x, 'weak_type', False)
|
||||
named_shape = getattr(x, 'named_shape', {})
|
||||
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
|
||||
if hasattr(x, 'dtype'):
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype)
|
||||
else:
|
||||
dtype = dtypes.result_type(x) # TODO(frostig,mattjj): why this case?
|
||||
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
|
||||
named_shape=named_shape)
|
||||
|
||||
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
|
||||
|
@ -651,6 +651,8 @@ def array_result_handler(sticky_device: Optional[Device],
|
||||
if aval.dtype == dtypes.float0:
|
||||
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
|
||||
aval = core.raise_to_shaped(aval)
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
return aval.dtype.result_handler(sticky_device, aval)
|
||||
handler = lambda _, b: _maybe_create_array_from_da(b, aval, sticky_device)
|
||||
handler.args = aval, sticky_device # for C++ dispatch path in api.py
|
||||
return handler
|
||||
|
@ -25,6 +25,7 @@ from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src.config import flags, config
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
@ -85,6 +86,8 @@ def _to_complex_dtype(dtype):
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _canonicalize_dtype(x64_enabled, dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
if type(dtype) in jax.core.custom_eltypes:
|
||||
return dtype
|
||||
try:
|
||||
dtype = np.dtype(dtype)
|
||||
except TypeError as e:
|
||||
|
@ -29,11 +29,13 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
tree_map)
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -232,9 +234,8 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
stacked_y = tree_map(stack, *maybe_reversed(ys))
|
||||
return carry, stacked_y
|
||||
|
||||
x_shapes = [x.shape[1:] for x in xs_flat]
|
||||
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
|
||||
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
|
||||
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
|
||||
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
||||
|
||||
def _create_jaxpr(init):
|
||||
init_flat, init_tree = tree_flatten(init)
|
||||
@ -242,7 +243,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
|
||||
carry_avals = tuple(_map(_abstractify, init_flat))
|
||||
jaxpr, consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_tree, carry_avals + x_avals, "scan")
|
||||
f, in_tree, (*carry_avals, *x_avals), "scan")
|
||||
out_tree_children = out_tree.children()
|
||||
if len(out_tree_children) != 2:
|
||||
msg = "scan body output must be a pair, got {}."
|
||||
@ -414,7 +415,7 @@ def _index_array(i, aval, x):
|
||||
return slicing.index_in_dim(x, i, keepdims=False)
|
||||
|
||||
def _empty_array(sz, aval):
|
||||
return lax.full((sz,) + aval.shape, 0, aval.dtype)
|
||||
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
|
||||
|
||||
def _update_array(i, aval, xs, x):
|
||||
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
|
||||
@ -968,6 +969,28 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry
|
||||
f'called with sequence of type\n{_avals_short(x_avals)}')
|
||||
return [*init_avals, *y_avals], jaxpr.effects
|
||||
|
||||
def _scan_pp_rule(eqn, context, settings):
|
||||
printed_params = dict(eqn.params)
|
||||
del printed_params['linear']
|
||||
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
|
||||
del printed_params['length']
|
||||
if printed_params['unroll'] == 1:
|
||||
del printed_params['unroll']
|
||||
if printed_params['num_carry'] == 0:
|
||||
del printed_params['num_carry']
|
||||
if printed_params['num_consts'] == 0:
|
||||
del printed_params['num_consts']
|
||||
if not printed_params['reverse']:
|
||||
del printed_params['reverse']
|
||||
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
|
||||
rhs = [pp.text(eqn.primitive.name),
|
||||
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
|
||||
pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if settings.source_info else None)
|
||||
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
|
||||
|
||||
|
||||
def scan_bind(*args, **params):
|
||||
if config.jax_enable_checks:
|
||||
avals = _map(core.get_aval, args)
|
||||
@ -992,6 +1015,8 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
|
||||
pe.padding_rules[scan_p] = _scan_padding_rule
|
||||
pe.dce_rules[scan_p] = _scan_dce_rule
|
||||
# TODO(mattjj,frostig): un-comment this pp rule
|
||||
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
|
||||
|
||||
### while_loop
|
||||
|
||||
@ -1376,38 +1401,32 @@ def _while_transpose_error(*_, **kwargs):
|
||||
"lax.while_loop or lax.fori_loop. "
|
||||
"Try using lax.scan instead.")
|
||||
|
||||
while_p = core.AxisPrimitive('while')
|
||||
while_p.multiple_results = True
|
||||
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
||||
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
|
||||
ad.primitive_jvps[while_p] = _while_loop_jvp
|
||||
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
xla.register_initial_style_primitive(while_p)
|
||||
ad.primitive_transposes[while_p] = _while_transpose_error
|
||||
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
|
||||
|
||||
|
||||
def _pred_bcast_select_mhlo(
|
||||
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
|
||||
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
|
||||
if x_y_aval is core.abstract_token:
|
||||
x, = xs
|
||||
y, = ys
|
||||
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
|
||||
else:
|
||||
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
|
||||
x, = xs
|
||||
y, = ys
|
||||
assert x.type == y.type, (x.type, y.type)
|
||||
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
|
||||
pred_aval.shape, x_y_aval)
|
||||
bcast_pred = mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
|
||||
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
|
||||
return mhlo.SelectOp(bcast_pred, x, y).results
|
||||
|
||||
|
||||
# For a while loop with ordered effects in the cond, we need a special
|
||||
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
|
||||
# this:
|
||||
# ```
|
||||
# while cond(x):
|
||||
# x = body(x)
|
||||
# ```
|
||||
# into something that looks like this:
|
||||
# ```
|
||||
# while True:
|
||||
# token, pred = cond(token, x)
|
||||
# if not pred:
|
||||
# break
|
||||
# token, x = body(token, x)
|
||||
# ```
|
||||
# Unfortunately, with an MHLO while we can't (1) return multiple values
|
||||
# from a `cond` and (2) can't break a while loop. We thus adopt the
|
||||
# following rewrite strategy:
|
||||
# ```
|
||||
# def new_cond(pred, token, x):
|
||||
# return pred
|
||||
# token, pred = cond(token, x)
|
||||
# while new_cond(pred, token, x):
|
||||
# token, x = body(token, x)
|
||||
# token, pred = cond(token, x)
|
||||
# ```
|
||||
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
body_nconsts):
|
||||
pred_aval = cond_jaxpr.out_avals[0]
|
||||
@ -1415,32 +1434,6 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in
|
||||
core.ordered_effects]
|
||||
if cond_ordered_effects:
|
||||
# For a while loop with ordered effects in the cond, we need a special
|
||||
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
|
||||
# this:
|
||||
# ```
|
||||
# while cond(x):
|
||||
# x = body(x)
|
||||
# ```
|
||||
# into something that looks like this:
|
||||
# ```
|
||||
# while True:
|
||||
# token, pred = cond(token, x)
|
||||
# if not pred:
|
||||
# break
|
||||
# token, x = body(token, x)
|
||||
# ```
|
||||
# Unfortunately, with an MHLO while we can't (1) return multiple values
|
||||
# from a `cond` and (2) can't break a while loop. We thus adopt the
|
||||
# following rewrite strategy:
|
||||
# ```
|
||||
# def new_cond(pred, token, x):
|
||||
# return pred
|
||||
# token, pred = cond(token, x)
|
||||
# while new_cond(pred, token, x):
|
||||
# token, x = body(token, x)
|
||||
# token, pred = cond(token, x)
|
||||
# ```
|
||||
def cond(args):
|
||||
return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
|
||||
def body(args):
|
||||
@ -1540,7 +1533,47 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
|
||||
return z
|
||||
|
||||
def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
body_nconsts):
|
||||
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
|
||||
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
||||
if joined_effects - allowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {joined_effects - allowed_effects}')
|
||||
return body_jaxpr.out_avals, joined_effects
|
||||
|
||||
while_p = core.AxisPrimitive('while')
|
||||
while_p.multiple_results = True
|
||||
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
||||
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
|
||||
ad.primitive_jvps[while_p] = _while_loop_jvp
|
||||
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
xla.register_initial_style_primitive(while_p)
|
||||
ad.primitive_transposes[while_p] = _while_transpose_error
|
||||
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
|
||||
mlir.register_lowering(while_p, _while_lowering)
|
||||
core.custom_typechecks[while_p] = _while_typecheck
|
||||
|
||||
|
||||
def _pred_bcast_select_mhlo(
|
||||
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
|
||||
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
|
||||
if x_y_aval is core.abstract_token:
|
||||
x, = xs
|
||||
y, = ys
|
||||
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
|
||||
else:
|
||||
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
|
||||
x, = xs
|
||||
y, = ys
|
||||
assert x.type == y.type, (x.type, y.type)
|
||||
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
|
||||
pred_aval.shape, x_y_aval)
|
||||
bcast_pred = mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
|
||||
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
|
||||
return mhlo.SelectOp(bcast_pred, x, y).results
|
||||
|
||||
### fori_loop
|
||||
|
||||
|
@ -2808,6 +2808,10 @@ def _broadcast_in_dim_partial_eval(
|
||||
|
||||
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.broadcast_in_dim_mlir(
|
||||
ctx, x, *dyn_shape, shape=shape,
|
||||
broadcast_dimensions=broadcast_dimensions)
|
||||
if dyn_shape:
|
||||
shape = _merge_dyn_shape(shape, dyn_shape)
|
||||
return mhlo.DynamicBroadcastInDimOp(
|
||||
@ -3288,6 +3292,8 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
|
||||
|
||||
def _transpose_lower(ctx, x, *, permutation):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.transpose_mlir(ctx, x, permutation=permutation)
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
|
||||
|
||||
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
@ -4532,6 +4538,8 @@ def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
|
||||
"""Check that dtypes agree, possibly ignoring float precision."""
|
||||
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
||||
# allows mixed floating point precision, but the HLO verifier often rejects it
|
||||
if any(type(t) in core.custom_eltypes for t in ttypes):
|
||||
return # TODO(mattjj,frostig): do some checking, friend
|
||||
types = map(np.dtype, ttypes) # canonicalize
|
||||
if ignore_fp_precision:
|
||||
types = [
|
||||
@ -4691,3 +4699,14 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
fun_name = f"requested in {fun_name}" if fun_name else ""
|
||||
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=3)
|
||||
|
||||
|
||||
def empty(eltype):
|
||||
return empty_p.bind(eltype=eltype)
|
||||
empty_p = core.Primitive('empty')
|
||||
empty_p.def_abstract_eval(lambda *, eltype: core.ShapedArray((), eltype))
|
||||
def _empty_lower(ctx, *, eltype):
|
||||
if type(eltype) in core.custom_eltypes:
|
||||
return eltype.empty_mlir(ctx)
|
||||
return mlir.ir_constants(np.empty((), np.dtype(eltype)))
|
||||
mlir.register_lowering(empty_p, _empty_lower)
|
||||
|
@ -897,6 +897,9 @@ ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
|
||||
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
|
||||
|
||||
def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.dynamic_slice_mlir(ctx, x, start_indices, slice_sizes)
|
||||
return mhlo.DynamicSliceOp(x, start_indices,
|
||||
mlir.dense_int_elements(slice_sizes)).results
|
||||
|
||||
@ -993,6 +996,9 @@ batching.primitive_batchers[dynamic_update_slice_p] = \
|
||||
|
||||
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.dynamic_update_slice_mlir(
|
||||
ctx, x, update, *start_indices)
|
||||
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
||||
start_indices).results
|
||||
|
||||
|
24
jax/core.py
24
jax/core.py
@ -1132,7 +1132,7 @@ def lattice_join(x: Optional[AbstractValue],
|
||||
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
|
||||
Value = Any
|
||||
|
||||
def valid_jaxtype(x):
|
||||
def valid_jaxtype(x) -> bool:
|
||||
try:
|
||||
concrete_aval(x)
|
||||
except TypeError:
|
||||
@ -1187,16 +1187,25 @@ def concrete_or_error(force: Any, val: Any, context=""):
|
||||
return force(val)
|
||||
|
||||
|
||||
def _short_dtype_name(dtype):
|
||||
return (dtype.name.replace('float', 'f').replace('uint', 'u')
|
||||
.replace('int', 'i').replace('complex', 'c'))
|
||||
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
|
||||
custom_eltypes: Set[Any] = set()
|
||||
|
||||
def _short_dtype_name(dtype) -> str:
|
||||
if type(dtype) in custom_eltypes:
|
||||
return str(dtype)
|
||||
else:
|
||||
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
|
||||
.replace('int' , 'i').replace('complex', 'c'))
|
||||
|
||||
def _dtype_object(dtype):
|
||||
return dtype if type(dtype) in custom_eltypes else np.dtype(dtype)
|
||||
|
||||
class UnshapedArray(AbstractValue):
|
||||
__slots__ = ['dtype', 'weak_type']
|
||||
array_abstraction_level = 4
|
||||
|
||||
def __init__(self, dtype, weak_type=False):
|
||||
self.dtype = np.dtype(dtype)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
|
||||
def update(self, dtype=None, weak_type=None):
|
||||
@ -1264,7 +1273,7 @@ class ShapedArray(UnshapedArray):
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.dtype = np.dtype(dtype)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
self.named_shape = {} if named_shape is None else dict(named_shape)
|
||||
|
||||
@ -1885,6 +1894,9 @@ class NamedShape:
|
||||
return total
|
||||
|
||||
def __str__(self):
|
||||
# TODO(mattjj,frostig): revise not to miss commas
|
||||
if not self.__named:
|
||||
return str(self.__positional)
|
||||
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
|
||||
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")
|
||||
|
||||
|
@ -1791,6 +1791,15 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
|
||||
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
|
||||
|
||||
|
||||
def _empty(*, eltype):
|
||||
if type(eltype) in core.custom_eltypes:
|
||||
raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers
|
||||
return tf.constant(np.array(0, dtype=eltype))
|
||||
|
||||
|
||||
tf_impl[lax_internal.empty_p] = _empty
|
||||
|
||||
|
||||
def _reshape(operand, *, new_sizes, dimensions):
|
||||
if dimensions is None:
|
||||
dimensions = tf.range(tf.rank(operand))
|
||||
|
@ -917,7 +917,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Argument 'b' of type <class 'jax.experimental.jax2tf.shape_poly._DimPolynomial'> is not a valid JAX type"):
|
||||
"Argument 'b' .*DimPoly.*not a valid JAX type"):
|
||||
jax2tf.convert(lambda x: jnp.prod(x.shape),
|
||||
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
||||
|
||||
|
@ -129,6 +129,8 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
|
||||
|
||||
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
|
||||
) -> Sequence[ir.Type]:
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
return aval.dtype.aval_to_ir_types(aval)
|
||||
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
|
||||
|
||||
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
||||
|
@ -11,14 +11,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from functools import partial
|
||||
import itertools
|
||||
import operator
|
||||
import types
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
from typing import Tuple
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -33,13 +35,17 @@ from jax.test_util import check_grads
|
||||
from jax import tree_util
|
||||
import jax.util
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import lax_reference
|
||||
from jax._src.util import prod
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -2970,5 +2976,288 @@ class LaxNamedShapeTest(jtu.JaxTestCase):
|
||||
(out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
|
||||
class FooTy:
|
||||
name = 'foo'
|
||||
def __hash__(self) -> int:
|
||||
return hash(FooTy)
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(other) is FooTy
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
__str__ = __repr__
|
||||
|
||||
# handlers
|
||||
|
||||
@staticmethod
|
||||
def aval_to_ir_types(aval):
|
||||
aval2 = core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))
|
||||
return mlir.aval_to_ir_types(aval2)
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
return FooArray(aval.shape, buf)
|
||||
return handler
|
||||
|
||||
# eltype-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def empty_mlir(ctx):
|
||||
return mlir.ir_constants(np.empty((2,), dtype=np.dtype('uint32')))
|
||||
|
||||
@staticmethod
|
||||
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
|
||||
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, 2))
|
||||
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
|
||||
|
||||
@staticmethod
|
||||
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.avals_out
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
|
||||
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
||||
start_indices).results
|
||||
|
||||
@staticmethod
|
||||
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
|
||||
if dyn_shape: raise NotImplementedError
|
||||
aval_out, = ctx.avals_out
|
||||
broadcast_dimensions = [*broadcast_dimensions, aval_out.ndim]
|
||||
return mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(aval_out), x,
|
||||
mlir.dense_int_elements(broadcast_dimensions)).results
|
||||
|
||||
@staticmethod
|
||||
def transpose_mlir(ctx, x, *, permutation):
|
||||
perm = [*permutation, len(permutation)]
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
|
||||
|
||||
# primitives
|
||||
|
||||
make_p = core.Primitive('make')
|
||||
bake_p = core.Primitive('bake')
|
||||
take_p = core.Primitive('take')
|
||||
|
||||
def make(shape): return make_p.bind(shape=tuple(shape))
|
||||
def bake(k): return bake_p.bind(k)
|
||||
def take(k): return take_p.bind(k)
|
||||
|
||||
@make_p.def_abstract_eval
|
||||
def make_abstract_eval(*, shape):
|
||||
return core.ShapedArray(shape, FooTy())
|
||||
|
||||
@bake_p.def_abstract_eval
|
||||
def bake_abstract_eval(x):
|
||||
if type(x.dtype) != FooTy: raise TypeError
|
||||
return core.ShapedArray(tuple(reversed(x.shape)), FooTy())
|
||||
|
||||
@take_p.def_abstract_eval
|
||||
def take_abstract_eval(x):
|
||||
return core.ShapedArray(x.shape, jnp.dtype('float32'))
|
||||
|
||||
# runtime ('outside jit') data types
|
||||
|
||||
class FooArray:
|
||||
shape: Tuple[int, ...]
|
||||
data: jnp.ndarray
|
||||
|
||||
def __init__(self, shape, data):
|
||||
assert data.shape == (*shape, 2)
|
||||
self.shape = shape
|
||||
self.data = data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
shape = ','.join(map(str, self.shape))
|
||||
return f'foo[{shape}] with value\n{self.data}'
|
||||
|
||||
size = property(lambda self: self.data.size // 2)
|
||||
ndim = property(lambda self: self.data.ndim - 1)
|
||||
|
||||
def device_put_foo_array(x: FooArray, device):
|
||||
return dispatch._device_put_array(x.data, device)
|
||||
|
||||
def foo_array_constant_handler(x, c):
|
||||
return mlir._device_array_constant_handler(x.data, c)
|
||||
|
||||
def make_lowering(*, shape):
|
||||
return jnp.zeros((*shape, 2), 'uint32')
|
||||
|
||||
def bake_lowering(k):
|
||||
return k.T
|
||||
|
||||
def take_lowering(k):
|
||||
return jnp.broadcast_to(jnp.float32(k.size), k.shape)
|
||||
|
||||
|
||||
def bake_vmap(batched_args, batch_dims):
|
||||
xs, = batched_args
|
||||
bdim_in, = batch_dims
|
||||
ys = bake(xs)
|
||||
perm = list(reversed(range(xs.ndim)))
|
||||
bdim_out = perm[bdim_in]
|
||||
return ys, bdim_out
|
||||
|
||||
|
||||
class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
core.custom_eltypes.add(FooTy)
|
||||
core.pytype_aval_mappings[FooArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, FooTy())
|
||||
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
|
||||
xla.pytype_aval_mappings[FooArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, FooTy())
|
||||
dispatch.device_put_handlers[FooArray] = device_put_foo_array
|
||||
mlir._constant_handlers[FooArray] = foo_array_constant_handler
|
||||
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
|
||||
mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False))
|
||||
mlir.register_lowering(take_p, mlir.lower_fun(take_lowering, False))
|
||||
batching.defvectorized(take_p)
|
||||
batching.primitive_batchers[bake_p] = bake_vmap
|
||||
|
||||
def tearDown(self):
|
||||
core.custom_eltypes.remove(FooTy)
|
||||
del core.pytype_aval_mappings[FooArray]
|
||||
del xla.canonicalize_dtype_handlers[FooArray]
|
||||
del xla.pytype_aval_mappings[FooArray]
|
||||
del dispatch.device_put_handlers[FooArray]
|
||||
del mlir._constant_handlers[FooArray]
|
||||
del mlir._lowerings[make_p]
|
||||
del mlir._lowerings[bake_p]
|
||||
del mlir._lowerings[take_p]
|
||||
del batching.primitive_batchers[take_p]
|
||||
del batching.primitive_batchers[bake_p]
|
||||
|
||||
def test_shaped_array_construction(self):
|
||||
aval = core.ShapedArray((), FooTy())
|
||||
self.assertEqual(aval.str_short(), 'foo[]')
|
||||
aval = core.ShapedArray((3, 4), FooTy())
|
||||
self.assertEqual(aval.str_short(), 'foo[3,4]')
|
||||
|
||||
def test_make_jaxpr_identity(self):
|
||||
x = types.SimpleNamespace(shape=(3,), dtype=FooTy())
|
||||
jaxpr = jax.make_jaxpr(lambda x: x)(x).jaxpr
|
||||
# { lambda ; a:foo[3]. let in (a,) }
|
||||
self.assertLen(jaxpr.invars, 1)
|
||||
a, = jaxpr.invars
|
||||
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
|
||||
self.assertLen(jaxpr.outvars, 1)
|
||||
a, = jaxpr.outvars
|
||||
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
|
||||
|
||||
# tests after here need the primitives
|
||||
|
||||
def test_make_jaxpr_with_primitives(self):
|
||||
def f():
|
||||
k1 = make((3, 4))
|
||||
k2 = bake(k1)
|
||||
x = take(k2)
|
||||
return x
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)().jaxpr
|
||||
# { lambda ; . let
|
||||
# a:foo[3,4] = make[shape=(3, 4)]
|
||||
# b:foo[4,3] = bake a
|
||||
# c:f32[4,3] = take b
|
||||
# in (c,) }
|
||||
self.assertLen(jaxpr.invars, 0)
|
||||
self.assertLen(jaxpr.eqns, 3)
|
||||
e1, e2, e3 = jaxpr.eqns
|
||||
|
||||
self.assertIs(e1.primitive, make_p)
|
||||
self.assertLen(e1.outvars, 1)
|
||||
a, = e1.outvars
|
||||
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
|
||||
|
||||
self.assertIs(e2.primitive, bake_p)
|
||||
self.assertLen(e2.outvars, 1)
|
||||
b, = e2.outvars
|
||||
self.assertEqual(b.aval, core.ShapedArray((4, 3), FooTy()))
|
||||
|
||||
self.assertIs(e3.primitive, take_p)
|
||||
self.assertLen(e3.outvars, 1)
|
||||
c, = e3.outvars
|
||||
self.assertEqual(c.aval, core.ShapedArray((4, 3), np.dtype('float32')))
|
||||
|
||||
# tests after here need FooArray and lowerings
|
||||
|
||||
def test_jit_closure(self):
|
||||
k = FooArray((), jnp.arange(2, dtype='uint32'))
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
jnp.add(1, 1) # make jit not hit trivial dispatch path
|
||||
return k
|
||||
|
||||
y = f() # doesn't crash
|
||||
self.assertIsInstance(y, FooArray)
|
||||
self.assertEqual(y.shape, ())
|
||||
|
||||
def test_jit_identity(self):
|
||||
k = FooArray((), jnp.arange(2, dtype='uint32'))
|
||||
|
||||
@jax.jit
|
||||
def f(k):
|
||||
jnp.add(1, 1) # make jit not hit trivial dispatch path
|
||||
return k
|
||||
|
||||
y = f(k) # doesn't crash
|
||||
self.assertIsInstance(y, FooArray)
|
||||
self.assertEqual(y.shape, ())
|
||||
|
||||
def test_jit_multiple_primitives(self):
|
||||
@jax.jit
|
||||
def f():
|
||||
k1 = make((3,))
|
||||
k2 = bake(k1)
|
||||
y = take(k2)
|
||||
return y
|
||||
|
||||
y = f()
|
||||
self.assertArraysAllClose(y, jnp.array([3., 3., 3.]), check_dtypes=False)
|
||||
|
||||
def test_scan_jaxpr(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
|
||||
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
|
||||
# { lambda ; a:foo[3,4]. let
|
||||
# b:foo[3,4] = scan[
|
||||
# jaxpr={ lambda ; c:foo[4]. let d:foo[4] = bake c in (d,) }
|
||||
# ] a
|
||||
# in (b,) }
|
||||
self.assertLen(jaxpr.invars, 1)
|
||||
a, = jaxpr.invars
|
||||
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
e, = jaxpr.eqns
|
||||
self.assertLen(e.outvars, 1)
|
||||
b, = e.outvars
|
||||
self.assertEqual(b.aval, core.ShapedArray((3, 4), FooTy()))
|
||||
|
||||
def test_scan_lowering(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
|
||||
_, out = jax.jit(f)(ks) # doesn't crash
|
||||
self.assertIsInstance(out, FooArray)
|
||||
self.assertEqual(out.shape, (3, 4))
|
||||
|
||||
def test_vmap(self):
|
||||
ks = jax.jit(lambda: make((3, 4, 5)))()
|
||||
ys = jax.vmap(jax.jit(lambda k: take(bake(k))))(ks)
|
||||
expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32')
|
||||
self.assertAllClose(ys, expected)
|
||||
|
||||
def test_transpose(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
ys = jax.jit(lambda x: x.T)(ks)
|
||||
self.assertIsInstance(ys, FooArray)
|
||||
self.assertEqual(ys.shape, (4, 3))
|
||||
|
||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user