Update pjit_test to skip GDA tests with Array is enabled.

PiperOrigin-RevId: 475684445
This commit is contained in:
Yash Katariya 2022-09-20 16:38:09 -07:00 committed by jax authors
parent 310bcd57a2
commit 6183727acc

View File

@ -1090,9 +1090,13 @@ class PJitTest(jtu.BufferDonationTestCase):
pjit_f(jnp.array([1, 2, 3]))
@jax_array(False)
class GDAPjitTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
def test_pjit_gda_single_output(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)