[x64] use explicit casting rules for promote_dtypes_inexact

This commit is contained in:
Jake VanderPlas 2022-05-24 15:51:44 -07:00
parent d849f49519
commit 5f7cd72130
4 changed files with 41 additions and 20 deletions

View File

@ -55,6 +55,26 @@ _dtype_to_32bit_dtype = {
np.dtype('complex128'): np.dtype('complex64'),
}
# Note: we promote narrow types to float32 here for backward compatibility
# with earlier approaches. We might consider revisiting this, or perhaps
# tying the logic more closely to the type promotion lattice.
_dtype_to_inexact = {
np.dtype(k): np.dtype(v) for k, v in [
('bool', 'float32'),
('uint8', 'float32'), ('int8', 'float32'),
('uint16', 'float32'), ('int16', 'float32'),
('uint32', 'float32'), ('int32', 'float32'),
('uint64', 'float64'), ('int64', 'float64')
]
}
def _to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
dtype = np.dtype(dtype)
return _dtype_to_inexact.get(dtype, dtype)
@functools.lru_cache(maxsize=None)
def _canonicalize_dtype(x64_enabled, dtype):
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""

View File

@ -270,17 +270,12 @@ def _promote_dtypes_inexact(*args):
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = _to_inexact_dtype(to_dtype)
to_dtype_inexact = dtypes._to_inexact_dtype(to_dtype)
weak_type = (weak_type and to_dtype == to_dtype_inexact)
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]
def _to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
return dtype if dtypes.issubdtype(dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_)
def _complex_elem_type(dtype):
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype

View File

@ -789,7 +789,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32)
t = out_dtype if out_dtype != jnp.bfloat16 else np.float32
return np_op(x_cast, axis, dtype=t, keepdims=keepdims)
np_fun = _promote_like_jnp(np_fun, inexact)
jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
@ -828,7 +828,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
res = np_op(x_cast, axis, keepdims=keepdims)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims)
args_maker = lambda: [rng(shape, dtype)]
tol = {np.float16: 0.002}
@ -853,13 +853,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
@ -893,13 +893,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
@ -932,15 +932,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@jtu.ignore_warning(category=RuntimeWarning,
message="Mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="invalid value encountered in true_divide*")
message="invalid value encountered.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res
np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
@ -3744,7 +3743,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
np_fun = _promote_like_jnp(np_fun, inexact=True)
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5,
np.float64: 1e-12, np.complex64: 1e-5}
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE
check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE and numpy_version >= (1, 22)
try:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=check_dtypes, tol=tol)

View File

@ -30,6 +30,7 @@ import jax
from jax import numpy as jnp
from jax import lax
from jax import scipy as jsp
from jax.tree_util import tree_map
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
from jax.scipy import cluster as lsp_cluster
@ -181,8 +182,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
# TODO(mattjj): test autodiff
if use_b:
def scipy_fun(array_to_reduce, scale_array):
return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
return_sign=return_sign, b=scale_array)
res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
return_sign=return_sign, b=scale_array)
if dtype == np.int32:
res = tree_map(lambda x: x.astype('float32'), res)
return res
def lax_fun(array_to_reduce, scale_array):
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
@ -191,8 +195,11 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
else:
def scipy_fun(array_to_reduce):
return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
return_sign=return_sign)
res = osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
return_sign=return_sign)
if dtype == np.int32:
res = tree_map(lambda x: x.astype('float32'), res)
return res
def lax_fun(array_to_reduce):
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,