jnp.result_type: respect default types

This commit is contained in:
Jake VanderPlas 2021-11-01 11:44:14 -07:00
parent 1cfb54bdec
commit c67e3d86c9
3 changed files with 26 additions and 3 deletions

View File

@ -51,6 +51,7 @@ complex_: type = np.complex128
# uint = np.uint32
# float_ = np.float32
# complex_ = np.complex64
_default_types = {'b': bool_, 'i': int_, 'u': uint, 'f': float_, 'c': complex_}
# Trivial vectorspace datatype needed for tangent values of int/bool primals
float0 = np.dtype([('float0', np.void, 0)])
@ -368,4 +369,7 @@ def result_type(*args):
"""Convenience function to apply JAX argument dtype promotion."""
if len(args) == 0:
raise ValueError("at least one array or dtype is required")
return canonicalize_dtype(_lattice_result_type(*args)[0])
dtype, weak_type = _lattice_result_type(*args)
if weak_type:
dtype = _default_types['f' if dtype == _bfloat16_dtype else dtype.kind]
return canonicalize_dtype(dtype)

View File

@ -342,8 +342,12 @@ class TestPromotionTables(jtu.JaxTestCase):
def testUnaryPromotion(self, dtype, weak_type):
# Regression test for https://github.com/google/jax/issues/6051
x = lax._convert_element_type(0, dtype, weak_type=weak_type)
y = jnp.array(0, dtype=dtypes.result_type(x))
assert x.dtype == y.dtype
if weak_type:
expected = dtypes.canonicalize_dtype(
dtypes._default_types['f' if x.dtype == 'bfloat16' else x.dtype.kind])
else:
expected = x.dtype
self.assertEqual(dtypes.result_type(x), expected)
@parameterized.named_parameters(
{"testcase_name": "_dtype={}_weak_type={}".format(dtype, weak_type),

View File

@ -4016,6 +4016,21 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertAllClose(np.arange(2.5, dtype=jnp.float_),
jnp.arange(2.5))
def testArangeTypes(self):
# Test that arange() output type is equal to the default types.
int_ = dtypes.canonicalize_dtype(jnp.int_)
float_ = dtypes.canonicalize_dtype(jnp.float_)
self.assertEqual(jnp.arange(10).dtype, int_)
self.assertEqual(jnp.arange(10.).dtype, float_)
self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16)
self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16)
self.assertEqual(jnp.arange(0, 10, 1).dtype, int_)
self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_)
self.assertEqual(jnp.arange(0., 10, 1).dtype, float_)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),