pjit allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test(). This is because the inner backend will take priority.

PiperOrigin-RevId: 502721834
This commit is contained in:
Yash Katariya 2023-01-17 16:38:53 -08:00 committed by jax authors
parent 8da6c89c7b
commit 53fceab17c

View File

@ -87,6 +87,10 @@ class MultiBackendTest(jtu.JaxTestCase):
raise SkipTest("Backend is not CPU or the device under test")
if outer is None and inner == jtu.device_under_test():
raise SkipTest("(None, device) is allowed")
if jax.config.jax_jit_pjit_api_merge and outer is None:
raise SkipTest("The inner device will dictate the device assignment for "
"the entire computation. So if inner is CPU and outer is "
"None, then the computation will be execute on CPU.")
@partial(jax.jit, backend=outer)
def fun(x, y):
@ -107,7 +111,7 @@ class MultiBackendTest(jtu.JaxTestCase):
raise SkipTest("Backend is not CPU or the device under test")
@partial(jax.jit, backend=backend)
def fun(x, y):
return jnp.matmul(x, y)
return jnp.matmul(x, y)
x = npr.uniform(size=(10,10))
y = npr.uniform(size=(10,10))
z = fun(x, y)