mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
internal rename: swap mentions of "custom eltypes" for "opaque dtypes"
Also, avoid direct set membership tests on `core.opaque_dtypes`. Update callers to use `core.{is,has}_opaque_dtype` predicates instead.
This commit is contained in:
parent
c26c7fddad
commit
8f045b12d6
@ -1107,9 +1107,9 @@ def _check_scalar(x):
|
||||
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
raise TypeError(
|
||||
f"{name} with input element type {core.aval_eltype(aval).name}")
|
||||
f"{name} with input element type {aval.dtype.name}")
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
|
||||
@ -1128,9 +1128,9 @@ _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
|
||||
|
||||
def _check_output_dtype_revderiv(name, holomorphic, x):
|
||||
aval = core.get_aval(x)
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
raise TypeError(
|
||||
f"{name} with output element type {core.aval_eltype(aval).name}")
|
||||
f"{name} with output element type {aval.dtype.name}")
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
|
||||
@ -1208,9 +1208,9 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
raise TypeError(
|
||||
f"jacfwd with input element type {core.aval_eltype(aval).name}")
|
||||
f"jacfwd with input element type {aval.dtype.name}")
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
|
||||
@ -2927,7 +2927,7 @@ class ShapeDtypeStruct:
|
||||
__slots__ = ["shape", "dtype", "named_shape"]
|
||||
def __init__(self, shape, dtype, named_shape=None):
|
||||
self.shape = shape
|
||||
self.dtype = dtype if core.is_custom_eltype(dtype) else np.dtype(dtype)
|
||||
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
|
||||
self.named_shape = {} if named_shape is None else dict(named_shape)
|
||||
|
||||
size = property(lambda self: prod(self.shape))
|
||||
|
@ -732,7 +732,7 @@ def array_result_handler(sticky_device: Optional[Device],
|
||||
if aval.dtype == dtypes.float0:
|
||||
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
|
||||
aval = core.raise_to_shaped(aval)
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.result_handler(sticky_device, aval)
|
||||
handler = lambda _, b: maybe_create_array_from_da(b, aval, sticky_device)
|
||||
handler.args = aval, sticky_device # for C++ dispatch path in api.py
|
||||
|
@ -91,7 +91,7 @@ def to_complex_dtype(dtype):
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _canonicalize_dtype(x64_enabled, dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
if type(dtype) in jax.core.custom_eltypes:
|
||||
if jax.core.is_opaque_dtype(dtype):
|
||||
return dtype
|
||||
try:
|
||||
dtype = np.dtype(dtype)
|
||||
|
@ -1246,7 +1246,7 @@ def stop_gradient(x: T) -> T:
|
||||
"""
|
||||
def stop(x):
|
||||
# only bind primitive on inexact dtypes, to avoid some staging
|
||||
if core.has_custom_eltype(x):
|
||||
if core.has_opaque_dtype(x):
|
||||
return x
|
||||
elif (dtypes.issubdtype(_dtype(x), np.floating) or
|
||||
dtypes.issubdtype(_dtype(x), np.complexfloating)):
|
||||
@ -2826,7 +2826,7 @@ def _broadcast_in_dim_partial_eval(
|
||||
|
||||
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
return aval_out.dtype._rules.broadcast_in_dim_mlir(
|
||||
ctx, x, *dyn_shape, shape=shape,
|
||||
broadcast_dimensions=broadcast_dimensions)
|
||||
@ -3310,7 +3310,7 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
|
||||
|
||||
def _transpose_lower(ctx, x, *, permutation):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
return aval_out.dtype._rules.transpose_mlir(ctx, x, permutation=permutation)
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
|
||||
|
||||
@ -4557,7 +4557,7 @@ def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
|
||||
"""Check that dtypes agree, possibly ignoring float precision."""
|
||||
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
|
||||
# allows mixed floating point precision, but the HLO verifier often rejects it
|
||||
if any(type(t) in core.custom_eltypes for t in ttypes):
|
||||
if any(core.is_opaque_dtype(t) for t in ttypes):
|
||||
return # TODO(mattjj,frostig): do some checking, friend
|
||||
types = map(np.dtype, ttypes) # canonicalize
|
||||
if ignore_fp_precision:
|
||||
@ -4720,12 +4720,12 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=3)
|
||||
|
||||
|
||||
def empty(eltype):
|
||||
return empty_p.bind(eltype=eltype)
|
||||
def empty(dtype):
|
||||
return empty_p.bind(dtype=dtype)
|
||||
empty_p = core.Primitive('empty')
|
||||
empty_p.def_abstract_eval(lambda *, eltype: core.ShapedArray((), eltype))
|
||||
def _empty_lower(ctx, *, eltype):
|
||||
if type(eltype) in core.custom_eltypes:
|
||||
return eltype._rules.empty_mlir(ctx)
|
||||
return mlir.ir_constants(np.zeros((), np.dtype(eltype)))
|
||||
empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype))
|
||||
def _empty_lower(ctx, *, dtype):
|
||||
if core.is_opaque_dtype(dtype):
|
||||
return dtype._rules.empty_mlir(ctx)
|
||||
return mlir.ir_constants(np.zeros((), np.dtype(dtype)))
|
||||
mlir.register_lowering(empty_p, _empty_lower)
|
||||
|
@ -803,7 +803,7 @@ batching.primitive_batchers[slice_p] = _slice_batching_rule
|
||||
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
|
||||
strides = strides or [1] * len(start_indices)
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
return aval_out.dtype._rules.slice_mlir(
|
||||
ctx, x, start_indices, limit_indices, strides)
|
||||
return mhlo.SliceOp(x,
|
||||
@ -904,7 +904,7 @@ batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
|
||||
|
||||
def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
return aval_out.dtype._rules.dynamic_slice_mlir(
|
||||
ctx, x, start_indices, slice_sizes)
|
||||
return mhlo.DynamicSliceOp(x, start_indices,
|
||||
@ -1003,7 +1003,7 @@ batching.primitive_batchers[dynamic_update_slice_p] = \
|
||||
|
||||
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
return aval_out.dtype._rules.dynamic_update_slice_mlir(
|
||||
ctx, x, update, *start_indices)
|
||||
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
|
||||
@ -1318,7 +1318,7 @@ def _gather_lower(ctx, operand, indices, *,
|
||||
dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
return aval_out.dtype._rules.gather_mlir(
|
||||
ctx, operand, indices, dimension_numbers=dimension_numbers,
|
||||
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
||||
|
@ -189,7 +189,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
# * unpack upfront into shape[0] many keyarray slices
|
||||
# * return iter over these unpacked slices
|
||||
# Whatever we do, we'll want to do it by overriding
|
||||
# ShapedArray._iter when the eltype is KeyTy...
|
||||
# ShapedArray._iter when the element type is KeyTy...
|
||||
return (PRNGKeyArray(self.impl, k) for k in iter(self._base_array))
|
||||
|
||||
# TODO(frostig): are all of the stackable methods below (reshape,
|
||||
@ -373,7 +373,7 @@ class KeyTyRules:
|
||||
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
|
||||
return handler
|
||||
|
||||
# eltype-polymorphic primitive lowering rules
|
||||
# element-type-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def empty_mlir(ctx):
|
||||
@ -476,7 +476,7 @@ class KeyTy:
|
||||
return hash((self.__class__, self.impl))
|
||||
|
||||
|
||||
core.custom_eltypes.add(KeyTy)
|
||||
core.opaque_dtypes.add(KeyTy)
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArray] = (
|
||||
|
34
jax/core.py
34
jax/core.py
@ -1192,33 +1192,24 @@ def concrete_or_error(force: Any, val: Any, context=""):
|
||||
|
||||
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
|
||||
|
||||
custom_eltypes: Set[Any] = set()
|
||||
opaque_dtypes: Set[Any] = set()
|
||||
|
||||
# TODO(frostig): update inliners of the four functions below to call them
|
||||
def has_custom_eltype(x: Any):
|
||||
return aval_has_custom_eltype(get_aval(x))
|
||||
def has_opaque_dtype(x: Any):
|
||||
return is_opaque_dtype(get_aval(x).dtype)
|
||||
|
||||
def eltype(x: Any):
|
||||
return aval_eltype(get_aval(x))
|
||||
|
||||
def aval_has_custom_eltype(aval: UnshapedArray):
|
||||
return is_custom_eltype(aval.dtype)
|
||||
|
||||
def aval_eltype(aval: UnshapedArray):
|
||||
return aval.dtype
|
||||
|
||||
def is_custom_eltype(eltype):
|
||||
return type(eltype) in custom_eltypes
|
||||
def is_opaque_dtype(dtype):
|
||||
return type(dtype) in opaque_dtypes
|
||||
|
||||
def _short_dtype_name(dtype) -> str:
|
||||
if type(dtype) in custom_eltypes:
|
||||
if type(dtype) in opaque_dtypes:
|
||||
return str(dtype)
|
||||
else:
|
||||
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
|
||||
.replace('int' , 'i').replace('complex', 'c'))
|
||||
|
||||
def _dtype_object(dtype):
|
||||
return dtype if type(dtype) in custom_eltypes else np.dtype(dtype)
|
||||
return dtype if type(dtype) in opaque_dtypes else np.dtype(dtype)
|
||||
|
||||
class UnshapedArray(AbstractValue):
|
||||
__slots__ = ['dtype', 'weak_type']
|
||||
@ -1422,13 +1413,12 @@ class ConcreteArray(ShapedArray):
|
||||
_float = concretization_function_error(float, True)
|
||||
_complex = concretization_function_error(complex, True)
|
||||
|
||||
# TODO(frostig,mattjj): rename to primal_eltype_to_tangent_eltype
|
||||
def primal_dtype_to_tangent_dtype(primal_dtype):
|
||||
# TODO(frostig,mattjj): determines that all custom eltypes have
|
||||
# float0 tangent type, which works fine for all our current custom
|
||||
# eltype applications. We may some day want to delegate this
|
||||
# decision to the eltype.
|
||||
if (type(primal_dtype) in custom_eltypes or
|
||||
# TODO(frostig,mattjj): determines that all opaque dtypes have
|
||||
# float0 tangent type, which works fine for all our current opaque
|
||||
# dtype applications. We may some day want to delegate this
|
||||
# decision to the dtype rules.
|
||||
if (is_opaque_dtype(primal_dtype) or
|
||||
not dtypes.issubdtype(primal_dtype, np.inexact)):
|
||||
return dtypes.float0
|
||||
else:
|
||||
|
@ -461,7 +461,7 @@ pxla.shard_arg_handlers[Array] = _array_shard_arg
|
||||
def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
if global_aval.dtype == dtypes.float0:
|
||||
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
|
||||
if core.aval_has_custom_eltype(global_aval):
|
||||
if core.is_opaque_dtype(global_aval.dtype):
|
||||
return global_aval.dtype._rules.global_sharded_result_handler(
|
||||
global_aval, out_sharding, committed)
|
||||
return lambda bufs: Array(global_aval, out_sharding, bufs,
|
||||
@ -472,7 +472,7 @@ pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambd
|
||||
|
||||
|
||||
def _array_local_result_handler(aval, sharding, indices):
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.local_sharded_result_handler(
|
||||
aval, sharding, indices)
|
||||
else:
|
||||
|
@ -619,7 +619,7 @@ pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
|
||||
|
||||
|
||||
def _gda_array_result_handler(global_aval, out_sharding, committed):
|
||||
if core.aval_has_custom_eltype(global_aval):
|
||||
if core.is_opaque_dtype(global_aval.dtype):
|
||||
return global_aval.dtype._rules.global_sharded_result_handler(
|
||||
global_aval, out_sharding, committed)
|
||||
global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec
|
||||
|
@ -778,7 +778,7 @@ def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
|
||||
physical avals, but we don't support those here. Instead we assert
|
||||
there is only one and return it.
|
||||
"""
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
aval, = aval.dtype._rules.physical_avals(aval)
|
||||
return aval
|
||||
return aval
|
||||
@ -881,8 +881,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
return tf_val, jax_dtype
|
||||
|
||||
|
||||
# TODO(frostig,mattjj): rename dtype argument to eltype, for now just
|
||||
# being consistent.
|
||||
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
|
||||
assert all(map(lambda x: x is not None, shape)), (
|
||||
f"Argument shape should be a valid JAX shape but got {shape}")
|
||||
@ -1845,10 +1843,10 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
|
||||
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
|
||||
|
||||
|
||||
def _empty(*, eltype):
|
||||
if type(eltype) in core.custom_eltypes:
|
||||
def _empty(*, dtype):
|
||||
if core.is_opaque_dtype(dtype):
|
||||
raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers
|
||||
return tf.constant(np.array(0, dtype=eltype))
|
||||
return tf.constant(np.array(0, dtype=dtype))
|
||||
|
||||
|
||||
tf_impl[lax_internal.empty_p] = _empty
|
||||
|
@ -138,7 +138,7 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
|
||||
|
||||
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
|
||||
) -> Sequence[ir.Type]:
|
||||
if type(aval.dtype) in core.custom_eltypes:
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.aval_to_ir_types(aval)
|
||||
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
|
||||
|
||||
|
@ -578,7 +578,7 @@ local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRes
|
||||
|
||||
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
|
||||
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
|
||||
if core.aval_has_custom_eltype(aval):
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.local_sharded_result_handler(
|
||||
aval, sharding, indices)
|
||||
else:
|
||||
|
@ -3038,7 +3038,7 @@ class FooTyRules:
|
||||
return FooArray(aval.shape, buf)
|
||||
return handler
|
||||
|
||||
# eltype-polymorphic primitive lowering rules
|
||||
# element-type-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
def empty_mlir(ctx):
|
||||
@ -3194,7 +3194,7 @@ def bake_vmap(batched_args, batch_dims):
|
||||
class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
core.custom_eltypes.add(FooTy)
|
||||
core.opaque_dtypes.add(FooTy)
|
||||
core.pytype_aval_mappings[FooArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, FooTy())
|
||||
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
|
||||
@ -3210,7 +3210,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
batching.primitive_batchers[bake_p] = bake_vmap
|
||||
|
||||
def tearDown(self):
|
||||
core.custom_eltypes.remove(FooTy)
|
||||
core.opaque_dtypes.remove(FooTy)
|
||||
del core.pytype_aval_mappings[FooArray]
|
||||
del xla.canonicalize_dtype_handlers[FooArray]
|
||||
del xla.pytype_aval_mappings[FooArray]
|
||||
|
@ -1496,7 +1496,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
class KeyArrayTest(jtu.JaxTestCase):
|
||||
# Key arrays involve:
|
||||
# * a Python key array type, backed by an underlying uint32 "base" array,
|
||||
# * an abstract shaped array with key eltype,
|
||||
# * an abstract shaped array with key element type,
|
||||
# * primitives that return or operate on such shaped arrays,
|
||||
# * compiler lowerings,
|
||||
# * a device-side data representation...
|
||||
@ -1504,8 +1504,8 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
#
|
||||
# A handful of these tests follow CustomElementTypesTest in
|
||||
# lax_tests.py as an example. If you add a test here (e.g. testing
|
||||
# lowering of an key-eltyped shaped array), consider whether it
|
||||
# might also be a more general test of extended/custom eltypes. If
|
||||
# lowering of an key-dtyped shaped array), consider whether it
|
||||
# might also be a more general test of opaque element types. If
|
||||
# so, add a corresponding test to to CustomElementTypesTest as well.
|
||||
|
||||
def make_keys(self, *shape, seed=None):
|
||||
|
Loading…
x
Reference in New Issue
Block a user