internal rename: swap mentions of "custom eltypes" for "opaque dtypes"

Also, avoid direct set membership tests on `core.opaque_dtypes`. Update
callers to use `core.{is,has}_opaque_dtype` predicates instead.
This commit is contained in:
Roy Frostig 2022-08-30 14:47:15 -07:00
parent c26c7fddad
commit 8f045b12d6
14 changed files with 54 additions and 66 deletions

View File

@ -1107,9 +1107,9 @@ def _check_scalar(x):
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
_check_arg(x)
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"{name} with input element type {core.aval_eltype(aval).name}")
f"{name} with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
@ -1128,9 +1128,9 @@ _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"{name} with output element type {core.aval_eltype(aval).name}")
f"{name} with output element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
@ -1208,9 +1208,9 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
_check_arg(x)
aval = core.get_aval(x)
if core.aval_has_custom_eltype(aval):
if core.is_opaque_dtype(aval.dtype):
raise TypeError(
f"jacfwd with input element type {core.aval_eltype(aval).name}")
f"jacfwd with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
@ -2927,7 +2927,7 @@ class ShapeDtypeStruct:
__slots__ = ["shape", "dtype", "named_shape"]
def __init__(self, shape, dtype, named_shape=None):
self.shape = shape
self.dtype = dtype if core.is_custom_eltype(dtype) else np.dtype(dtype)
self.dtype = dtype if core.is_opaque_dtype(dtype) else np.dtype(dtype)
self.named_shape = {} if named_shape is None else dict(named_shape)
size = property(lambda self: prod(self.shape))

View File

