mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[x64] use explicit casting rules for promote_dtypes_inexact
This commit is contained in:
parent
d849f49519
commit
5f7cd72130
@ -55,6 +55,26 @@ _dtype_to_32bit_dtype = {
|
||||
np.dtype('complex128'): np.dtype('complex64'),
|
||||
}
|
||||
|
||||
# Note: we promote narrow types to float32 here for backward compatibility
|
||||
# with earlier approaches. We might consider revisiting this, or perhaps
|
||||
# tying the logic more closely to the type promotion lattice.
|
||||
_dtype_to_inexact = {
|
||||
np.dtype(k): np.dtype(v) for k, v in [
|
||||
('bool', 'float32'),
|
||||
('uint8', 'float32'), ('int8', 'float32'),
|
||||
('uint16', 'float32'), ('int16', 'float32'),
|
||||
('uint32', 'float32'), ('int32', 'float32'),
|
||||
('uint64', 'float64'), ('int64', 'float64')
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
"""Promotes a dtype into an inexact dtype, if it is not already one."""
|
||||
dtype = np.dtype(dtype)
|
||||
return _dtype_to_inexact.get(dtype, dtype)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _canonicalize_dtype(x64_enabled, dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
|
@ -270,17 +270,12 @@ def _promote_dtypes_inexact(*args):
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_inexact = _to_inexact_dtype(to_dtype)
|
||||
to_dtype_inexact = dtypes._to_inexact_dtype(to_dtype)
|
||||
weak_type = (weak_type and to_dtype == to_dtype_inexact)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
|
||||
for x in args]
|
||||
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
"""Promotes a dtype into an inexact dtype, if it is not already one."""
|
||||
return dtype if dtypes.issubdtype(dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_)
|
||||
|
||||
|
||||
def _complex_elem_type(dtype):
|
||||
"""Returns the float type of the real/imaginary parts of a complex dtype."""
|
||||
return np.abs(np.zeros((), dtype)).dtype
|
||||
|
@ -789,7 +789,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
|
||||
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
|
||||
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
|
||||
np_fun = _promote_like_jnp(np_fun, inexact)
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -828,7 +828,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
res = np_op(x_cast, axis, keepdims=keepdims)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
np_fun = _promote_like_jnp(np_fun, inexact)
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
tol = {np.float16: 0.002}
|
||||
@ -853,13 +853,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
np_fun = _promote_like_jnp(np_fun, inexact)
|
||||
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -893,13 +893,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Degrees of freedom <= 0 for slice.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
np_fun = _promote_like_jnp(np_fun, inexact)
|
||||
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
|
||||
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -932,15 +932,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="Mean of empty slice.*")
|
||||
@jtu.ignore_warning(category=RuntimeWarning,
|
||||
message="invalid value encountered in true_divide*")
|
||||
message="invalid value encountered.*")
|
||||
@jtu.ignore_warning(category=np.ComplexWarning)
|
||||
def np_fun(x):
|
||||
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
|
||||
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
|
||||
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
|
||||
return res
|
||||
|
||||
np_fun = _promote_like_jnp(np_fun, inexact)
|
||||
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
|
||||
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
|
||||
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -3744,7 +3743,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
np_fun = _promote_like_jnp(np_fun, inexact=True)
|
||||
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5,
|
||||
np.float64: 1e-12, np.complex64: 1e-5}
|
||||
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
|
||||
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE and numpy_version >= (1, 22)
|
||||
try:
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
|
||||
check_dtypes=check_dtypes, tol=tol)
|
||||
|
@ -30,6 +30,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import lax
|
||||
from jax import scipy as jsp
|
||||
from jax.tree_util import tree_map
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
from jax.scipy import cluster as lsp_cluster
|
||||
@ -181,8 +182,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
# TODO(mattjj): test autodiff
|
||||
if use_b:
|
||||
def scipy_fun(array_to_reduce, scale_array):
|
||||
return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
|
||||
return_sign=return_sign, b=scale_array)
|
||||
res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
|
||||
return_sign=return_sign, b=scale_array)
|
||||
if dtype == np.int32:
|
||||
res = tree_map(lambda x: x.astype('float32'), res)
|
||||
return res
|
||||
|
||||
def lax_fun(array_to_reduce, scale_array):
|
||||
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
|
||||
@ -191,8 +195,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
|
||||
else:
|
||||
def scipy_fun(array_to_reduce):
|
||||
return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
|
||||
return_sign=return_sign)
|
||||
res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
|
||||
return_sign=return_sign)
|
||||
if dtype == np.int32:
|
||||
res = tree_map(lambda x: x.astype('float32'), res)
|
||||
return res
|
||||
|
||||
def lax_fun(array_to_reduce):
|
||||
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
|
||||
|
Loading…
x
Reference in New Issue
Block a user