From ee8af0985133564982ff74151d56784e7c5e1be1 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Tue, 3 Oct 2023 10:54:40 -0700 Subject: [PATCH] Fix mock_gpu_test on OSS build. PiperOrigin-RevId: 570436380 --- tests/mock_gpu_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/mock_gpu_test.py b/tests/mock_gpu_test.py index dced8882c..24a19a2b1 100644 --- a/tests/mock_gpu_test.py +++ b/tests/mock_gpu_test.py @@ -28,6 +28,7 @@ import numpy as np config.parse_flags_with_absl() +@jtu.run_on_devices('gpu') class MockGPUTest(jtu.JaxTestCase): def setUp(self): @@ -44,10 +45,7 @@ class MockGPUTest(jtu.JaxTestCase): return num_shards = 16 jax.config.update('mock_num_gpus', num_shards) - mesh_shape = (num_shards,) - axis_names = ('x',) - mesh_devices = np.array(jax.devices()).reshape(mesh_shape) - mesh = jax.sharding.Mesh(mesh_devices, axis_names) + mesh = jtu.create_global_mesh((num_shards,), ('x',)) @partial( jax.jit, in_shardings=NamedSharding(mesh, P('x',)),