mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jnp.result_type: respect default types
This commit is contained in:
parent
1cfb54bdec
commit
c67e3d86c9
@ -51,6 +51,7 @@ complex_: type = np.complex128
|
||||
# uint = np.uint32
|
||||
# float_ = np.float32
|
||||
# complex_ = np.complex64
|
||||
_default_types = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
|
||||
|
||||
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
||||
float0 = np.dtype([('float0', np.void, 0)])
|
||||
@ -368,4 +369,7 @@ def result_type(*args):
|
||||
"""Convenience function to apply JAX argument dtype promotion."""
|
||||
if len(args) == 0:
|
||||
raise ValueError("at least one array or dtype is required")
|
||||
return canonicalize_dtype(_lattice_result_type(*args)[0])
|
||||
dtype, weak_type = _lattice_result_type(*args)
|
||||
if weak_type:
|
||||
dtype = _default_types['f' if dtype == _bfloat16_dtype else dtype.kind]
|
||||
return canonicalize_dtype(dtype)
|
||||
|
@ -342,8 +342,12 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
def testUnaryPromotion(self, dtype, weak_type):
|
||||
# Regression test for https://github.com/google/jax/issues/6051
|
||||
x = lax._convert_element_type(0, dtype, weak_type=weak_type)
|
||||
y = jnp.array(0, dtype=dtypes.result_type(x))
|
||||
assert x.dtype == y.dtype
|
||||
if weak_type:
|
||||
expected = dtypes.canonicalize_dtype(
|
||||
dtypes._default_types['f' if x.dtype == 'bfloat16' else x.dtype.kind])
|
||||
else:
|
||||
expected = x.dtype
|
||||
self.assertEqual(dtypes.result_type(x), expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_dtype={}_weak_type={}".format(dtype, weak_type),
|
||||
|
@ -4016,6 +4016,21 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(np.arange(2.5, dtype=jnp.float_),
|
||||
jnp.arange(2.5))
|
||||
|
||||
def testArangeTypes(self):
|
||||
# Test that arange() output type is equal to the default types.
|
||||
int_ = dtypes.canonicalize_dtype(jnp.int_)
|
||||
float_ = dtypes.canonicalize_dtype(jnp.float_)
|
||||
|
||||
self.assertEqual(jnp.arange(10).dtype, int_)
|
||||
self.assertEqual(jnp.arange(10.).dtype, float_)
|
||||
self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16)
|
||||
self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16)
|
||||
|
||||
self.assertEqual(jnp.arange(0, 10, 1).dtype, int_)
|
||||
self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_)
|
||||
self.assertEqual(jnp.arange(0., 10, 1).dtype, float_)
|
||||
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
|
Loading…
x
Reference in New Issue
Block a user