mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Internal change
PiperOrigin-RevId: 465705931
This commit is contained in:
parent
3f9c300521
commit
6b0c0dc321
@ -2938,7 +2938,7 @@ def _valid_jaxtype(arg):
|
||||
try:
|
||||
xla.abstractify(arg) # faster than core.get_aval
|
||||
except TypeError:
|
||||
return core.valid_jaxtype(arg)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
@ -435,11 +435,7 @@ def _shaped_abstractify_slow(x):
|
||||
|
||||
weak_type = getattr(x, 'weak_type', False)
|
||||
named_shape = getattr(x, 'named_shape', {})
|
||||
if hasattr(x, 'dtype'):
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype)
|
||||
else:
|
||||
dtype = dtypes.result_type(x) # TODO(frostig,mattjj): why this case?
|
||||
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
|
||||
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
|
||||
named_shape=named_shape)
|
||||
|
||||
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
|
||||
|
@ -651,8 +651,6 @@ def array_result_handler(sticky_device: Optional[Device],
|
||||
if aval.dtype == dtypes.float0:
|
||||
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
|
||||
aval = core.raise_to_shaped(aval)
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
return aval.dtype.result_handler(sticky_device, aval)
|
||||
handler = lambda _, b: _maybe_create_array_from_da(b, aval, sticky_device)
|
||||
handler.args = aval, sticky_device # for C++ dispatch path in api.py
|
||||
return handler
|
||||
|
@ -25,7 +25,6 @@ from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src.config import flags, config
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
@ -86,8 +85,6 @@ def _to_complex_dtype(dtype):
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _canonicalize_dtype(x64_enabled, dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
if type(dtype) in jax.core.custom_eltypes:
|
||||
return dtype
|
||||
try:
|
||||
dtype = np.dtype(dtype)
|
||||
except TypeError as e:
|
||||
|
@ -29,13 +29,11 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
tree_map)
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -234,8 +232,9 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
stacked_y = tree_map(stack, *maybe_reversed(ys))
|
||||
return carry, stacked_y
|
||||
|
||||
xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat]
|
||||
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
||||
x_shapes = [x.shape[1:] for x in xs_flat]
|
||||
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
|
||||
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
|
||||
|
||||
def _create_jaxpr(init):
|
||||
init_flat, init_tree = tree_flatten(init)
|
||||
@ -243,7 +242,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
|
||||
carry_avals = tuple(_map(_abstractify, init_flat))
|
||||
jaxpr, consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_tree, (*carry_avals, *x_avals), "scan")
|
||||
f, in_tree, carry_avals + x_avals, "scan")
|
||||
out_tree_children = out_tree.children()
|
||||
if len(out_tree_children) != 2:
|
||||
msg = "scan body output must be a pair, got {}."
|
||||
@ -415,7 +414,7 @@ def _index_array(i, aval, x):
|
||||
return slicing.index_in_dim(x, i, keepdims=False)
|
||||
|
||||
def _empty_array(sz, aval):
|
||||
return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape))
|
||||
return lax.full((sz,) + aval.shape, 0, aval.dtype)
|
||||
|
||||
def _update_array(i, aval, xs, x):
|
||||
return slicing.dynamic_update_index_in_dim(xs, x, i, 0)
|
||||
@ -969,28 +968,6 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry
|
||||
f'called with sequence of type\n{_avals_short(x_avals)}')
|
||||
return [*init_avals, *y_avals], jaxpr.effects
|
||||
|
||||
def _scan_pp_rule(eqn, context, settings):
|
||||
printed_params = dict(eqn.params)
|
||||
del printed_params['linear']
|
||||
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
|
||||
del printed_params['length']
|
||||
if printed_params['unroll'] == 1:
|
||||
del printed_params['unroll']
|
||||
if printed_params['num_carry'] == 0:
|
||||
del printed_params['num_carry']
|
||||
if printed_params['num_consts'] == 0:
|
||||
del printed_params['num_consts']
|
||||
if not printed_params['reverse']:
|
||||
del printed_params['reverse']
|
||||
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
|
||||
rhs = [pp.text(eqn.primitive.name),
|
||||
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
|
||||
pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if settings.source_info else None)
|
||||
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
|
||||
|
||||
|
||||
def scan_bind(*args, **params):
|
||||
if config.jax_enable_checks:
|
||||
avals = _map(core.get_aval, args)
|
||||
@ -1015,8 +992,6 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
|
||||
pe.padding_rules[scan_p] = _scan_padding_rule
|
||||
pe.dce_rules[scan_p] = _scan_dce_rule
|
||||
# TODO(mattjj,frostig): un-comment this pp rule
|
||||
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
|
||||
|
||||
### while_loop
|
||||
|
||||
@ -1401,32 +1376,38 @@ def _while_transpose_error(*_, **kwargs):
|
||||
"lax.while_loop or lax.fori_loop. "
|
||||
"Try using lax.scan instead.")
|
||||
|
||||
# For a while loop with ordered effects in the cond, we need a special
|
||||
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
|
||||
# this:
|
||||
# ```
|
||||
# while cond(x):
|
||||
# x = body(x)
|
||||
# ```
|
||||
# into something that looks like this:
|
||||
# ```
|
||||
# while True:
|
||||
# token, pred = cond(token, x)
|
||||
# if not pred:
|
||||
# break
|
||||
# token, x = body(token, x)
|
||||
# ```
|
||||
# Unfortunately, with an MHLO while we can't (1) return multiple values
|
||||
# from a `cond` and (2) can't break a while loop. We thus adopt the
|
||||
# following rewrite strategy:
|
||||
# ```
|
||||
# def new_cond(pred, token, x):
|
||||
# return pred
|
||||
# token, pred = cond(token, x)
|
||||
# while new_cond(pred, token, x):
|
||||
# token, x = body(token, x)
|
||||
# token, pred = cond(token, x)
|
||||
# ```
|
||||
while_p = core.AxisPrimitive('while')
|
||||
while_p.multiple_results = True
|
||||
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
||||
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
|
||||
ad.primitive_jvps[while_p] = _while_loop_jvp
|
||||
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
xla.register_initial_style_primitive(while_p)
|
||||
ad.primitive_transposes[while_p] = _while_transpose_error
|
||||
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
|
||||
|
||||
|
||||
def _pred_bcast_select_mhlo(
|
||||
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
|
||||
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
|
||||
if x_y_aval is core.abstract_token:
|
||||
x, = xs
|
||||
y, = ys
|
||||
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
|
||||
else:
|
||||
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
|
||||
x, = xs
|
||||
y, = ys
|
||||
assert x.type == y.type, (x.type, y.type)
|
||||
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
|
||||
pred_aval.shape, x_y_aval)
|
||||
bcast_pred = mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
|
||||
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
|
||||
return mhlo.SelectOp(bcast_pred, x, y).results
|
||||
|
||||
|
||||
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
body_nconsts):
|
||||
pred_aval = cond_jaxpr.out_avals[0]
|
||||
@ -1434,6 +1415,32 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
cond_ordered_effects = [eff for eff in cond_jaxpr.effects if eff in
|
||||
core.ordered_effects]
|
||||
if cond_ordered_effects:
|
||||
# For a while loop with ordered effects in the cond, we need a special
|
||||
# lowering. Fundamentally, we'd like to rewrite a while loop that looks like
|
||||
# this:
|
||||
# ```
|
||||
# while cond(x):
|
||||
# x = body(x)
|
||||
# ```
|
||||
# into something that looks like this:
|
||||
# ```
|
||||
# while True:
|
||||
# token, pred = cond(token, x)
|
||||
# if not pred:
|
||||
# break
|
||||
# token, x = body(token, x)
|
||||
# ```
|
||||
# Unfortunately, with an MHLO while we can't (1) return multiple values
|
||||
# from a `cond` and (2) can't break a while loop. We thus adopt the
|
||||
# following rewrite strategy:
|
||||
# ```
|
||||
# def new_cond(pred, token, x):
|
||||
# return pred
|
||||
# token, pred = cond(token, x)
|
||||
# while new_cond(pred, token, x):
|
||||
# token, x = body(token, x)
|
||||
# token, pred = cond(token, x)
|
||||
# ```
|
||||
def cond(args):
|
||||
return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
|
||||
def body(args):
|
||||
@ -1533,47 +1540,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
|
||||
return z
|
||||
|
||||
def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
body_nconsts):
|
||||
# TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types
|
||||
joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects)
|
||||
if joined_effects - allowed_effects:
|
||||
raise NotImplementedError(
|
||||
f'Effects not supported in `while`: {joined_effects - allowed_effects}')
|
||||
return body_jaxpr.out_avals, joined_effects
|
||||
|
||||
while_p = core.AxisPrimitive('while')
|
||||
while_p.multiple_results = True
|
||||
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
||||
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
|
||||
ad.primitive_jvps[while_p] = _while_loop_jvp
|
||||
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
xla.register_initial_style_primitive(while_p)
|
||||
ad.primitive_transposes[while_p] = _while_transpose_error
|
||||
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom
|
||||
mlir.register_lowering(while_p, _while_lowering)
|
||||
core.custom_typechecks[while_p] = _while_typecheck
|
||||
|
||||
|
||||
def _pred_bcast_select_mhlo(
|
||||
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
|
||||
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
|
||||
if x_y_aval is core.abstract_token:
|
||||
x, = xs
|
||||
y, = ys
|
||||
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
|
||||
else:
|
||||
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
|
||||
x, = xs
|
||||
y, = ys
|
||||
assert x.type == y.type, (x.type, y.type)
|
||||
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
|
||||
pred_aval.shape, x_y_aval)
|
||||
bcast_pred = mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(x_y_aval.update(dtype=np.dtype(np.bool_))),
|
||||
pred, mlir.dense_int_elements(list(range(len(pred_aval.shape))))).result
|
||||
return mhlo.SelectOp(bcast_pred, x, y).results
|
||||
|
||||
### fori_loop
|
||||
|
||||
|
@ -2808,10 +2808,6 @@ def _broadcast_in_dim_partial_eval(
|
||||
|
||||
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.broadcast_in_dim_mlir(
|
||||
ctx, x, *dyn_shape, shape=shape,
|
||||
broadcast_dimensions=broadcast_dimensions)
|
||||
if dyn_shape:
|
||||
shape = _merge_dyn_shape(shape, dyn_shape)
|
||||
return mhlo.DynamicBroadcastInDimOp(
|
||||
@ -3292,8 +3288,6 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
|
||||
|
||||
def _transpose_lower(ctx, x, *, permutation):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.transpose_mlir(ctx, x, permutation=permutation)
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
|
||||
|
||||
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
@ -4538,8 +4532,6 @@ def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
|
||||
"""Check that dtypes agree, possibly ignoring float precision."""
|
||||
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
||||
# allows mixed floating point precision, but the HLO verifier often rejects it
|
||||
if any(type(t) in core.custom_eltypes for t in ttypes):
|
||||
return # TODO(mattjj,frostig): do some checking, friend
|
||||
types = map(np.dtype, ttypes) # canonicalize
|
||||
if ignore_fp_precision:
|
||||
types = [
|
||||
@ -4699,14 +4691,3 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
fun_name = f"requested in {fun_name}" if fun_name else ""
|
||||
truncated_dtype = dtypes.canonicalize_dtype(dtype).name
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=3)
|
||||
|
||||
|
||||
def empty(eltype):
|
||||
return empty_p.bind(eltype=eltype)
|
||||
empty_p = core.Primitive('empty')
|
||||
empty_p.def_abstract_eval(lambda *, eltype: core.ShapedArray((), eltype))
|
||||
def _empty_lower(ctx, *, eltype):
|
||||
if type(eltype) in core.custom_eltypes:
|
||||
return eltype.empty_mlir(ctx)
|
||||
return mlir.ir_constants(np.empty((), np.dtype(eltype)))
|
||||
mlir.register_lowering(empty_p, _empty_lower)
|
||||
|
@ -897,9 +897,6 @@ ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
|
||||
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
|
||||
|
||||
def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.dynamic_slice_mlir(ctx, x, start_indices, slice_sizes)
|
||||
return mhlo.DynamicSliceOp(x, start_indices,
|
||||
mlir.dense_int_elements(slice_sizes)).results
|
||||
|
||||
@ -996,9 +993,6 @@ batching.primitive_batchers[dynamic_update_slice_p] = \
|
||||
|
||||
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.dynamic_update_slice_mlir(
|
||||
ctx, x, update, *start_indices)
|
||||
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
||||
start_indices).results
|
||||
|
||||
|
24
jax/core.py
24
jax/core.py
@ -1132,7 +1132,7 @@ def lattice_join(x: Optional[AbstractValue],
|
||||
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
|
||||
Value = Any
|
||||
|
||||
def valid_jaxtype(x) -> bool:
|
||||
def valid_jaxtype(x):
|
||||
try:
|
||||
concrete_aval(x)
|
||||
except TypeError:
|
||||
@ -1187,25 +1187,16 @@ def concrete_or_error(force: Any, val: Any, context=""):
|
||||
return force(val)
|
||||
|
||||
|
||||
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
|
||||
custom_eltypes: Set[Any] = set()
|
||||
|
||||
def _short_dtype_name(dtype) -> str:
|
||||
if type(dtype) in custom_eltypes:
|
||||
return str(dtype)
|
||||
else:
|
||||
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
|
||||
.replace('int' , 'i').replace('complex', 'c'))
|
||||
|
||||
def _dtype_object(dtype):
|
||||
return dtype if type(dtype) in custom_eltypes else np.dtype(dtype)
|
||||
def _short_dtype_name(dtype):
|
||||
return (dtype.name.replace('float', 'f').replace('uint', 'u')
|
||||
.replace('int', 'i').replace('complex', 'c'))
|
||||
|
||||
class UnshapedArray(AbstractValue):
|
||||
__slots__ = ['dtype', 'weak_type']
|
||||
array_abstraction_level = 4
|
||||
|
||||
def __init__(self, dtype, weak_type=False):
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.dtype = np.dtype(dtype)
|
||||
self.weak_type = weak_type
|
||||
|
||||
def update(self, dtype=None, weak_type=None):
|
||||
@ -1273,7 +1264,7 @@ class ShapedArray(UnshapedArray):
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.dtype = np.dtype(dtype)
|
||||
self.weak_type = weak_type
|
||||
self.named_shape = {} if named_shape is None else dict(named_shape)
|
||||
|
||||
@ -1894,9 +1885,6 @@ class NamedShape:
|
||||
return total
|
||||
|
||||
def __str__(self):
|
||||
# TODO(mattjj,frostig): revise not to miss commas
|
||||
if not self.__named:
|
||||
return str(self.__positional)
|
||||
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
|
||||
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")
|
||||
|
||||
|
@ -1791,15 +1791,6 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
|
||||
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
|
||||
|
||||
|
||||
def _empty(*, eltype):
|
||||
if type(eltype) in core.custom_eltypes:
|
||||
raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers
|
||||
return tf.constant(np.array(0, dtype=eltype))
|
||||
|
||||
|
||||
tf_impl[lax_internal.empty_p] = _empty
|
||||
|
||||
|
||||
def _reshape(operand, *, new_sizes, dimensions):
|
||||
if dimensions is None:
|
||||
dimensions = tf.range(tf.rank(operand))
|
||||
|
@ -917,7 +917,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Argument 'b' .*DimPoly.*not a valid JAX type"):
|
||||
"Argument 'b' of type <class 'jax.experimental.jax2tf.shape_poly._DimPolynomial'> is not a valid JAX type"):
|
||||
jax2tf.convert(lambda x: jnp.prod(x.shape),
|
||||
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
||||
|
||||
|
@ -129,8 +129,6 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
|
||||
|
||||
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
|
||||
) -> Sequence[ir.Type]:
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
return aval.dtype.aval_to_ir_types(aval)
|
||||
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
|
||||
|
||||
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
||||
|
@ -11,16 +11,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import collections
|
||||
from functools import partial
|
||||
import itertools
|
||||
import operator
|
||||
import types
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
from typing import Tuple
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -35,17 +33,13 @@ from jax.test_util import check_grads
|
||||
from jax import tree_util
|
||||
import jax.util
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import lax_reference
|
||||
from jax._src.util import prod
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -2976,288 +2970,5 @@ class LaxNamedShapeTest(jtu.JaxTestCase):
|
||||
(out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
|
||||
class FooTy:
|
||||
name = 'foo'
|
||||
def __hash__(self) -> int:
|
||||
return hash(FooTy)
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(other) is FooTy
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
__str__ = __repr__
|
||||
|
||||
# handlers
|
||||
|
||||
@staticmethod
|
||||
def aval_to_ir_types(aval):
|
||||
aval2 = core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))
|
||||
return mlir.aval_to_ir_types(aval2)
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
return FooArray(aval.shape, buf)
|
||||
return handler
|
||||
|
||||
# eltype-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def empty_mlir(ctx):
|
||||
return mlir.ir_constants(np.empty((2,), dtype=np.dtype('uint32')))
|
||||
|
||||
@staticmethod
|
||||
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
|
||||
slice_sizes_ = mlir.dense_int_elements((*slice_sizes, 2))
|
||||
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).results
|
||||
|
||||
@staticmethod
|
||||
def dynamic_update_slice_mlir(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.avals_out
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
|
||||
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
||||
start_indices).results
|
||||
|
||||
@staticmethod
|
||||
def broadcast_in_dim_mlir(ctx, x, *dyn_shape, shape, broadcast_dimensions):
|
||||
if dyn_shape: raise NotImplementedError
|
||||
aval_out, = ctx.avals_out
|
||||
broadcast_dimensions = [*broadcast_dimensions, aval_out.ndim]
|
||||
return mhlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(aval_out), x,
|
||||
mlir.dense_int_elements(broadcast_dimensions)).results
|
||||
|
||||
@staticmethod
|
||||
def transpose_mlir(ctx, x, *, permutation):
|
||||
perm = [*permutation, len(permutation)]
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).results
|
||||
|
||||
# primitives
|
||||
|
||||
make_p = core.Primitive('make')
|
||||
bake_p = core.Primitive('bake')
|
||||
take_p = core.Primitive('take')
|
||||
|
||||
def make(shape): return make_p.bind(shape=tuple(shape))
|
||||
def bake(k): return bake_p.bind(k)
|
||||
def take(k): return take_p.bind(k)
|
||||
|
||||
@make_p.def_abstract_eval
|
||||
def make_abstract_eval(*, shape):
|
||||
return core.ShapedArray(shape, FooTy())
|
||||
|
||||
@bake_p.def_abstract_eval
|
||||
def bake_abstract_eval(x):
|
||||
if type(x.dtype) != FooTy: raise TypeError
|
||||
return core.ShapedArray(tuple(reversed(x.shape)), FooTy())
|
||||
|
||||
@take_p.def_abstract_eval
|
||||
def take_abstract_eval(x):
|
||||
return core.ShapedArray(x.shape, jnp.dtype('float32'))
|
||||
|
||||
# runtime ('outside jit') data types
|
||||
|
||||
class FooArray:
|
||||
shape: Tuple[int, ...]
|
||||
data: jnp.ndarray
|
||||
|
||||
def __init__(self, shape, data):
|
||||
assert data.shape == (*shape, 2)
|
||||
self.shape = shape
|
||||
self.data = data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
shape = ','.join(map(str, self.shape))
|
||||
return f'foo[{shape}] with value\n{self.data}'
|
||||
|
||||
size = property(lambda self: self.data.size // 2)
|
||||
ndim = property(lambda self: self.data.ndim - 1)
|
||||
|
||||
def device_put_foo_array(x: FooArray, device):
|
||||
return dispatch._device_put_array(x.data, device)
|
||||
|
||||
def foo_array_constant_handler(x, c):
|
||||
return mlir._device_array_constant_handler(x.data, c)
|
||||
|
||||
def make_lowering(*, shape):
|
||||
return jnp.zeros((*shape, 2), 'uint32')
|
||||
|
||||
def bake_lowering(k):
|
||||
return k.T
|
||||
|
||||
def take_lowering(k):
|
||||
return jnp.broadcast_to(jnp.float32(k.size), k.shape)
|
||||
|
||||
|
||||
def bake_vmap(batched_args, batch_dims):
|
||||
xs, = batched_args
|
||||
bdim_in, = batch_dims
|
||||
ys = bake(xs)
|
||||
perm = list(reversed(range(xs.ndim)))
|
||||
bdim_out = perm[bdim_in]
|
||||
return ys, bdim_out
|
||||
|
||||
|
||||
class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
core.custom_eltypes.add(FooTy)
|
||||
core.pytype_aval_mappings[FooArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, FooTy())
|
||||
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
|
||||
xla.pytype_aval_mappings[FooArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, FooTy())
|
||||
dispatch.device_put_handlers[FooArray] = device_put_foo_array
|
||||
mlir._constant_handlers[FooArray] = foo_array_constant_handler
|
||||
mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False))
|
||||
mlir.register_lowering(bake_p, mlir.lower_fun(bake_lowering, False))
|
||||
mlir.register_lowering(take_p, mlir.lower_fun(take_lowering, False))
|
||||
batching.defvectorized(take_p)
|
||||
batching.primitive_batchers[bake_p] = bake_vmap
|
||||
|
||||
def tearDown(self):
|
||||
core.custom_eltypes.remove(FooTy)
|
||||
del core.pytype_aval_mappings[FooArray]
|
||||
del xla.canonicalize_dtype_handlers[FooArray]
|
||||
del xla.pytype_aval_mappings[FooArray]
|
||||
del dispatch.device_put_handlers[FooArray]
|
||||
del mlir._constant_handlers[FooArray]
|
||||
del mlir._lowerings[make_p]
|
||||
del mlir._lowerings[bake_p]
|
||||
del mlir._lowerings[take_p]
|
||||
del batching.primitive_batchers[take_p]
|
||||
del batching.primitive_batchers[bake_p]
|
||||
|
||||
def test_shaped_array_construction(self):
|
||||
aval = core.ShapedArray((), FooTy())
|
||||
self.assertEqual(aval.str_short(), 'foo[]')
|
||||
aval = core.ShapedArray((3, 4), FooTy())
|
||||
self.assertEqual(aval.str_short(), 'foo[3,4]')
|
||||
|
||||
def test_make_jaxpr_identity(self):
|
||||
x = types.SimpleNamespace(shape=(3,), dtype=FooTy())
|
||||
jaxpr = jax.make_jaxpr(lambda x: x)(x).jaxpr
|
||||
# { lambda ; a:foo[3]. let in (a,) }
|
||||
self.assertLen(jaxpr.invars, 1)
|
||||
a, = jaxpr.invars
|
||||
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
|
||||
self.assertLen(jaxpr.outvars, 1)
|
||||
a, = jaxpr.outvars
|
||||
self.assertEqual(a.aval, core.ShapedArray((3,), FooTy()))
|
||||
|
||||
# tests after here need the primitives
|
||||
|
||||
def test_make_jaxpr_with_primitives(self):
|
||||
def f():
|
||||
k1 = make((3, 4))
|
||||
k2 = bake(k1)
|
||||
x = take(k2)
|
||||
return x
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)().jaxpr
|
||||
# { lambda ; . let
|
||||
# a:foo[3,4] = make[shape=(3, 4)]
|
||||
# b:foo[4,3] = bake a
|
||||
# c:f32[4,3] = take b
|
||||
# in (c,) }
|
||||
self.assertLen(jaxpr.invars, 0)
|
||||
self.assertLen(jaxpr.eqns, 3)
|
||||
e1, e2, e3 = jaxpr.eqns
|
||||
|
||||
self.assertIs(e1.primitive, make_p)
|
||||
self.assertLen(e1.outvars, 1)
|
||||
a, = e1.outvars
|
||||
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
|
||||
|
||||
self.assertIs(e2.primitive, bake_p)
|
||||
self.assertLen(e2.outvars, 1)
|
||||
b, = e2.outvars
|
||||
self.assertEqual(b.aval, core.ShapedArray((4, 3), FooTy()))
|
||||
|
||||
self.assertIs(e3.primitive, take_p)
|
||||
self.assertLen(e3.outvars, 1)
|
||||
c, = e3.outvars
|
||||
self.assertEqual(c.aval, core.ShapedArray((4, 3), np.dtype('float32')))
|
||||
|
||||
# tests after here need FooArray and lowerings
|
||||
|
||||
def test_jit_closure(self):
|
||||
k = FooArray((), jnp.arange(2, dtype='uint32'))
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
jnp.add(1, 1) # make jit not hit trivial dispatch path
|
||||
return k
|
||||
|
||||
y = f() # doesn't crash
|
||||
self.assertIsInstance(y, FooArray)
|
||||
self.assertEqual(y.shape, ())
|
||||
|
||||
def test_jit_identity(self):
|
||||
k = FooArray((), jnp.arange(2, dtype='uint32'))
|
||||
|
||||
@jax.jit
|
||||
def f(k):
|
||||
jnp.add(1, 1) # make jit not hit trivial dispatch path
|
||||
return k
|
||||
|
||||
y = f(k) # doesn't crash
|
||||
self.assertIsInstance(y, FooArray)
|
||||
self.assertEqual(y.shape, ())
|
||||
|
||||
def test_jit_multiple_primitives(self):
|
||||
@jax.jit
|
||||
def f():
|
||||
k1 = make((3,))
|
||||
k2 = bake(k1)
|
||||
y = take(k2)
|
||||
return y
|
||||
|
||||
y = f()
|
||||
self.assertArraysAllClose(y, jnp.array([3., 3., 3.]), check_dtypes=False)
|
||||
|
||||
def test_scan_jaxpr(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
|
||||
jaxpr = jax.make_jaxpr(f)(ks).jaxpr
|
||||
# { lambda ; a:foo[3,4]. let
|
||||
# b:foo[3,4] = scan[
|
||||
# jaxpr={ lambda ; c:foo[4]. let d:foo[4] = bake c in (d,) }
|
||||
# ] a
|
||||
# in (b,) }
|
||||
self.assertLen(jaxpr.invars, 1)
|
||||
a, = jaxpr.invars
|
||||
self.assertEqual(a.aval, core.ShapedArray((3, 4), FooTy()))
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
e, = jaxpr.eqns
|
||||
self.assertLen(e.outvars, 1)
|
||||
b, = e.outvars
|
||||
self.assertEqual(b.aval, core.ShapedArray((3, 4), FooTy()))
|
||||
|
||||
def test_scan_lowering(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
f = lambda ks: jax.lax.scan(lambda _, k: (None, bake(k)), None, ks)
|
||||
_, out = jax.jit(f)(ks) # doesn't crash
|
||||
self.assertIsInstance(out, FooArray)
|
||||
self.assertEqual(out.shape, (3, 4))
|
||||
|
||||
def test_vmap(self):
|
||||
ks = jax.jit(lambda: make((3, 4, 5)))()
|
||||
ys = jax.vmap(jax.jit(lambda k: take(bake(k))))(ks)
|
||||
expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32')
|
||||
self.assertAllClose(ys, expected)
|
||||
|
||||
def test_transpose(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
ys = jax.jit(lambda x: x.T)(ks)
|
||||
self.assertIsInstance(ys, FooArray)
|
||||
self.assertEqual(ys.shape, (4, 3))
|
||||
|
||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user