Give GDA the ability to fetch the value to host if all devices are addressable and jax_array flag is enabled (just like SDA).

PiperOrigin-RevId: 449790827
This commit is contained in:
Yash Katariya 2022-05-19 11:13:06 -07:00 committed by jax authors
parent 478a95ab74
commit 1313a8456d
2 changed files with 46 additions and 1 deletions

View File

@ -408,6 +408,22 @@ class GlobalDeviceArray:
db.block_until_ready()
return self
def _value(self):
if not config.jax_array:
raise NotImplementedError('Please set `jax_array` config option to True '
'to use this feature.')
if self.mesh.is_multi_process:
raise RuntimeError("Fetching value for GDA that spans non-addressable "
"devices is not possible. You can use "
"`jax.experimental.multihost_utils.process_allgather` "
"for this use case.")
unique_shards = [s.data.copy_to_host_async() or s
for s in self.local_shards if s.replica_id == 0]
npy_value = np.empty(self.shape, self.dtype)
for s in unique_shards:
npy_value[s.index] = s.data.to_py()
return npy_value
@classmethod
def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, data_callback: Callable[[Index],

View File

@ -32,6 +32,14 @@ from jax.config import config
config.parse_flags_with_absl()
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(prod(global_shape)).reshape(global_shape)
return GlobalDeviceArray.from_callback(
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
@ -354,7 +362,28 @@ class GDATest(jtu.JaxTestCase):
gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
self.assertTrue(gda.block_until_ready() is gda)
self.assertIs(gda.block_until_ready(), gda)
class JaxArrayTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
jax._src.config.config.update('jax_array', True)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_jax_array_gda_value(self, mesh_axes):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
gda, global_data = create_gda(input_shape, global_mesh, mesh_axes)
self.assertArraysEqual(gda._value(), global_data)
if __name__ == '__main__':