Merge pull request #23823 from mattjj:simplify-extended-dtype-convert-logic

PiperOrigin-RevId: 678456216
This commit is contained in:
jax authors 2024-09-24 17:29:32 -07:00
commit cfb4e85fcd
8 changed files with 295 additions and 107 deletions

View File

@ -1608,12 +1608,6 @@ def physical_element_aval(edtype: dtypes.ExtendedDType) -> ShapedArray:
duck = edtype._rules.physical_element_aval(edtype) # type: ignore
return ShapedArray(duck.shape, dtypes.dtype(duck.dtype))
def _short_dtype_name(dtype) -> str:
if isinstance(dtype, dtypes.ExtendedDType):
return str(dtype)
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))
def _dtype_object(dtype):
return dtype if isinstance(dtype, dtypes.ExtendedDType) else np.dtype(dtype)
@ -1672,7 +1666,7 @@ class UnshapedArray(AbstractValue):
raise TypeError(self, other)
def str_short(self, short_dtypes=False) -> str:
return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
def strip_weak_type(self):
"""Returns a copy of the aval with weak_type=False."""
@ -1811,7 +1805,7 @@ class ShapedArray(UnshapedArray):
raise TypeError(self, other)
def str_short(self, short_dtypes=False):
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
dt_str = dt_str.replace('void', 'float0')
shapestr = ','.join(map(str, self.shape))
if hasattr(self, 'sharding'):
@ -1872,7 +1866,7 @@ class ConcreteArray(ShapedArray):
raise TypeError(self, other)
def str_short(self, short_dtypes=False) -> str:
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
return f'{self.val}, dtype={dt_str}'
_bool = partialmethod(_forward_to_value, bool)
@ -1922,7 +1916,7 @@ class DShapedArray(UnshapedArray):
def str_short(self, short_dtypes=False) -> str:
del short_dtypes # ignored
shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
dtype = _short_dtype_name(self.dtype)
dtype = dtypes.short_dtype_name(self.dtype)
return f'{dtype}[{shape}]'
__str__ = __repr__ = str_short
@ -1989,7 +1983,7 @@ class DArray:
# special-case scalar bints
return f'{int(self._data)}{{{self.dtype.bound}}}'
dtypestr = _short_dtype_name(self._aval.dtype)
dtypestr = dtypes.short_dtype_name(self._aval.dtype)
shapestr = ','.join(map(str, self.shape))
data = self.data
return f'{dtypestr}[{shapestr}] with value: {data}'
@ -3203,7 +3197,7 @@ def pp_var(v: Var | Literal, context: JaxprPpContext) -> str:
def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str:
if isinstance(a, DShapedArray):
shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape]
dtype = _short_dtype_name(a.dtype)
dtype = dtypes.short_dtype_name(a.dtype)
return f'{dtype}[{",".join(shape)}]'
else:
return a.str_short(short_dtypes=True)

View File

@ -839,13 +839,14 @@ def safe_to_cast(input_dtype_or_value: Any,
def primal_tangent_dtype(primal_dtype, tangent_dtype,
name: str | None = None) -> ExtendedDType:
name_ = name or f'PTDtype{{{primal_dtype}:{tangent_dtype}}}'
primal_dtype, tangent_dtype = map(dtype, (primal_dtype, tangent_dtype))
name_ = name or (f'PrimalTangentDType{{{short_dtype_name(primal_dtype)}'
f'/{short_dtype_name(tangent_dtype)}}}')
rules = types.SimpleNamespace(
physical_element_aval=
lambda dtype: types.SimpleNamespace(shape=(), dtype=primal_dtype),
tangent_dtype=lambda dtype: tangent_dtype,
convert_from=lambda _, other: other == primal_dtype,
convert_to=lambda other, _: other == primal_dtype)
allow_conversion=True)
class primal_tangent_dtype_scalar(extended): ...
@ -854,5 +855,13 @@ def primal_tangent_dtype(primal_dtype, tangent_dtype,
name = name_
_rules = rules
type = primal_tangent_dtype_scalar
__repr__ = lambda _: name_
return PrimalTangentDType()
def short_dtype_name(dtype) -> str:
if isinstance(dtype, ExtendedDType):
return str(dtype)
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))

