mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #17515 from jakevdp:fix-can-cast
PiperOrigin-RevId: 563793237
This commit is contained in:
commit
bc1ca0b5df
@ -711,4 +711,6 @@ def safe_to_cast(input_dtype_or_value: Any,
|
||||
output_dtype = dtype(output_dtype_or_value, canonicalize=True)
|
||||
if input_dtype == output_dtype:
|
||||
return True
|
||||
return result_type(input_dtype_or_value, output_dtype_or_value) == output_dtype
|
||||
# We deliberately use output_dtype rather than output_dtype_or_value here:
|
||||
# this effectively treats the output dtype as always strongly-typed.
|
||||
return result_type(input_dtype_or_value, output_dtype) == output_dtype
|
||||
|
@ -611,33 +611,33 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
self.assertEndsWith(rep, f"dtype={val.dtype.name})")
|
||||
|
||||
@jtu.sample_product(
|
||||
lhs_dtype=jtu.dtypes.all + [bool, int, float, complex],
|
||||
rhs_dtype=jtu.dtypes.all,
|
||||
input_dtype=jtu.dtypes.all + [bool, int, float, complex],
|
||||
output_dtype=jtu.dtypes.all + [bool, int, float, complex],
|
||||
numpy_dtype_promotion=['strict', 'standard']
|
||||
)
|
||||
def testSafeToCast(self, lhs_dtype, rhs_dtype, numpy_dtype_promotion):
|
||||
def testSafeToCast(self, input_dtype, output_dtype, numpy_dtype_promotion):
|
||||
with jax.numpy_dtype_promotion(numpy_dtype_promotion):
|
||||
# First the special cases which are always safe:
|
||||
always_safe = (
|
||||
(lhs_dtype == rhs_dtype) or
|
||||
(dtypes.issubdtype(rhs_dtype, np.integer) and rhs_dtype in {bool, int}) or
|
||||
(dtypes.issubdtype(rhs_dtype, np.floating) and rhs_dtype in {bool, int, float}) or
|
||||
(dtypes.issubdtype(rhs_dtype, np.complexfloating) and rhs_dtype in {bool, int, float, complex})
|
||||
(input_dtype == output_dtype) or
|
||||
(dtypes.issubdtype(output_dtype, np.integer) and input_dtype in {int}) or
|
||||
(dtypes.issubdtype(output_dtype, np.floating) and input_dtype in {int, float}) or
|
||||
(dtypes.issubdtype(output_dtype, np.complexfloating) and input_dtype in {int, float, complex})
|
||||
)
|
||||
if always_safe:
|
||||
self.assertTrue(dtypes.safe_to_cast(lhs_dtype, rhs_dtype))
|
||||
self.assertTrue(dtypes.safe_to_cast(input_dtype, output_dtype))
|
||||
|
||||
try:
|
||||
result_dtype = dtypes.result_type(lhs_dtype, rhs_dtype)
|
||||
result_dtype = dtypes.result_type(input_dtype, dtypes.canonicalize_dtype(output_dtype))
|
||||
except dtypes.TypePromotionError:
|
||||
result_dtype = None
|
||||
|
||||
if result_dtype is None and lhs_dtype != rhs_dtype:
|
||||
if result_dtype is None and input_dtype != output_dtype:
|
||||
with self.assertRaises(dtypes.TypePromotionError):
|
||||
dtypes.safe_to_cast(lhs_dtype, rhs_dtype)
|
||||
dtypes.safe_to_cast(input_dtype, output_dtype)
|
||||
else:
|
||||
self.assertEqual(dtypes.result_type(rhs_dtype) == result_dtype,
|
||||
dtypes.safe_to_cast(lhs_dtype, rhs_dtype))
|
||||
self.assertEqual(dtypes.result_type(output_dtype) == result_dtype,
|
||||
dtypes.safe_to_cast(input_dtype, output_dtype))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user