Merge pull request #17515 from jakevdp:fix-can-cast

PiperOrigin-RevId: 563793237
This commit is contained in:
jax authors 2023-09-08 10:44:23 -07:00
commit bc1ca0b5df
2 changed files with 16 additions and 14 deletions

View File

@ -711,4 +711,6 @@ def safe_to_cast(input_dtype_or_value: Any,
output_dtype = dtype(output_dtype_or_value, canonicalize=True)
if input_dtype == output_dtype:
return True
return result_type(input_dtype_or_value, output_dtype_or_value) == output_dtype
# We deliberately use output_dtype rather than output_dtype_or_value here:
# this effectively treats the output dtype as always strongly-typed.
return result_type(input_dtype_or_value, output_dtype) == output_dtype

View File

@ -611,33 +611,33 @@ class TestPromotionTables(jtu.JaxTestCase):
self.assertEndsWith(rep, f"dtype={val.dtype.name})")
@jtu.sample_product(
lhs_dtype=jtu.dtypes.all + [bool, int, float, complex],
rhs_dtype=jtu.dtypes.all,
input_dtype=jtu.dtypes.all + [bool, int, float, complex],
output_dtype=jtu.dtypes.all + [bool, int, float, complex],
numpy_dtype_promotion=['strict', 'standard']
)
def testSafeToCast(self, lhs_dtype, rhs_dtype, numpy_dtype_promotion):
def testSafeToCast(self, input_dtype, output_dtype, numpy_dtype_promotion):
with jax.numpy_dtype_promotion(numpy_dtype_promotion):
# First the special cases which are always safe:
always_safe = (
(lhs_dtype == rhs_dtype) or
(dtypes.issubdtype(rhs_dtype, np.integer) and rhs_dtype in {bool, int}) or
(dtypes.issubdtype(rhs_dtype, np.floating) and rhs_dtype in {bool, int, float}) or
(dtypes.issubdtype(rhs_dtype, np.complexfloating) and rhs_dtype in {bool, int, float, complex})
(input_dtype == output_dtype) or
(dtypes.issubdtype(output_dtype, np.integer) and input_dtype in {int}) or
(dtypes.issubdtype(output_dtype, np.floating) and input_dtype in {int, float}) or
(dtypes.issubdtype(output_dtype, np.complexfloating) and input_dtype in {int, float, complex})
)
if always_safe:
self.assertTrue(dtypes.safe_to_cast(lhs_dtype, rhs_dtype))
self.assertTrue(dtypes.safe_to_cast(input_dtype, output_dtype))
try:
result_dtype = dtypes.result_type(lhs_dtype, rhs_dtype)
result_dtype = dtypes.result_type(input_dtype, dtypes.canonicalize_dtype(output_dtype))
except dtypes.TypePromotionError:
result_dtype = None
if result_dtype is None and lhs_dtype != rhs_dtype:
if result_dtype is None and input_dtype != output_dtype:
with self.assertRaises(dtypes.TypePromotionError):
dtypes.safe_to_cast(lhs_dtype, rhs_dtype)
dtypes.safe_to_cast(input_dtype, output_dtype)
else:
self.assertEqual(dtypes.result_type(rhs_dtype) == result_dtype,
dtypes.safe_to_cast(lhs_dtype, rhs_dtype))
self.assertEqual(dtypes.result_type(output_dtype) == result_dtype,
dtypes.safe_to_cast(input_dtype, output_dtype))
if __name__ == "__main__":