mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make jit work with custom float inputs
This commit is contained in:
parent
640488883e
commit
31c5044c1d
@ -254,6 +254,8 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types
|
||||
x = np.array(0 if x.item() == 0 else 0xff, np.uint8)
|
||||
elif x.dtype == dtypes.bfloat16:
|
||||
x = x.view(np.uint16)
|
||||
elif x.dtype in [dtypes.float8_e4m3b11fnuz, dtypes.float8_e4m3fn, dtypes.float8_e5m2]:
|
||||
x = x.view(np.uint8)
|
||||
x = np.ascontiguousarray(x)
|
||||
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
|
||||
return (hlo.ConstantOp(attr).result,)
|
||||
|
@ -1136,6 +1136,12 @@ class _LazyDtypes:
|
||||
supported = supported_dtypes()
|
||||
return type(dtypes)(d for d in dtypes if d in supported)
|
||||
|
||||
@_cached_property
|
||||
def custom_floats(self):
|
||||
return [np.dtype(t) for t in [
|
||||
_dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz,
|
||||
_dtypes.float8_e4m3fn, _dtypes.float8_e5m2]]
|
||||
|
||||
@_cached_property
|
||||
def floating(self):
|
||||
return self.supported([np.float32, np.float64])
|
||||
|
@ -4264,6 +4264,14 @@ class APITest(jtu.JaxTestCase):
|
||||
tracing_add_count += 1
|
||||
self.assertEqual(tracing_add_count, 2)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{"testcase_name": f"{dtype}", "dtype": dtype}
|
||||
for dtype in jtu.dtypes.custom_floats])
|
||||
def test_jit_custom_floats(self, dtype):
|
||||
f = lambda x: x + 1
|
||||
args_maker = lambda: [jnp.ones((), dtype=dtype)]
|
||||
self._CompileAndCheck(f, args_maker)
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user