mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Give GDA the ability to fetch the value to host if all devices are addressable and jax_array flag is enabled (just like SDA).
PiperOrigin-RevId: 449790827
This commit is contained in:
parent
478a95ab74
commit
1313a8456d
@ -408,6 +408,22 @@ class GlobalDeviceArray:
|
||||
db.block_until_ready()
|
||||
return self
|
||||
|
||||
def _value(self):
|
||||
if not config.jax_array:
|
||||
raise NotImplementedError('Please set `jax_array` config option to True '
|
||||
'to use this feature.')
|
||||
if self.mesh.is_multi_process:
|
||||
raise RuntimeError("Fetching value for GDA that spans non-addressable "
|
||||
"devices is not possible. You can use "
|
||||
"`jax.experimental.multihost_utils.process_allgather` "
|
||||
"for this use case.")
|
||||
unique_shards = [s.data.copy_to_host_async() or s
|
||||
for s in self.local_shards if s.replica_id == 0]
|
||||
npy_value = np.empty(self.shape, self.dtype)
|
||||
for s in unique_shards:
|
||||
npy_value[s.index] = s.data.to_py()
|
||||
return npy_value
|
||||
|
||||
@classmethod
|
||||
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes, data_callback: Callable[[Index],
|
||||
|
@ -32,6 +32,14 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
|
||||
if global_data is None:
|
||||
global_data = np.arange(prod(global_shape)).reshape(global_shape)
|
||||
|
||||
return GlobalDeviceArray.from_callback(
|
||||
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
|
||||
|
||||
|
||||
class GDATest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -354,7 +362,28 @@ class GDATest(jtu.JaxTestCase):
|
||||
gda = GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
self.assertTrue(gda.block_until_ready() is gda)
|
||||
self.assertIs(gda.block_until_ready(), gda)
|
||||
|
||||
|
||||
class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
jax._src.config.config.update('jax_array', True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y")),
|
||||
("mesh_x", P("x")),
|
||||
("mesh_y", P("y")),
|
||||
("mesh_none_y", P(None, "y")),
|
||||
("mesh_xy", P(("x", "y"))),
|
||||
("mesh_fully_replicated", P()),
|
||||
)
|
||||
def test_jax_array_gda_value(self, mesh_axes):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
input_shape = (8, 2)
|
||||
gda, global_data = create_gda(input_shape, global_mesh, mesh_axes)
|
||||
self.assertArraysEqual(gda._value(), global_data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user