Merge pull request #6068 from jakevdp:fix-result-type

PiperOrigin-RevId: 363282198
This commit is contained in:
jax authors 2021-03-16 15:20:40 -07:00
commit 0a84db59ed
3 changed files with 65 additions and 27 deletions

View File

@ -272,20 +272,18 @@ def _promote_dtypes(*args):
if len(args) < 2:
return args
else:
to_dtype_raw = dtypes._result_type_raw(*args)
weak_type = to_dtype_raw in set(dtypes._weak_types)
to_dtype = dtypes.canonicalize_dtype(to_dtype_raw)
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax.convert_element_type(x, to_dtype, weak_type) for x in args]
def _promote_dtypes_inexact(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
to_dtype_raw = dtypes._result_type_raw(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype_raw)
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = _to_inexact_dtype(to_dtype)
weak_type = (to_dtype == to_dtype_inexact
and to_dtype_raw in set(dtypes._weak_types))
weak_type = (weak_type and to_dtype == to_dtype_inexact)
return [lax.convert_element_type(x, to_dtype_inexact, weak_type) for x in args]
def _to_inexact_dtype(dtype):

View File

@ -222,17 +222,13 @@ _jax_types = [
np.dtype('complex128'),
] + _weak_types # type: ignore[operator]
def _jax_type(value):
"""Return the jax type for a value or type."""
# Note: `x in _weak_types` can return false positives due to dtype comparator overloading.
if any(value is typ for typ in _weak_types):
return value
dtype_ = dtype(value)
if is_weakly_typed(value):
pytype = type(dtype_.type(0).item())
if pytype in _weak_types:
return pytype
return dtype_
def _jax_type(dtype, weak_type):
"""Return the jax type for a dtype and weak type."""
return type(dtype.type(0).item()) if (weak_type and dtype != bool) else dtype
def _dtype_and_weaktype(value):
"""Return a (dtype, weak_type) tuple for the given input."""
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
def _type_promotion_lattice():
"""
@ -264,6 +260,14 @@ _lattice_upper_bounds = _make_lattice_upper_bounds()
@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
def _least_upper_bound(*nodes):
"""Compute the least upper bound of a set of nodes.
Args:
nodes: sequence of entries from _jax_types
Returns:
the _jax_type representing the least upper bound of the input nodes
on the promotion lattice.
"""
# This function computes the least upper bound of a set of nodes N within a partially
# ordered set defined by the lattice generated above.
# Given a partially ordered set S, let the set of upper bounds of n ∈ S be
@ -323,13 +327,23 @@ def dtype(x):
return python_scalar_dtypes[type(x)]
return np.result_type(x)
def _result_type_raw(*args):
if len(args) == 1:
return _jax_type(args[0])
return _least_upper_bound(*{_jax_type(arg) for arg in args})
def _lattice_result_type(*args):
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
if len(dtypes) == 1:
return dtypes[0], weak_types[0]
# If all inputs are weakly typed, we compute the bound of the strongly-typed
# counterparts and apply the weak type at the end. This avoids returning the
# incorrect result with non-canonical weak types (e.g. weak int16).
if all(weak_types):
result_type = _least_upper_bound(*{_jax_type(dtype, False) for dtype in dtypes})
return dtype(result_type), True
else:
result_type = _least_upper_bound(*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
return dtype(result_type), any(result_type is t for t in _weak_types)
def result_type(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
"""Convenience function to apply JAX argument dtype promotion."""
if len(args) == 0:
raise ValueError("at least one array or dtype is required")
return canonicalize_dtype(_result_type_raw(*args))
return canonicalize_dtype(_lattice_result_type(*args)[0])

View File

@ -24,6 +24,7 @@ import numpy as np
import jax
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax.interpreters import xla
@ -34,7 +35,7 @@ config.parse_flags_with_absl()
bool_dtypes = [np.dtype('bool')]
signed_dtypes = [np.dtype('int8'), np.dtype('int16'), np.dtype('int32'),
np.dtype('int64'), np.dtype('longlong'), np.dtype('intc')]
np.dtype('int64')]
unsigned_dtypes = [np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
np.dtype('uint64')]
@ -210,7 +211,7 @@ class TestPromotionTables(jtu.JaxTestCase):
"jaxtype": jaxtype}
for jaxtype in dtypes._jax_types)
def testJaxTypeFromType(self, jaxtype):
self.assertIs(dtypes._jax_type(jaxtype), jaxtype)
self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)), jaxtype)
@parameterized.named_parameters(
{"testcase_name": "_jaxtype={}".format(jaxtype),
@ -221,7 +222,7 @@ class TestPromotionTables(jtu.JaxTestCase):
val = jaxtype(0)
except TypeError:
val = jaxtype.type(0)
self.assertIs(dtypes._jax_type(val), jaxtype)
self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(val)), jaxtype)
@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype.*")
@ -327,5 +328,30 @@ class TestPromotionTables(jtu.JaxTestCase):
args_maker = lambda: [xtype(1), ytype(1)]
self._CompileAndCheck(f, args_maker, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": "_dtype={}_weak_type={}".format(dtype, weak_type),
"dtype": dtype, "weak_type": weak_type}
for dtype in all_dtypes
for weak_type in [True, False]
)
def testUnaryPromotion(self, dtype, weak_type):
# Regression test for https://github.com/google/jax/issues/6051
x = lax.convert_element_type(0, dtype, weak_type=weak_type)
y = jnp.array(0, dtype=dtypes.result_type(x))
assert x.dtype == y.dtype
@parameterized.named_parameters(
{"testcase_name": "_dtype={}_weak_type={}".format(dtype, weak_type),
"dtype": dtype, "weak_type": weak_type}
for dtype in all_dtypes
for weak_type in [True, False]
)
def testBinaryNonPromotion(self, dtype, weak_type):
# Regression test for https://github.com/google/jax/issues/6051
x = lax.convert_element_type(0, dtype, weak_type=weak_type)
y = (x + x)
assert x.dtype == y.dtype
assert dtypes.is_weakly_typed(y) == dtypes.is_weakly_typed(x)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())