Fix mock_gpu_test on OSS build.

PiperOrigin-RevId: 570436380
This commit is contained in:
Tao Wang 2023-10-03 10:54:40 -07:00 committed by jax authors
parent c3e73c67aa
commit ee8af09851

View File

@ -28,6 +28,7 @@ import numpy as np
config.parse_flags_with_absl()
@jtu.run_on_devices('gpu')
class MockGPUTest(jtu.JaxTestCase):
def setUp(self):
@ -44,10 +45,7 @@ class MockGPUTest(jtu.JaxTestCase):
return
num_shards = 16
jax.config.update('mock_num_gpus', num_shards)
mesh_shape = (num_shards,)
axis_names = ('x',)
mesh_devices = np.array(jax.devices()).reshape(mesh_shape)
mesh = jax.sharding.Mesh(mesh_devices, axis_names)
mesh = jtu.create_global_mesh((num_shards,), ('x',))
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, P('x',)),