@ -732,7 +732,7 @@ def array_result_handler(sticky_device: Optional[Device],
if aval.dtype == dtypes.float0:
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
aval = core.raise_to_shaped(aval)
if type(aval.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.result_handler(sticky_device, aval)
handler = lambda _, b: maybe_create_array_from_da(b, aval, sticky_device)
handler.args = aval, sticky_device # for C++ dispatch path in api.py

View File

@ -91,7 +91,7 @@ def to_complex_dtype(dtype):
@functools.lru_cache(maxsize=None)
def _canonicalize_dtype(x64_enabled, dtype):
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
if type(dtype) in jax.core.custom_eltypes:
if jax.core.is_opaque_dtype(dtype):
return dtype
try:
dtype = np.dtype(dtype)

View File

@ -1246,7 +1246,7 @@ def stop_gradient(x: T) -> T:
"""
def stop(x):
# only bind primitive on inexact dtypes, to avoid some staging
if core.has_custom_eltype(x):
if core.has_opaque_dtype(x):
return x
elif (dtypes.issubdtype(_dtype(x), np.floating) or
dtypes.issubdtype(_dtype(x), np.complexfloating)):
@ -2826,7 +2826,7 @@ def _broadcast_in_dim_partial_eval(
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.broadcast_in_dim_mlir(
ctx, x, *dyn_shape, shape=shape,
broadcast_dimensions=broadcast_dimensions)
@ -3310,7 +3310,7 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.transpose_mlir(ctx, x, permutation=permutation)
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
@ -4557,7 +4557,7 @@ def _check_same_dtypes(name, ignore_fp_precision, *ttypes):
"""Check that dtypes agree, possibly ignoring float precision."""
# the `ignore_fp_precision` flag exists because the XLA shape inference logic
# allows mixed floating point precision, but the HLO verifier often rejects it
if any(type(t) in core.custom_eltypes for t in ttypes):
if any(core.is_opaque_dtype(t) for t in ttypes):
return # TODO(mattjj,frostig): do some checking, friend
types = map(np.dtype, ttypes) # canonicalize
if ignore_fp_precision:
@ -4720,12 +4720,12 @@ def _check_user_dtype_supported(dtype, fun_name=None):
warnings.warn(msg.format(dtype, fun_name , truncated_dtype), stacklevel=3)
def empty(eltype):
return empty_p.bind(eltype=eltype)
def empty(dtype):
return empty_p.bind(dtype=dtype)
empty_p = core.Primitive('empty')
empty_p.def_abstract_eval(lambda *, eltype: core.ShapedArray((), eltype))
def _empty_lower(ctx, *, eltype):
if type(eltype) in core.custom_eltypes:
return eltype._rules.empty_mlir(ctx)
return mlir.ir_constants(np.zeros((), np.dtype(eltype)))
empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype))
def _empty_lower(ctx, *, dtype):
if core.is_opaque_dtype(dtype):
return dtype._rules.empty_mlir(ctx)
return mlir.ir_constants(np.zeros((), np.dtype(dtype)))
mlir.register_lowering(empty_p, _empty_lower)

View File

@ -803,7 +803,7 @@ batching.primitive_batchers[slice_p] = _slice_batching_rule
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
strides = strides or [1] * len(start_indices)
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.slice_mlir(
ctx, x, start_indices, limit_indices, strides)
return mhlo.SliceOp(x,
@ -904,7 +904,7 @@ batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_slice_mlir(
ctx, x, start_indices, slice_sizes)
return mhlo.DynamicSliceOp(x, start_indices,
@ -1003,7 +1003,7 @@ batching.primitive_batchers[dynamic_update_slice_p] = \
def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.dynamic_update_slice_mlir(
ctx, x, update, *start_indices)
return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update,
@ -1318,7 +1318,7 @@ def _gather_lower(ctx, operand, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval_out.dtype):
return aval_out.dtype._rules.gather_mlir(
ctx, operand, indices, dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes, unique_indices=unique_indices,

View File

@ -189,7 +189,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
# * unpack upfront into shape[0] many keyarray slices
# * return iter over these unpacked slices
# Whatever we do, we'll want to do it by overriding
# ShapedArray._iter when the eltype is KeyTy...
# ShapedArray._iter when the element type is KeyTy...
return (PRNGKeyArray(self.impl, k) for k in iter(self._base_array))
# TODO(frostig): are all of the stackable methods below (reshape,
@ -373,7 +373,7 @@ class KeyTyRules:
return PRNGKeyArray(aval.dtype.impl, phys_handler(bufs))
return handler
# eltype-polymorphic primitive lowering rules
# element-type-polymorphic primitive lowering rules
@staticmethod
def empty_mlir(ctx):
@ -476,7 +476,7 @@ class KeyTy:
return hash((self.__class__, self.impl))
core.custom_eltypes.add(KeyTy)
core.opaque_dtypes.add(KeyTy)
core.pytype_aval_mappings[PRNGKeyArray] = (

View File

@ -1192,33 +1192,24 @@ def concrete_or_error(force: Any, val: Any, context=""):
# TODO(frostig,mattjj): achieve this w/ a protocol instead of registry?
custom_eltypes: Set[Any] = set()
opaque_dtypes: Set[Any] = set()
# TODO(frostig): update inliners of the four functions below to call them
def has_custom_eltype(x: Any):
return aval_has_custom_eltype(get_aval(x))
def has_opaque_dtype(x: Any):
return is_opaque_dtype(get_aval(x).dtype)
def eltype(x: Any):
return aval_eltype(get_aval(x))
def aval_has_custom_eltype(aval: UnshapedArray):
return is_custom_eltype(aval.dtype)
def aval_eltype(aval: UnshapedArray):
return aval.dtype
def is_custom_eltype(eltype):
return type(eltype) in custom_eltypes
def is_opaque_dtype(dtype):
return type(dtype) in opaque_dtypes
def _short_dtype_name(dtype) -> str:
if type(dtype) in custom_eltypes:
if type(dtype) in opaque_dtypes:
return str(dtype)
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))
def _dtype_object(dtype):
return dtype if type(dtype) in custom_eltypes else np.dtype(dtype)
return dtype if type(dtype) in opaque_dtypes else np.dtype(dtype)
class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
@ -1422,13 +1413,12 @@ class ConcreteArray(ShapedArray):
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
# TODO(frostig,mattjj): rename to primal_eltype_to_tangent_eltype
def primal_dtype_to_tangent_dtype(primal_dtype):
# TODO(frostig,mattjj): determines that all custom eltypes have
# float0 tangent type, which works fine for all our current custom
# eltype applications. We may some day want to delegate this
# decision to the eltype.
if (type(primal_dtype) in custom_eltypes or
# TODO(frostig,mattjj): determines that all opaque dtypes have
# float0 tangent type, which works fine for all our current opaque
# dtype applications. We may some day want to delegate this
# decision to the dtype rules.
if (is_opaque_dtype(primal_dtype) or
not dtypes.issubdtype(primal_dtype, np.inexact)):
return dtypes.float0
else:

View File

@ -461,7 +461,7 @@ pxla.shard_arg_handlers[Array] = _array_shard_arg
def _array_global_result_handler(global_aval, out_sharding, committed):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
if core.aval_has_custom_eltype(global_aval):
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed)
return lambda bufs: Array(global_aval, out_sharding, bufs,
@ -472,7 +472,7 @@ pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambd
def _array_local_result_handler(aval, sharding, indices):
if core.aval_has_custom_eltype(aval):
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
else:

View File

@ -619,7 +619,7 @@ pxla.shard_arg_handlers[GlobalDeviceArray] = _gda_shard_arg
def _gda_array_result_handler(global_aval, out_sharding, committed):
if core.aval_has_custom_eltype(global_aval):
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed)
global_mesh, out_axis_resources = out_sharding.mesh, out_sharding.spec

View File

@ -778,7 +778,7 @@ def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
physical avals, but we don't support those here. Instead we assert
there is only one and return it.
"""
if type(aval.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval.dtype):
aval, = aval.dtype._rules.physical_avals(aval)
return aval
return aval
@ -881,8 +881,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
return tf_val, jax_dtype
# TODO(frostig,mattjj): rename dtype argument to eltype, for now just
# being consistent.
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
assert all(map(lambda x: x is not None, shape)), (
f"Argument shape should be a valid JAX shape but got {shape}")
@ -1845,10 +1843,10 @@ def _broadcast_in_dim(operand, *, shape, broadcast_dimensions,
tf_impl_with_avals[lax.broadcast_in_dim_p] = _broadcast_in_dim
def _empty(*, eltype):
if type(eltype) in core.custom_eltypes:
def _empty(*, dtype):
if core.is_opaque_dtype(dtype):
raise NotImplementedError # TODO(frostig,mattjj): jax2tf handlers
return tf.constant(np.array(0, dtype=eltype))
return tf.constant(np.array(0, dtype=dtype))
tf_impl[lax_internal.empty_p] = _empty

View File

@ -138,7 +138,7 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
) -> Sequence[ir.Type]:
if type(aval.dtype) in core.custom_eltypes:
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.aval_to_ir_types(aval)
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)

View File

@ -578,7 +578,7 @@ local_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRes
def sda_array_result_handler(aval: ShapedArray, sharding, indices):
sharding_spec = _get_sharding_specs([sharding], [aval])[0]
if core.aval_has_custom_eltype(aval):
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
else:

View File

@ -3038,7 +3038,7 @@ class FooTyRules:
return FooArray(aval.shape, buf)
return handler
# eltype-polymorphic primitive lowering rules
# element-type-polymorphic primitive lowering rules
@staticmethod
def empty_mlir(ctx):
@ -3194,7 +3194,7 @@ def bake_vmap(batched_args, batch_dims):
class CustomElementTypesTest(jtu.JaxTestCase):
def setUp(self):
core.custom_eltypes.add(FooTy)
core.opaque_dtypes.add(FooTy)
core.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
@ -3210,7 +3210,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
batching.primitive_batchers[bake_p] = bake_vmap
def tearDown(self):
core.custom_eltypes.remove(FooTy)
core.opaque_dtypes.remove(FooTy)
del core.pytype_aval_mappings[FooArray]
del xla.canonicalize_dtype_handlers[FooArray]
del xla.pytype_aval_mappings[FooArray]

View File

@ -1496,7 +1496,7 @@ class LaxRandomTest(jtu.JaxTestCase):
class KeyArrayTest(jtu.JaxTestCase):
# Key arrays involve:
# * a Python key array type, backed by an underlying uint32 "base" array,
# * an abstract shaped array with key eltype,
# * an abstract shaped array with key element type,
# * primitives that return or operate on such shaped arrays,
# * compiler lowerings,
# * a device-side data representation...
@ -1504,8 +1504,8 @@ class KeyArrayTest(jtu.JaxTestCase):
#
# A handful of these tests follow CustomElementTypesTest in
# lax_tests.py as an example. If you add a test here (e.g. testing
# lowering of an key-eltyped shaped array), consider whether it
# might also be a more general test of extended/custom eltypes. If
# lowering of an key-dtyped shaped array), consider whether it
# might also be a more general test of opaque element types. If
# so, add a corresponding test to to CustomElementTypesTest as well.
def make_keys(self, *shape, seed=None):