mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Return PyBuffer directly from the C++ jax.jit.
PiperOrigin-RevId: 346080315
This commit is contained in:
parent
f969a74be2
commit
57f1bab09b
13
jax/api.py
13
jax/api.py
@ -308,7 +308,18 @@ def _cpp_jit(
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
if use_fastpath:
|
||||
xla_executable, _, result_handlers = execute.args
|
||||
fastpath_data = (xla_executable, result_handlers, out_pytree_def)
|
||||
sticky_device = None
|
||||
avals = []
|
||||
lazy_exprs = []
|
||||
for result_handler in result_handlers:
|
||||
aval, sticky_device, lazy_expr = result_handler.args
|
||||
avals.append(aval)
|
||||
lazy_exprs.append(None if xla.lazy.is_trivial(lazy_expr) else lazy_expr)
|
||||
assert len(avals) == len(out_flat)
|
||||
if version >= (0, 1, 58):
|
||||
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
|
||||
else:
|
||||
fastpath_data = (xla_executable, result_handlers, out_pytree_def)
|
||||
else:
|
||||
fastpath_data = None
|
||||
|
||||
|
@ -462,6 +462,17 @@ class CPPJitTest(jtu.JaxTestCase):
|
||||
re.escape("static arguments should be comparable using __eq__")):
|
||||
jitted_f(1, HashableWithoutEq())
|
||||
|
||||
def test_cpp_jitted_function_returns_PyBuffer(self):
|
||||
if version < (0, 1, 58):
|
||||
raise unittest.SkipTest("Disabled because it depends on some future "
|
||||
"release of jax_jit.cc within jaxlib.")
|
||||
if self.jit != jax.api._cpp_jit:
|
||||
raise unittest.SkipTest("this test only applies to _cpp_jit")
|
||||
|
||||
jitted_f = self.jit(lambda a: a + 1)
|
||||
jitted_f(1)
|
||||
self.assertIsInstance(jitted_f(2), xla._CppDeviceArray)
|
||||
|
||||
|
||||
class PythonJitTest(CPPJitTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user