mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix mock_gpu_test on OSS build.
PiperOrigin-RevId: 570436380
This commit is contained in:
parent
c3e73c67aa
commit
ee8af09851
@ -28,6 +28,7 @@ import numpy as np
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.run_on_devices('gpu')
|
||||
class MockGPUTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -44,10 +45,7 @@ class MockGPUTest(jtu.JaxTestCase):
|
||||
return
|
||||
num_shards = 16
|
||||
jax.config.update('mock_num_gpus', num_shards)
|
||||
mesh_shape = (num_shards,)
|
||||
axis_names = ('x',)
|
||||
mesh_devices = np.array(jax.devices()).reshape(mesh_shape)
|
||||
mesh = jax.sharding.Mesh(mesh_devices, axis_names)
|
||||
mesh = jtu.create_global_mesh((num_shards,), ('x',))
|
||||
@partial(
|
||||
jax.jit,
|
||||
in_shardings=NamedSharding(mesh, P('x',)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user