Return PyBuffer directly from the C++ jax.jit.

PiperOrigin-RevId: 346080315
This commit is contained in:
Jean-Baptiste Lespiau 2020-12-07 06:36:02 -08:00 committed by jax authors
parent f969a74be2
commit 57f1bab09b
2 changed files with 23 additions and 1 deletions

View File

@ -308,7 +308,18 @@ def _cpp_jit(
### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
xla_executable, _, result_handlers = execute.args
fastpath_data = (xla_executable, result_handlers, out_pytree_def)
sticky_device = None
avals = []
lazy_exprs = []
for result_handler in result_handlers:
aval, sticky_device, lazy_expr = result_handler.args
avals.append(aval)
lazy_exprs.append(None if xla.lazy.is_trivial(lazy_expr) else lazy_expr)
assert len(avals) == len(out_flat)
if version >= (0, 1, 58):
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
else:
fastpath_data = (xla_executable, result_handlers, out_pytree_def)
else:
fastpath_data = None

View File

@ -462,6 +462,17 @@ class CPPJitTest(jtu.JaxTestCase):
re.escape("static arguments should be comparable using __eq__")):
jitted_f(1, HashableWithoutEq())
def test_cpp_jitted_function_returns_PyBuffer(self):
if version < (0, 1, 58):
raise unittest.SkipTest("Disabled because it depends on some future "
"release of jax_jit.cc within jaxlib.")
if self.jit != jax.api._cpp_jit:
raise unittest.SkipTest("this test only applies to _cpp_jit")
jitted_f = self.jit(lambda a: a + 1)
jitted_f(1)
self.assertIsInstance(jitted_f(2), xla._CppDeviceArray)
class PythonJitTest(CPPJitTest):