mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update pjit_test to skip GDA tests with Array is enabled.
PiperOrigin-RevId: 475684445
This commit is contained in:
parent
310bcd57a2
commit
6183727acc
@ -1090,9 +1090,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
pjit_f(jnp.array([1, 2, 3]))
|
||||
|
||||
|
||||
@jax_array(False)
|
||||
class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
|
||||
def test_pjit_gda_single_output(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user