mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
pjit
allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test()
. This is because the inner backend will take priority.
PiperOrigin-RevId: 502721834
This commit is contained in:
parent
8da6c89c7b
commit
53fceab17c
@ -87,6 +87,10 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
if outer is None and inner == jtu.device_under_test():
|
||||
raise SkipTest("(None, device) is allowed")
|
||||
if jax.config.jax_jit_pjit_api_merge and outer is None:
|
||||
raise SkipTest("The inner device will dictate the device assignment for "
|
||||
"the entire computation. So if inner is CPU and outer is "
|
||||
"None, then the computation will be execute on CPU.")
|
||||
|
||||
@partial(jax.jit, backend=outer)
|
||||
def fun(x, y):
|
||||
@ -107,7 +111,7 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
@partial(jax.jit, backend=backend)
|
||||
def fun(x, y):
|
||||
return jnp.matmul(x, y)
|
||||
return jnp.matmul(x, y)
|
||||
x = npr.uniform(size=(10,10))
|
||||
y = npr.uniform(size=(10,10))
|
||||
z = fun(x, y)
|
||||
|
Loading…
x
Reference in New Issue
Block a user