mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #6068 from jakevdp:fix-result-type
PiperOrigin-RevId: 363282198
This commit is contained in:
commit
0a84db59ed
@ -272,20 +272,18 @@ def _promote_dtypes(*args):
|
||||
if len(args) < 2:
|
||||
return args
|
||||
else:
|
||||
to_dtype_raw = dtypes._result_type_raw(*args)
|
||||
weak_type = to_dtype_raw in set(dtypes._weak_types)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype_raw)
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
return [lax.convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
def _promote_dtypes_inexact(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype_raw = dtypes._result_type_raw(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype_raw)
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_inexact = _to_inexact_dtype(to_dtype)
|
||||
weak_type = (to_dtype == to_dtype_inexact
|
||||
and to_dtype_raw in set(dtypes._weak_types))
|
||||
weak_type = (weak_type and to_dtype == to_dtype_inexact)
|
||||
return [lax.convert_element_type(x, to_dtype_inexact, weak_type) for x in args]
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
|
@ -222,17 +222,13 @@ _jax_types = [
|
||||
np.dtype('complex128'),
|
||||
] + _weak_types # type: ignore[operator]
|
||||
|
||||
def _jax_type(value):
|
||||
"""Return the jax type for a value or type."""
|
||||
# Note: `x in _weak_types` can return false positives due to dtype comparator overloading.
|
||||
if any(value is typ for typ in _weak_types):
|
||||
return value
|
||||
dtype_ = dtype(value)
|
||||
if is_weakly_typed(value):
|
||||
pytype = type(dtype_.type(0).item())
|
||||
if pytype in _weak_types:
|
||||
return pytype
|
||||
return dtype_
|
||||
def _jax_type(dtype, weak_type):
|
||||
"""Return the jax type for a dtype and weak type."""
|
||||
return type(dtype.type(0).item()) if (weak_type and dtype != bool) else dtype
|
||||
|
||||
def _dtype_and_weaktype(value):
|
||||
"""Return a (dtype, weak_type) tuple for the given input."""
|
||||
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
|
||||
|
||||
def _type_promotion_lattice():
|
||||
"""
|
||||
@ -264,6 +260,14 @@ _lattice_upper_bounds = _make_lattice_upper_bounds()
|
||||
|
||||
@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
|
||||
def _least_upper_bound(*nodes):
|
||||
"""Compute the least upper bound of a set of nodes.
|
||||
|
||||
Args:
|
||||
nodes: sequence of entries from _jax_types
|
||||
Returns:
|
||||
the _jax_type representing the least upper bound of the input nodes
|
||||
on the promotion lattice.
|
||||
"""
|
||||
# This function computes the least upper bound of a set of nodes N within a partially
|
||||
# ordered set defined by the lattice generated above.
|
||||
# Given a partially ordered set S, let the set of upper bounds of n ∈ S be
|
||||
@ -323,13 +327,23 @@ def dtype(x):
|
||||
return python_scalar_dtypes[type(x)]
|
||||
return np.result_type(x)
|
||||
|
||||
def _result_type_raw(*args):
|
||||
if len(args) == 1:
|
||||
return _jax_type(args[0])
|
||||
return _least_upper_bound(*{_jax_type(arg) for arg in args})
|
||||
def _lattice_result_type(*args):
|
||||
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
|
||||
if len(dtypes) == 1:
|
||||
return dtypes[0], weak_types[0]
|
||||
|
||||
# If all inputs are weakly typed, we compute the bound of the strongly-typed
|
||||
# counterparts and apply the weak type at the end. This avoids returning the
|
||||
# incorrect result with non-canonical weak types (e.g. weak int16).
|
||||
if all(weak_types):
|
||||
result_type = _least_upper_bound(*{_jax_type(dtype, False) for dtype in dtypes})
|
||||
return dtype(result_type), True
|
||||
else:
|
||||
result_type = _least_upper_bound(*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
|
||||
return dtype(result_type), any(result_type is t for t in _weak_types)
|
||||
|
||||
def result_type(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
"""Convenience function to apply JAX argument dtype promotion."""
|
||||
if len(args) == 0:
|
||||
raise ValueError("at least one array or dtype is required")
|
||||
return canonicalize_dtype(_result_type_raw(*args))
|
||||
return canonicalize_dtype(_lattice_result_type(*args)[0])
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.interpreters import xla
|
||||
@ -34,7 +35,7 @@ config.parse_flags_with_absl()
|
||||
bool_dtypes = [np.dtype('bool')]
|
||||
|
||||
signed_dtypes = [np.dtype('int8'), np.dtype('int16'), np.dtype('int32'),
|
||||
np.dtype('int64'), np.dtype('longlong'), np.dtype('intc')]
|
||||
np.dtype('int64')]
|
||||
|
||||
unsigned_dtypes = [np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
|
||||
np.dtype('uint64')]
|
||||
@ -210,7 +211,7 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
"jaxtype": jaxtype}
|
||||
for jaxtype in dtypes._jax_types)
|
||||
def testJaxTypeFromType(self, jaxtype):
|
||||
self.assertIs(dtypes._jax_type(jaxtype), jaxtype)
|
||||
self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)), jaxtype)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jaxtype={}".format(jaxtype),
|
||||
@ -221,7 +222,7 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
val = jaxtype(0)
|
||||
except TypeError:
|
||||
val = jaxtype.type(0)
|
||||
self.assertIs(dtypes._jax_type(val), jaxtype)
|
||||
self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(val)), jaxtype)
|
||||
|
||||
@jtu.ignore_warning(category=UserWarning,
|
||||
message="Explicitly requested dtype.*")
|
||||
@ -327,5 +328,30 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
args_maker = lambda: [xtype(1), ytype(1)]
|
||||
self._CompileAndCheck(f, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_dtype={}_weak_type={}".format(dtype, weak_type),
|
||||
"dtype": dtype, "weak_type": weak_type}
|
||||
for dtype in all_dtypes
|
||||
for weak_type in [True, False]
|
||||
)
|
||||
def testUnaryPromotion(self, dtype, weak_type):
|
||||
# Regression test for https://github.com/google/jax/issues/6051
|
||||
x = lax.convert_element_type(0, dtype, weak_type=weak_type)
|
||||
y = jnp.array(0, dtype=dtypes.result_type(x))
|
||||
assert x.dtype == y.dtype
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_dtype={}_weak_type={}".format(dtype, weak_type),
|
||||
"dtype": dtype, "weak_type": weak_type}
|
||||
for dtype in all_dtypes
|
||||
for weak_type in [True, False]
|
||||
)
|
||||
def testBinaryNonPromotion(self, dtype, weak_type):
|
||||
# Regression test for https://github.com/google/jax/issues/6051
|
||||
x = lax.convert_element_type(0, dtype, weak_type=weak_type)
|
||||
y = (x + x)
|
||||
assert x.dtype == y.dtype
|
||||
assert dtypes.is_weakly_typed(y) == dtypes.is_weakly_typed(x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user