mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23823 from mattjj:simplify-extended-dtype-convert-logic
PiperOrigin-RevId: 678456216
This commit is contained in:
commit
cfb4e85fcd
@ -1608,12 +1608,6 @@ def physical_element_aval(edtype: dtypes.ExtendedDType) -> ShapedArray:
|
||||
duck = edtype._rules.physical_element_aval(edtype) # type: ignore
|
||||
return ShapedArray(duck.shape, dtypes.dtype(duck.dtype))
|
||||
|
||||
def _short_dtype_name(dtype) -> str:
|
||||
if isinstance(dtype, dtypes.ExtendedDType):
|
||||
return str(dtype)
|
||||
else:
|
||||
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
|
||||
.replace('int' , 'i').replace('complex', 'c'))
|
||||
|
||||
def _dtype_object(dtype):
|
||||
return dtype if isinstance(dtype, dtypes.ExtendedDType) else np.dtype(dtype)
|
||||
@ -1672,7 +1666,7 @@ class UnshapedArray(AbstractValue):
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
|
||||
def strip_weak_type(self):
|
||||
"""Returns a copy of the aval with weak_type=False."""
|
||||
@ -1811,7 +1805,7 @@ class ShapedArray(UnshapedArray):
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self, short_dtypes=False):
|
||||
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
if hasattr(self, 'sharding'):
|
||||
@ -1872,7 +1866,7 @@ class ConcreteArray(ShapedArray):
|
||||
raise TypeError(self, other)
|
||||
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
return f'{self.val}, dtype={dt_str}'
|
||||
|
||||
_bool = partialmethod(_forward_to_value, bool)
|
||||
@ -1922,7 +1916,7 @@ class DShapedArray(UnshapedArray):
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
del short_dtypes # ignored
|
||||
shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
|
||||
dtype = _short_dtype_name(self.dtype)
|
||||
dtype = dtypes.short_dtype_name(self.dtype)
|
||||
return f'{dtype}[{shape}]'
|
||||
__str__ = __repr__ = str_short
|
||||
|
||||
@ -1989,7 +1983,7 @@ class DArray:
|
||||
# special-case scalar bints
|
||||
return f'{int(self._data)}{{≤{self.dtype.bound}}}'
|
||||
|
||||
dtypestr = _short_dtype_name(self._aval.dtype)
|
||||
dtypestr = dtypes.short_dtype_name(self._aval.dtype)
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
data = self.data
|
||||
return f'{dtypestr}[{shapestr}] with value: {data}'
|
||||
@ -3203,7 +3197,7 @@ def pp_var(v: Var | Literal, context: JaxprPpContext) -> str:
|
||||
def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str:
|
||||
if isinstance(a, DShapedArray):
|
||||
shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape]
|
||||
dtype = _short_dtype_name(a.dtype)
|
||||
dtype = dtypes.short_dtype_name(a.dtype)
|
||||
return f'{dtype}[{",".join(shape)}]'
|
||||
else:
|
||||
return a.str_short(short_dtypes=True)
|
||||
|
@ -839,13 +839,14 @@ def safe_to_cast(input_dtype_or_value: Any,
|
||||
|
||||
def primal_tangent_dtype(primal_dtype, tangent_dtype,
|
||||
name: str | None = None) -> ExtendedDType:
|
||||
name_ = name or f'PTDtype{{{primal_dtype}:{tangent_dtype}}}'
|
||||
primal_dtype, tangent_dtype = map(dtype, (primal_dtype, tangent_dtype))
|
||||
name_ = name or (f'PrimalTangentDType{{{short_dtype_name(primal_dtype)}'
|
||||
f'/{short_dtype_name(tangent_dtype)}}}')
|
||||
rules = types.SimpleNamespace(
|
||||
physical_element_aval=
|
||||
lambda dtype: types.SimpleNamespace(shape=(), dtype=primal_dtype),
|
||||
tangent_dtype=lambda dtype: tangent_dtype,
|
||||
convert_from=lambda _, other: other == primal_dtype,
|
||||
convert_to=lambda other, _: other == primal_dtype)
|
||||
allow_conversion=True)
|
||||
|
||||
class primal_tangent_dtype_scalar(extended): ...
|
||||
|
||||
@ -854,5 +855,13 @@ def primal_tangent_dtype(primal_dtype, tangent_dtype,
|
||||
name = name_
|
||||
_rules = rules
|
||||
type = primal_tangent_dtype_scalar
|
||||
__repr__ = lambda _: name_
|
||||
|
||||
return PrimalTangentDType()
|
||||
|
||||
def short_dtype_name(dtype) -> str:
|
||||
if isinstance(dtype, ExtendedDType):
|
||||
return str(dtype)
|
||||
else:
|
||||
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
|
||||
.replace('int' , 'i').replace('complex', 'c'))
|
||||
|
@ -137,7 +137,7 @@ def asarray(x: ArrayLike) -> Array:
|
||||
if isinstance(x, Array):
|
||||
return x
|
||||
if isinstance(x, (np.ndarray, np.generic, bool, int, float, builtins.complex)):
|
||||
return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x))
|
||||
return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x)) # type: ignore[unused-ignore,bad-return-type]
|
||||
else:
|
||||
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
|
||||
|
||||
@ -520,7 +520,7 @@ def convert_element_type(operand: ArrayLike,
|
||||
Returns:
|
||||
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
|
||||
"""
|
||||
return _convert_element_type(operand, new_dtype, weak_type=False)
|
||||
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
|
||||
|
||||
def _convert_element_type(
|
||||
operand: ArrayLike,
|
||||
@ -530,17 +530,30 @@ def _convert_element_type(
|
||||
if hasattr(operand, '__jax_array__'):
|
||||
operand = operand.__jax_array__()
|
||||
|
||||
if (dtypes.issubdtype(new_dtype, dtypes.extended) or
|
||||
dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)):
|
||||
return convert_element_type_p.bind(
|
||||
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
||||
sharding=sharding)
|
||||
|
||||
new_dtype = type_cast(DTypeLike | None, new_dtype)
|
||||
|
||||
# Don't canonicalize old_dtype because x64 context might cause
|
||||
# un-canonicalized operands to be passed in.
|
||||
old_dtype = dtypes.dtype(operand, canonicalize=False)
|
||||
|
||||
if (isinstance(new_dtype, dtypes.ExtendedDType) or
|
||||
isinstance(old_dtype, dtypes.ExtendedDType)):
|
||||
if sharding is not None or weak_type: raise NotImplementedError
|
||||
if new_dtype == old_dtype: return operand
|
||||
if (isinstance(new_dtype, dtypes.ExtendedDType) and
|
||||
isinstance(old_dtype, dtypes.ExtendedDType)):
|
||||
old_rep_dtype = core.physical_element_aval(old_dtype).dtype
|
||||
new_rep_dtype = core.physical_element_aval(new_dtype).dtype
|
||||
raise ValueError(
|
||||
"cannot directly convert between extended dtypes: from "
|
||||
f"{dtype_to_string(old_dtype)} to {dtype_to_string(new_dtype)}. "
|
||||
"Instead, convert to and from their representation dtypes, e.g.:\n"
|
||||
f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} "
|
||||
f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}")
|
||||
if isinstance(new_dtype, dtypes.ExtendedDType):
|
||||
return to_edtype_p.bind(operand, edtype=new_dtype)
|
||||
return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype))
|
||||
|
||||
new_dtype = type_cast(DTypeLike | None, new_dtype)
|
||||
|
||||
old_weak_type = dtypes.is_weakly_typed(operand)
|
||||
if new_dtype is None:
|
||||
new_dtype = old_dtype
|
||||
@ -2560,14 +2573,6 @@ def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
|
||||
|
||||
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
|
||||
sharding):
|
||||
if (operand.dtype != new_dtype and
|
||||
((dtypes.issubdtype(operand.dtype, dtypes.extended) and
|
||||
not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or
|
||||
(dtypes.issubdtype(new_dtype, dtypes.extended) and
|
||||
not new_dtype._rules.convert_to(operand.dtype, new_dtype)))):
|
||||
raise ValueError(
|
||||
f"Cannot convert_element_type from {dtype_to_string(operand.dtype)} "
|
||||
f"to {dtype_to_string(new_dtype)}")
|
||||
return new_dtype
|
||||
|
||||
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type,
|
||||
@ -2587,13 +2592,13 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type,
|
||||
return [convert_element_type_p.bind(
|
||||
ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)]
|
||||
|
||||
def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type,
|
||||
sharding):
|
||||
if core.primal_dtype_to_tangent_dtype(new_dtype) == dtypes.float0:
|
||||
tangent_aval = core.raise_to_shaped(core.get_aval(tangent))
|
||||
return ad_util.Zero(tangent_aval.update(dtype=dtypes.float0, weak_type=False))
|
||||
def _convert_element_type_jvp_rule(tangent, primal_result, operand, *,
|
||||
new_dtype, weak_type, sharding):
|
||||
new_tangent_dtype = core.primal_dtype_to_tangent_dtype(new_dtype)
|
||||
if new_tangent_dtype == dtypes.float0:
|
||||
return ad_util.Zero.from_primal_value(primal_result)
|
||||
else:
|
||||
return convert_element_type_p.bind(tangent, new_dtype=new_dtype,
|
||||
return convert_element_type_p.bind(tangent, new_dtype=new_tangent_dtype,
|
||||
weak_type=weak_type, sharding=sharding)
|
||||
|
||||
def _convert_elt_type_folding_rule(consts, eqn):
|
||||
@ -2653,7 +2658,7 @@ convert_element_type_p.def_abstract_eval(
|
||||
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
||||
_convert_element_type_weak_type_rule,
|
||||
_convert_element_type_sharding_rule))
|
||||
ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
|
||||
ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule)
|
||||
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
|
||||
batching.defvectorized(convert_element_type_p)
|
||||
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
|
||||
@ -2676,6 +2681,91 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
|
||||
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
|
||||
|
||||
|
||||
def _to_edtype_abstract_eval(x, *, edtype):
|
||||
assert (isinstance(edtype, dtypes.ExtendedDType) and
|
||||
not isinstance(x.dtype, dtypes.ExtendedDType))
|
||||
# For backward compatibility, if the edtype rules have a `convert_to` method,
|
||||
# use that rather than looking for an `allow_conversion: bool` attribute.
|
||||
if convert_to := getattr(edtype._rules, 'convert_to', None):
|
||||
allow_conversion = convert_to(x.dtype, edtype)
|
||||
else:
|
||||
allow_conversion = edtype._rules.allow_conversion
|
||||
if not allow_conversion:
|
||||
raise ValueError(
|
||||
f"Cannot convert_element_type from {dtype_to_string(x.dtype)} "
|
||||
f"to {dtype_to_string(edtype)}")
|
||||
rep_aval = core.physical_element_aval(edtype)
|
||||
if x.dtype != rep_aval.dtype:
|
||||
raise ValueError(
|
||||
"can only convert to extended dtype from its representation dtype, "
|
||||
f"but tried to convert from {dtype_to_string(x.dtype)} to "
|
||||
f"{dtype_to_string(edtype)} which doesn't match the representation type "
|
||||
f"{dtype_to_string(rep_aval.dtype)}.")
|
||||
if x.ndim < rep_aval.ndim:
|
||||
raise ValueError(
|
||||
"can only convert to extended dtype from an array of its "
|
||||
f"representation type, but the extended dtype {dtype_to_string(edtype)}"
|
||||
f" has a representation shape {rep_aval.shape} (rank {rep_aval.ndim}) "
|
||||
f"while the given representation array has shape {x.shape} (rank "
|
||||
f"{x.ndim} < {rep_aval.ndim}).")
|
||||
n = x.ndim - rep_aval.ndim
|
||||
shape_prefix, shape_suffix = x.shape[:n], x.shape[n:]
|
||||
if shape_suffix != rep_aval.shape:
|
||||
raise ValueError(
|
||||
"can only convert to extended dtype from an array of its "
|
||||
f"representation type, but the extended dtype {dtype_to_string(edtype)}"
|
||||
f" has a representation shape {rep_aval.shape} while the given "
|
||||
f"representation array has shape {x.shape}, so the shape suffix "
|
||||
f"does not match: given {shape_suffix} but required {rep_aval.shape}.")
|
||||
return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype)
|
||||
|
||||
to_edtype_p = Primitive('to_edtype')
|
||||
to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p))
|
||||
to_edtype_p.def_abstract_eval(_to_edtype_abstract_eval)
|
||||
ad.defjvp(to_edtype_p,
|
||||
lambda t, x, edtype:
|
||||
convert_element_type(t, core.primal_dtype_to_tangent_dtype(edtype)))
|
||||
ad.primitive_transposes[to_edtype_p] = \
|
||||
lambda ct, x, edtype: [from_edtype_p.bind(ct, dtype=x.aval.dtype)] # type: ignore
|
||||
batching.defvectorized(to_edtype_p)
|
||||
mlir.register_lowering(to_edtype_p, lambda _, x, **__: [x])
|
||||
|
||||
|
||||
def _from_edtype_abstract_eval(x, *, dtype):
|
||||
assert (isinstance(x.dtype, dtypes.ExtendedDType) and
|
||||
not isinstance(dtype, dtypes.ExtendedDType))
|
||||
if convert_from := getattr(x.dtype._rules, 'convert_from', None):
|
||||
allow_conversion = convert_from(x.dtype, dtype)
|
||||
else:
|
||||
allow_conversion = x.dtype._rules.allow_conversion
|
||||
if not allow_conversion:
|
||||
raise ValueError(
|
||||
f"Cannot convert_element_type from {dtype_to_string(x.dtype)} "
|
||||
f"to {dtype_to_string(dtype)}")
|
||||
rep_aval = core.physical_element_aval(x.dtype)
|
||||
if rep_aval.dtype != dtype:
|
||||
raise ValueError(
|
||||
"can only convert from extended dtype to its representation dtype, "
|
||||
f"but tried to convert from {dtype_to_string(x.dtype)} to "
|
||||
f"{dtype_to_string(dtype)} which doesn't match the representation type "
|
||||
f"{dtype_to_string(rep_aval.dtype)}.")
|
||||
if all(isinstance(d, int) for d in x.shape):
|
||||
return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
from_edtype_p = Primitive('from_edtype')
|
||||
from_edtype_p.def_impl(partial(dispatch.apply_primitive, from_edtype_p))
|
||||
from_edtype_p.def_abstract_eval(_from_edtype_abstract_eval)
|
||||
ad.defjvp(from_edtype_p,
|
||||
lambda t, x, dtype:
|
||||
convert_element_type(t, core.primal_dtype_to_tangent_dtype(dtype)))
|
||||
ad.primitive_transposes[from_edtype_p] = \
|
||||
lambda ct, x, dtype: [to_edtype_p.bind(ct, edtype=x.dtype)]
|
||||
batching.defvectorized(from_edtype_p)
|
||||
mlir.register_lowering(from_edtype_p, lambda _, x, **__: [x])
|
||||
|
||||
|
||||
def _bitcast_convert_type_shape_rule(operand, *, new_dtype):
|
||||
old_dtype = dtypes.canonicalize_dtype(operand.dtype)
|
||||
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
||||
@ -5343,6 +5433,8 @@ batching.defvectorized(tie_p)
|
||||
|
||||
|
||||
class BIntRules:
|
||||
allow_conversion: bool = True
|
||||
|
||||
@staticmethod
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
return core.ShapedArray((), np.dtype('int32'))
|
||||
@ -5369,14 +5461,6 @@ class BIntRules:
|
||||
return core.DArray(aval, phys_handler(bufs))
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def convert_from(bint_dtype, other_dtype) -> bool:
|
||||
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
||||
|
||||
@staticmethod
|
||||
def convert_to(other_dtype, bint_dtype) -> bool:
|
||||
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
|
||||
|
||||
|
||||
core.bint._rules = BIntRules
|
||||
|
||||
|
@ -149,11 +149,8 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray):
|
||||
raise NotImplementedError
|
||||
|
||||
def str_short(self, short_dtypes=False):
|
||||
dt_str = (
|
||||
jax_core._short_dtype_name(self.dtype)
|
||||
if short_dtypes
|
||||
else self.dtype.name
|
||||
)
|
||||
dt_str = \
|
||||
dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
dt_str = dt_str.replace("void", "float0")
|
||||
shapestr = ",".join(map(str, self.shape))
|
||||
if hasattr(self, "sharding"):
|
||||
|
@ -321,6 +321,7 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape):
|
||||
|
||||
|
||||
class KeyTyRules:
|
||||
allow_conversion: bool = False
|
||||
|
||||
@staticmethod
|
||||
def full(shape, fill_value, dtype):
|
||||
@ -425,14 +426,6 @@ class KeyTyRules:
|
||||
def zero(_):
|
||||
return np.zeros((), dtypes.float0)
|
||||
|
||||
@staticmethod
|
||||
def convert_from(key_dtype, other_dtype) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def convert_to(other_dtype, key_dtype) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class KeyTy(dtypes.ExtendedDType):
|
||||
_impl: PRNGImpl # TODO(mattjj,frostig): protocol really
|
||||
|
@ -1549,6 +1549,8 @@ tf_not_yet_impl = [
|
||||
"ragged_dot",
|
||||
"cholesky_update",
|
||||
"symmetric_update",
|
||||
"from_edtype",
|
||||
"to_edtype",
|
||||
# Pallas TPU primitives
|
||||
"bitcast",
|
||||
"repeat",
|
||||
|
@ -19,6 +19,7 @@ import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
import operator
|
||||
import types
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -300,16 +301,6 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(dtypes.issubdtype(t, category),
|
||||
np.issubdtype(np.dtype(t).type, category))
|
||||
|
||||
def testIsSubdtypeExtended(self):
|
||||
self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended))
|
||||
self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic))
|
||||
self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number))
|
||||
|
||||
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key))
|
||||
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended))
|
||||
self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic))
|
||||
self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number))
|
||||
|
||||
@parameterized.product(dtype=custom_float_dtypes)
|
||||
def testIsSubdtypeCustomFloats(self, dtype):
|
||||
for dt in [dtype, np.dtype(dtype), str(np.dtype(dtype))]:
|
||||
@ -408,6 +399,34 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64)
|
||||
self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128)
|
||||
|
||||
def test_check_dtype_non_hashable(self):
|
||||
# regression test for issue with checking non-hashable custom dtype
|
||||
class MyDtype:
|
||||
__hash__ = None
|
||||
dtype = np.dtype('float32')
|
||||
dtypes.check_user_dtype_supported(MyDtype())
|
||||
|
||||
def test_check_dtype_array(self):
|
||||
x = jnp.arange(4)
|
||||
msg = "Passing an array as a dtype argument is deprecated"
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
dtypes.check_user_dtype_supported(x)
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
jax.jit(dtypes.check_user_dtype_supported)(x)
|
||||
|
||||
|
||||
class ExtendedDTypeTest(jtu.JaxTestCase):
|
||||
|
||||
def testIsSubdtypeExtended(self):
|
||||
self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended))
|
||||
self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic))
|
||||
self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number))
|
||||
|
||||
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key))
|
||||
self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended))
|
||||
self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic))
|
||||
self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number))
|
||||
|
||||
def test_custom_tangent_dtype(self):
|
||||
from jax._src import core
|
||||
|
||||
@ -415,6 +434,8 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
pass
|
||||
|
||||
class ScalesTyRules:
|
||||
allow_conversion: bool = True
|
||||
|
||||
@staticmethod
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
return core.ShapedArray((), dtype.float_dtype)
|
||||
@ -435,14 +456,6 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
else dtypes.finfo(dt.float_dtype).min, dt.float_dtype)
|
||||
return jax.lax.convert_element_type(neginf, dt)
|
||||
|
||||
@staticmethod
|
||||
def convert_from(dtype, other_dtype) -> bool:
|
||||
return dtype.float_dtype == other_dtype
|
||||
|
||||
@staticmethod
|
||||
def convert_to(other_dtype, dtype) -> bool:
|
||||
return dtype.float_dtype == other_dtype
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ScaleTy(dtypes.ExtendedDType):
|
||||
float_dtype: dtypes.DType
|
||||
@ -485,19 +498,13 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
from jax._src import core
|
||||
|
||||
class ScalesTyRules:
|
||||
# tell JAX how to lower this dtype to an HLO dtype
|
||||
# tell JAX how to lower this dtype to an HLO representation dtype
|
||||
@staticmethod
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
return core.ShapedArray((), dtype.float_dtype)
|
||||
|
||||
# allow conversions to and from the corresponding float type
|
||||
@staticmethod
|
||||
def convert_from(scale_dtype, other_dtype) -> bool:
|
||||
return scale_dtype.float_dtype == other_dtype
|
||||
|
||||
@staticmethod
|
||||
def convert_to(other_dtype, scale_dtype) -> bool:
|
||||
return scale_dtype.float_dtype == other_dtype
|
||||
# allow conversions to and from the corresponding representation type
|
||||
allow_conversion: bool = True
|
||||
|
||||
# define how autodiff should accumulate these values
|
||||
@staticmethod
|
||||
@ -563,21 +570,6 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
_, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale)
|
||||
self.assertAllClose(new_scale, jnp.float32(1.0))
|
||||
|
||||
def test_check_dtype_non_hashable(self):
|
||||
# regression test for issue with checking non-hashable custom dtype
|
||||
class MyDtype:
|
||||
__hash__ = None
|
||||
dtype = np.dtype('float32')
|
||||
dtypes.check_user_dtype_supported(MyDtype())
|
||||
|
||||
def test_check_dtype_array(self):
|
||||
x = jnp.arange(4)
|
||||
msg = "Passing an array as a dtype argument is deprecated"
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
dtypes.check_user_dtype_supported(x)
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
jax.jit(dtypes.check_user_dtype_supported)(x)
|
||||
|
||||
@parameterized.parameters([True]) # TODO(mattjj): make jit=False work
|
||||
def test_primal_tangent_dtype(self, jit):
|
||||
dt = dtypes.primal_tangent_dtype(jnp.int8, jnp.bfloat16)
|
||||
@ -605,6 +597,123 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(result.dtype, jnp.bfloat16)
|
||||
self.assertEqual(bwd_result.dtype, jnp.bfloat16)
|
||||
self.assertAllClose(bwd_result, 2 * g)
|
||||
self.assertEqual(repr(dt), 'PrimalTangentDType{i8/bf16}')
|
||||
|
||||
@parameterized.parameters(itertools.product([(), (2,), (3, 4)], repeat=2))
|
||||
def test_edtype_conversion(self, shape_prefix, shape_suffix):
|
||||
class scalar(dtypes.extended): ...
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DType(dtypes.ExtendedDType):
|
||||
name = 'dt'
|
||||
type = scalar
|
||||
_rules = types.SimpleNamespace(
|
||||
physical_element_aval=
|
||||
lambda _: types.SimpleNamespace(shape=shape_suffix, dtype='int32'),
|
||||
allow_conversion=True)
|
||||
dtype = DType()
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertEqual(x.shape, shape_prefix + shape_suffix)
|
||||
self.assertEqual(x.dtype, jnp.dtype('int32'))
|
||||
x = jax.lax.convert_element_type(x, dtype)
|
||||
self.assertEqual(x.shape, shape_prefix)
|
||||
self.assertEqual(x.dtype, dtype)
|
||||
x = jax.lax.convert_element_type(x, 'int32')
|
||||
self.assertEqual(x.shape, shape_prefix + shape_suffix)
|
||||
self.assertEqual(x.dtype, jnp.dtype('int32'))
|
||||
f(jnp.zeros(shape_prefix + shape_suffix, dtype='int32'))
|
||||
|
||||
def test_edtype_conversion_errors(self):
|
||||
class scalar(dtypes.extended): ...
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DType(dtypes.ExtendedDType):
|
||||
name = 'dt'
|
||||
type = scalar
|
||||
_rules = types.SimpleNamespace(
|
||||
physical_element_aval=
|
||||
lambda _: types.SimpleNamespace(shape=(3,), dtype='int32'),
|
||||
allow_conversion=True)
|
||||
dtype = DType()
|
||||
|
||||
class scalar2(dtypes.extended): ...
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DType2(dtypes.ExtendedDType):
|
||||
name = 'dt2'
|
||||
type = scalar2
|
||||
_rules = types.SimpleNamespace(
|
||||
physical_element_aval=
|
||||
lambda _: types.SimpleNamespace(shape=(3,), dtype='int32'),
|
||||
allow_conversion=True)
|
||||
dtype2 = DType2()
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jax.lax.convert_element_type(x, dtype)
|
||||
with self.assertRaisesRegex(ValueError, "cannot directly"):
|
||||
jax.lax.convert_element_type(y, dtype2)
|
||||
with self.assertRaisesRegex(ValueError, "can only convert"):
|
||||
jax.lax.convert_element_type(x.astype('float32'), dtype)
|
||||
with self.assertRaisesRegex(ValueError, "can only convert"):
|
||||
jax.lax.convert_element_type(x[:, :2], dtype)
|
||||
with self.assertRaisesRegex(ValueError, "can only convert"):
|
||||
jax.lax.convert_element_type(x[:, 0], dtype)
|
||||
with self.assertRaisesRegex(ValueError, "can only convert"):
|
||||
jax.lax.convert_element_type(y, 'float32')
|
||||
f(jnp.zeros((5, 3), dtype='int32'))
|
||||
|
||||
def test_edtype_conversion_autodiff(self):
|
||||
|
||||
class scalar(dtypes.extended): ...
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DType(dtypes.ExtendedDType):
|
||||
name = 'dt'
|
||||
type = scalar
|
||||
_rules = types.SimpleNamespace(
|
||||
physical_element_aval=
|
||||
lambda _: types.SimpleNamespace(shape=(), dtype='float32'),
|
||||
tangent_dtype=lambda dtype: jnp.dtype('bfloat16'),
|
||||
allow_conversion=True)
|
||||
dtype = DType()
|
||||
|
||||
@jax.jit
|
||||
@jax.grad
|
||||
def f(x):
|
||||
x = jax.lax.convert_element_type(x, dtype)
|
||||
|
||||
@jax.custom_jvp
|
||||
def g(x): return x
|
||||
@g.defjvp
|
||||
def g_jvp(primals, tangents):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
self.assertEqual(x.shape, (5,))
|
||||
self.assertEqual(x.dtype, dtype)
|
||||
self.assertEqual(x_dot.shape, (5,))
|
||||
self.assertEqual(x_dot.dtype, jnp.dtype('bfloat16'))
|
||||
return x, x_dot
|
||||
x = g(x)
|
||||
|
||||
x = jax.lax.convert_element_type(x, 'float32')
|
||||
|
||||
@jax.custom_jvp
|
||||
def h(x): return x
|
||||
@h.defjvp
|
||||
def h_jvp(primals, tangents):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
self.assertEqual(x.shape, (5,))
|
||||
self.assertEqual(x.dtype, jnp.dtype('float32'))
|
||||
self.assertEqual(x_dot.shape, (5,))
|
||||
self.assertEqual(x_dot.dtype, jnp.dtype('float32'))
|
||||
return x, x_dot
|
||||
x = h(x)
|
||||
|
||||
return 0.
|
||||
|
||||
f(jnp.zeros(5, dtype='float32')) # test assertions in the function
|
||||
|
||||
|
||||
class EArrayTest(jtu.JaxTestCase):
|
||||
@ -618,10 +727,7 @@ class EArrayTest(jtu.JaxTestCase):
|
||||
class foo(dtypes.extended): pass
|
||||
|
||||
class FooTyRules:
|
||||
|
||||
@staticmethod
|
||||
def convert_to(foo_dtype, target_dtype):
|
||||
return True
|
||||
allow_conversion: bool = True
|
||||
|
||||
@staticmethod
|
||||
def physical_element_aval(foo_dtype):
|
||||
|
@ -1486,6 +1486,9 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase):
|
||||
jax_traceback_filtering='off')
|
||||
class JumbleTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if jax.config.x64_enabled: raise unittest.SkipTest()
|
||||
|
||||
@parameterized.parameters((True,), (False,))
|
||||
def test_internal_jumble(self, disable_jit):
|
||||
with jax.disable_jit(disable_jit):
|
||||
|
Loading…
x
Reference in New Issue
Block a user