Make jit work with custom float inputs

This commit is contained in:
Jake VanderPlas 2023-07-12 13:06:03 -07:00
parent 640488883e
commit 31c5044c1d
3 changed files with 16 additions and 0 deletions

View File

@ -254,6 +254,8 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types
x = np.array(0 if x.item() == 0 else 0xff, np.uint8)
elif x.dtype == dtypes.bfloat16:
x = x.view(np.uint16)
elif x.dtype in [dtypes.float8_e4m3b11fnuz, dtypes.float8_e4m3fn, dtypes.float8_e5m2]:
x = x.view(np.uint8)
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
return (hlo.ConstantOp(attr).result,)

View File

@ -1136,6 +1136,12 @@ class _LazyDtypes:
supported = supported_dtypes()
return type(dtypes)(d for d in dtypes if d in supported)
@_cached_property
def custom_floats(self):
return [np.dtype(t) for t in [
_dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz,
_dtypes.float8_e4m3fn, _dtypes.float8_e5m2]]
@_cached_property
def floating(self):
return self.supported([np.float32, np.float64])

View File

@ -4264,6 +4264,14 @@ class APITest(jtu.JaxTestCase):
tracing_add_count += 1
self.assertEqual(tracing_add_count, 2)
@parameterized.named_parameters([
{"testcase_name": f"{dtype}", "dtype": dtype}
for dtype in jtu.dtypes.custom_floats])
def test_jit_custom_floats(self, dtype):
f = lambda x: x + 1
args_maker = lambda: [jnp.ones((), dtype=dtype)]
self._CompileAndCheck(f, args_maker)
class RematTest(jtu.JaxTestCase):