View File

@ -137,7 +137,7 @@ def asarray(x: ArrayLike) -> Array:
if isinstance(x, Array):
return x
if isinstance(x, (np.ndarray, np.generic, bool, int, float, builtins.complex)):
return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x))
return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x)) # type: ignore[unused-ignore,bad-return-type]
else:
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
@ -520,7 +520,7 @@ def convert_element_type(operand: ArrayLike,
Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
"""
return _convert_element_type(operand, new_dtype, weak_type=False)
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
def _convert_element_type(
operand: ArrayLike,
@ -530,17 +530,30 @@ def _convert_element_type(
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()
if (dtypes.issubdtype(new_dtype, dtypes.extended) or
dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)):
return convert_element_type_p.bind(
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
sharding=sharding)
new_dtype = type_cast(DTypeLike | None, new_dtype)
# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
old_dtype = dtypes.dtype(operand, canonicalize=False)
if (isinstance(new_dtype, dtypes.ExtendedDType) or
isinstance(old_dtype, dtypes.ExtendedDType)):
if sharding is not None or weak_type: raise NotImplementedError
if new_dtype == old_dtype: return operand
if (isinstance(new_dtype, dtypes.ExtendedDType) and
isinstance(old_dtype, dtypes.ExtendedDType)):
old_rep_dtype = core.physical_element_aval(old_dtype).dtype
new_rep_dtype = core.physical_element_aval(new_dtype).dtype
raise ValueError(
"cannot directly convert between extended dtypes: from "
f"{dtype_to_string(old_dtype)} to {dtype_to_string(new_dtype)}. "
"Instead, convert to and from their representation dtypes, e.g.:\n"
f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} "
f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}")
if isinstance(new_dtype, dtypes.ExtendedDType):
return to_edtype_p.bind(operand, edtype=new_dtype)
return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype))
new_dtype = type_cast(DTypeLike | None, new_dtype)
old_weak_type = dtypes.is_weakly_typed(operand)
if new_dtype is None:
new_dtype = old_dtype
@ -2560,14 +2573,6 @@ def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
sharding):
if (operand.dtype != new_dtype and
((dtypes.issubdtype(operand.dtype, dtypes.extended) and
not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or
(dtypes.issubdtype(new_dtype, dtypes.extended) and
not new_dtype._rules.convert_to(operand.dtype, new_dtype)))):
raise ValueError(
f"Cannot convert_element_type from {dtype_to_string(operand.dtype)} "
f"to {dtype_to_string(new_dtype)}")
return new_dtype
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type,
@ -2587,13 +2592,13 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type,
return [convert_element_type_p.bind(
ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)]
def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type,
sharding):
if core.primal_dtype_to_tangent_dtype(new_dtype) == dtypes.float0:
tangent_aval = core.raise_to_shaped(core.get_aval(tangent))
return ad_util.Zero(tangent_aval.update(dtype=dtypes.float0, weak_type=False))
def _convert_element_type_jvp_rule(tangent, primal_result, operand, *,
new_dtype, weak_type, sharding):
new_tangent_dtype = core.primal_dtype_to_tangent_dtype(new_dtype)
if new_tangent_dtype == dtypes.float0:
return ad_util.Zero.from_primal_value(primal_result)
else:
return convert_element_type_p.bind(tangent, new_dtype=new_dtype,
return convert_element_type_p.bind(tangent, new_dtype=new_tangent_dtype,
weak_type=weak_type, sharding=sharding)
def _convert_elt_type_folding_rule(consts, eqn):
@ -2653,7 +2658,7 @@ convert_element_type_p.def_abstract_eval(
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
_convert_element_type_weak_type_rule,
_convert_element_type_sharding_rule))
ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule)
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
batching.defvectorized(convert_element_type_p)
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
@ -2676,6 +2681,91 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
def _to_edtype_abstract_eval(x, *, edtype):
assert (isinstance(edtype, dtypes.ExtendedDType) and
not isinstance(x.dtype, dtypes.ExtendedDType))
# For backward compatibility, if the edtype rules have a `convert_to` method,
# use that rather than looking for an `allow_conversion: bool` attribute.
if convert_to := getattr(edtype._rules, 'convert_to', None):
allow_conversion = convert_to(x.dtype, edtype)
else:
allow_conversion = edtype._rules.allow_conversion
if not allow_conversion:
raise ValueError(
f"Cannot convert_element_type from {dtype_to_string(x.dtype)} "
f"to {dtype_to_string(edtype)}")
rep_aval = core.physical_element_aval(edtype)
if x.dtype != rep_aval.dtype:
raise ValueError(
"can only convert to extended dtype from its representation dtype, "
f"but tried to convert from {dtype_to_string(x.dtype)} to "
f"{dtype_to_string(edtype)} which doesn't match the representation type "
f"{dtype_to_string(rep_aval.dtype)}.")
if x.ndim < rep_aval.ndim:
raise ValueError(
"can only convert to extended dtype from an array of its "
f"representation type, but the extended dtype {dtype_to_string(edtype)}"
f" has a representation shape {rep_aval.shape} (rank {rep_aval.ndim}) "
f"while the given representation array has shape {x.shape} (rank "
f"{x.ndim} < {rep_aval.ndim}).")
n = x.ndim - rep_aval.ndim
shape_prefix, shape_suffix = x.shape[:n], x.shape[n:]
if shape_suffix != rep_aval.shape:
raise ValueError(
"can only convert to extended dtype from an array of its "
f"representation type, but the extended dtype {dtype_to_string(edtype)}"
f" has a representation shape {rep_aval.shape} while the given "
f"representation array has shape {x.shape}, so the shape suffix "
f"does not match: given {shape_suffix} but required {rep_aval.shape}.")
return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype)
to_edtype_p = Primitive('to_edtype')
to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p))
to_edtype_p.def_abstract_eval(_to_edtype_abstract_eval)
ad.defjvp(to_edtype_p,
lambda t, x, edtype:
convert_element_type(t, core.primal_dtype_to_tangent_dtype(edtype)))
ad.primitive_transposes[to_edtype_p] = \
lambda ct, x, edtype: [from_edtype_p.bind(ct, dtype=x.aval.dtype)] # type: ignore
batching.defvectorized(to_edtype_p)
mlir.register_lowering(to_edtype_p, lambda _, x, **__: [x])
def _from_edtype_abstract_eval(x, *, dtype):
assert (isinstance(x.dtype, dtypes.ExtendedDType) and
not isinstance(dtype, dtypes.ExtendedDType))
if convert_from := getattr(x.dtype._rules, 'convert_from', None):
allow_conversion = convert_from(x.dtype, dtype)
else:
allow_conversion = x.dtype._rules.allow_conversion
if not allow_conversion:
raise ValueError(
f"Cannot convert_element_type from {dtype_to_string(x.dtype)} "
f"to {dtype_to_string(dtype)}")
rep_aval = core.physical_element_aval(x.dtype)
if rep_aval.dtype != dtype:
raise ValueError(
"can only convert from extended dtype to its representation dtype, "
f"but tried to convert from {dtype_to_string(x.dtype)} to "
f"{dtype_to_string(dtype)} which doesn't match the representation type "
f"{dtype_to_string(rep_aval.dtype)}.")
if all(isinstance(d, int) for d in x.shape):
return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype)
else:
raise NotImplementedError
from_edtype_p = Primitive('from_edtype')
from_edtype_p.def_impl(partial(dispatch.apply_primitive, from_edtype_p))
from_edtype_p.def_abstract_eval(_from_edtype_abstract_eval)
ad.defjvp(from_edtype_p,
lambda t, x, dtype:
convert_element_type(t, core.primal_dtype_to_tangent_dtype(dtype)))
ad.primitive_transposes[from_edtype_p] = \
lambda ct, x, dtype: [to_edtype_p.bind(ct, edtype=x.dtype)]
batching.defvectorized(from_edtype_p)
mlir.register_lowering(from_edtype_p, lambda _, x, **__: [x])
def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
new_dtype = dtypes.canonicalize_dtype(new_dtype)
@ -5343,6 +5433,8 @@ batching.defvectorized(tie_p)
class BIntRules:
allow_conversion: bool = True
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((), np.dtype('int32'))
@ -5369,14 +5461,6 @@ class BIntRules:
return core.DArray(aval, phys_handler(bufs))
return handler
@staticmethod
def convert_from(bint_dtype, other_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
@staticmethod
def convert_to(other_dtype, bint_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
core.bint._rules = BIntRules

View File

@ -149,11 +149,8 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray):
raise NotImplementedError
def str_short(self, short_dtypes=False):
dt_str = (
jax_core._short_dtype_name(self.dtype)
if short_dtypes
else self.dtype.name
)
dt_str = \
dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
dt_str = dt_str.replace("void", "float0")
shapestr = ",".join(map(str, self.shape))
if hasattr(self, "sharding"):

View File

@ -321,6 +321,7 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape):
class KeyTyRules:
allow_conversion: bool = False
@staticmethod
def full(shape, fill_value, dtype):
@ -425,14 +426,6 @@ class KeyTyRules:
def zero(_):
return np.zeros((), dtypes.float0)
@staticmethod
def convert_from(key_dtype, other_dtype) -> bool:
return False
@staticmethod
def convert_to(other_dtype, key_dtype) -> bool:
return False
class KeyTy(dtypes.ExtendedDType):
_impl: PRNGImpl # TODO(mattjj,frostig): protocol really

View File

@ -1549,6 +1549,8 @@ tf_not_yet_impl = [
"ragged_dot",
"cholesky_update",
"symmetric_update",
"from_edtype",
"to_edtype",
# Pallas TPU primitives
"bitcast",
"repeat",

View File

@ -19,6 +19,7 @@ import functools
from functools import partial
import itertools
import operator
import types
from absl.testing import absltest
from absl.testing import parameterized
@ -300,16 +301,6 @@ class DtypesTest(jtu.JaxTestCase):
self.assertEqual(dtypes.issubdtype(t, category),
np.issubdtype(np.dtype(t).type, category))
def testIsSubdtypeExtended(self):
self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended))
self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic))
self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number))
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key))
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended))
self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic))
self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number))
@parameterized.product(dtype=custom_float_dtypes)
def testIsSubdtypeCustomFloats(self, dtype):
for dt in [dtype, np.dtype(dtype), str(np.dtype(dtype))]:
@ -408,6 +399,34 @@ class DtypesTest(jtu.JaxTestCase):
self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64)
self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128)
def test_check_dtype_non_hashable(self):
# regression test for issue with checking non-hashable custom dtype
class MyDtype:
__hash__ = None
dtype = np.dtype('float32')
dtypes.check_user_dtype_supported(MyDtype())
def test_check_dtype_array(self):
x = jnp.arange(4)
msg = "Passing an array as a dtype argument is deprecated"
with self.assertWarnsRegex(DeprecationWarning, msg):
dtypes.check_user_dtype_supported(x)
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(dtypes.check_user_dtype_supported)(x)
class ExtendedDTypeTest(jtu.JaxTestCase):
def testIsSubdtypeExtended(self):
self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended))
self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic))
self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number))
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key))
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended))
self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic))
self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number))
def test_custom_tangent_dtype(self):
from jax._src import core
@ -415,6 +434,8 @@ class DtypesTest(jtu.JaxTestCase):
pass
class ScalesTyRules:
allow_conversion: bool = True
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((), dtype.float_dtype)
@ -435,14 +456,6 @@ class DtypesTest(jtu.JaxTestCase):
else dtypes.finfo(dt.float_dtype).min, dt.float_dtype)
return jax.lax.convert_element_type(neginf, dt)
@staticmethod
def convert_from(dtype, other_dtype) -> bool:
return dtype.float_dtype == other_dtype
@staticmethod
def convert_to(other_dtype, dtype) -> bool:
return dtype.float_dtype == other_dtype
@dataclasses.dataclass(frozen=True)
class ScaleTy(dtypes.ExtendedDType):
float_dtype: dtypes.DType
@ -485,19 +498,13 @@ class DtypesTest(jtu.JaxTestCase):
from jax._src import core
class ScalesTyRules:
# tell JAX how to lower this dtype to an HLO dtype
# tell JAX how to lower this dtype to an HLO representation dtype
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((), dtype.float_dtype)
# allow conversions to and from the corresponding float type
@staticmethod
def convert_from(scale_dtype, other_dtype) -> bool:
return scale_dtype.float_dtype == other_dtype
@staticmethod
def convert_to(other_dtype, scale_dtype) -> bool:
return scale_dtype.float_dtype == other_dtype
# allow conversions to and from the corresponding representation type
allow_conversion: bool = True
# define how autodiff should accumulate these values
@staticmethod
@ -563,21 +570,6 @@ class DtypesTest(jtu.JaxTestCase):
_, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale)
self.assertAllClose(new_scale, jnp.float32(1.0))
def test_check_dtype_non_hashable(self):
# regression test for issue with checking non-hashable custom dtype
class MyDtype:
__hash__ = None
dtype = np.dtype('float32')
dtypes.check_user_dtype_supported(MyDtype())
def test_check_dtype_array(self):
x = jnp.arange(4)
msg = "Passing an array as a dtype argument is deprecated"
with self.assertWarnsRegex(DeprecationWarning, msg):
dtypes.check_user_dtype_supported(x)
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(dtypes.check_user_dtype_supported)(x)
@parameterized.parameters([True]) # TODO(mattjj): make jit=False work
def test_primal_tangent_dtype(self, jit):
dt = dtypes.primal_tangent_dtype(jnp.int8, jnp.bfloat16)
@ -605,6 +597,123 @@ class DtypesTest(jtu.JaxTestCase):
self.assertEqual(result.dtype, jnp.bfloat16)
self.assertEqual(bwd_result.dtype, jnp.bfloat16)
self.assertAllClose(bwd_result, 2 * g)
self.assertEqual(repr(dt), 'PrimalTangentDType{i8/bf16}')
@parameterized.parameters(itertools.product([(), (2,), (3, 4)], repeat=2))
def test_edtype_conversion(self, shape_prefix, shape_suffix):
class scalar(dtypes.extended): ...
@dataclasses.dataclass(frozen=True)
class DType(dtypes.ExtendedDType):
name = 'dt'
type = scalar
_rules = types.SimpleNamespace(
physical_element_aval=
lambda _: types.SimpleNamespace(shape=shape_suffix, dtype='int32'),
allow_conversion=True)
dtype = DType()
@jax.jit
def f(x):
self.assertEqual(x.shape, shape_prefix + shape_suffix)
self.assertEqual(x.dtype, jnp.dtype('int32'))
x = jax.lax.convert_element_type(x, dtype)
self.assertEqual(x.shape, shape_prefix)
self.assertEqual(x.dtype, dtype)
x = jax.lax.convert_element_type(x, 'int32')
self.assertEqual(x.shape, shape_prefix + shape_suffix)
self.assertEqual(x.dtype, jnp.dtype('int32'))
f(jnp.zeros(shape_prefix + shape_suffix, dtype='int32'))
def test_edtype_conversion_errors(self):
class scalar(dtypes.extended): ...
@dataclasses.dataclass(frozen=True)
class DType(dtypes.ExtendedDType):
name = 'dt'
type = scalar
_rules = types.SimpleNamespace(
physical_element_aval=
lambda _: types.SimpleNamespace(shape=(3,), dtype='int32'),
allow_conversion=True)
dtype = DType()
class scalar2(dtypes.extended): ...
@dataclasses.dataclass(frozen=True)
class DType2(dtypes.ExtendedDType):
name = 'dt2'
type = scalar2
_rules = types.SimpleNamespace(
physical_element_aval=
lambda _: types.SimpleNamespace(shape=(3,), dtype='int32'),
allow_conversion=True)
dtype2 = DType2()
@jax.jit
def f(x):
y = jax.lax.convert_element_type(x, dtype)
with self.assertRaisesRegex(ValueError, "cannot directly"):
jax.lax.convert_element_type(y, dtype2)
with self.assertRaisesRegex(ValueError, "can only convert"):
jax.lax.convert_element_type(x.astype('float32'), dtype)
with self.assertRaisesRegex(ValueError, "can only convert"):
jax.lax.convert_element_type(x[:, :2], dtype)
with self.assertRaisesRegex(ValueError, "can only convert"):
jax.lax.convert_element_type(x[:, 0], dtype)
with self.assertRaisesRegex(ValueError, "can only convert"):
jax.lax.convert_element_type(y, 'float32')
f(jnp.zeros((5, 3), dtype='int32'))
def test_edtype_conversion_autodiff(self):
class scalar(dtypes.extended): ...
@dataclasses.dataclass(frozen=True)
class DType(dtypes.ExtendedDType):
name = 'dt'
type = scalar
_rules = types.SimpleNamespace(
physical_element_aval=
lambda _: types.SimpleNamespace(shape=(), dtype='float32'),
tangent_dtype=lambda dtype: jnp.dtype('bfloat16'),
allow_conversion=True)
dtype = DType()
@jax.jit
@jax.grad
def f(x):
x = jax.lax.convert_element_type(x, dtype)
@jax.custom_jvp
def g(x): return x
@g.defjvp
def g_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
self.assertEqual(x.shape, (5,))
self.assertEqual(x.dtype, dtype)
self.assertEqual(x_dot.shape, (5,))
self.assertEqual(x_dot.dtype, jnp.dtype('bfloat16'))
return x, x_dot
x = g(x)
x = jax.lax.convert_element_type(x, 'float32')
@jax.custom_jvp
def h(x): return x
@h.defjvp
def h_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
self.assertEqual(x.shape, (5,))
self.assertEqual(x.dtype, jnp.dtype('float32'))
self.assertEqual(x_dot.shape, (5,))
self.assertEqual(x_dot.dtype, jnp.dtype('float32'))
return x, x_dot
x = h(x)
return 0.
f(jnp.zeros(5, dtype='float32')) # test assertions in the function
class EArrayTest(jtu.JaxTestCase):
@ -618,10 +727,7 @@ class EArrayTest(jtu.JaxTestCase):
class foo(dtypes.extended): pass
class FooTyRules:
@staticmethod
def convert_to(foo_dtype, target_dtype):
return True
allow_conversion: bool = True
@staticmethod
def physical_element_aval(foo_dtype):

View File

@ -1486,6 +1486,9 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase):
jax_traceback_filtering='off')
class JumbleTest(jtu.JaxTestCase):
def setUp(self):
if jax.config.x64_enabled: raise unittest.SkipTest()
@parameterized.parameters((True,), (False,))
def test_internal_jumble(self, disable_jit):
with jax.disable_jit(disable_jit):