mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
Fix lint
This commit is contained in:
parent
c099e8081d
commit
ccbe9f7cd6
@ -730,6 +730,12 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
|
||||
"promotion path. To avoid unintended promotion, 8-bit floats do not support "
|
||||
"implicit promotion. If you'd like your inputs to be promoted to another type, "
|
||||
"you can do so explicitly using e.g. x.astype('float32')")
|
||||
elif any(n in _float4_dtypes for n in nodes):
|
||||
msg = (
|
||||
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
|
||||
"promotion path. To avoid unintended promotion, 4-bit floats do not support "
|
||||
"implicit promotion. If you'd like your inputs to be promoted to another type, "
|
||||
"you can do so explicitly using e.g. x.astype('float32')")
|
||||
elif any(n in _intn_dtypes for n in nodes):
|
||||
msg = (
|
||||
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
|
||||
|
@ -989,7 +989,8 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
def testFloat4PromotionError(self):
|
||||
for dtype in fp4_dtypes:
|
||||
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest("TPU does not support float4_e2m1fn.")
|
||||
# TPU does not support float4_e2m1fn.
|
||||
continue
|
||||
x = jnp.array(1, dtype=dtype)
|
||||
y = jnp.array(1, dtype='float32')
|
||||
with self.assertRaisesRegex(dtypes.TypePromotionError,
|
||||
@ -1055,7 +1056,7 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
if dtype == dtypes.float8_e8m0fnu and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest('TPU does not support float8_e8m0fnu.')
|
||||
if dtype == dtypes.float4_e2m1fn and jtu.test_device_matches(['tpu']):
|
||||
self.skipTest('TPU does not support float4_e2m1fn.')
|
||||
self.skipTest('TPU does not support float4_e2m1fn.')
|
||||
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
|
||||
rep = repr(val)
|
||||
self.assertStartsWith(rep, 'Array(')
|
||||
|
Loading…
x
Reference in New Issue
Block a user