Split aval_to_result_handler into local_aval_to_result_handler and global_aval_to_result_handler because aval_to_result_handler was shared by pmap, pjit and xmap. This led to parameters having optional values (None) and hacks were introduced to get around that.

For example (before this change): If `config.jax_gsda_out` flag is true and you execute a computation via pmap, it try to create a gda (but that's not supported) and it would error out. So hacks were introduced to prevent that.

Why would you enable the gsda output flag for pmap? Because bf-jax enables that flag at the top level and it becomes inconvenient to untoggle the flag (via context manager) for every unsupported API by GDA.

This change removes choosing the handler based on a flag and also removes handling of `None`'s.

Also replace `gsda` -> `gda` in test files.

PiperOrigin-RevId: 414400644
This commit is contained in:
Yash Katariya 2021-12-06 04:01:02 -08:00 committed by jax authors
parent e7c29eeb3b
commit 0e3653f347
5 changed files with 98 additions and 57 deletions

View File

@ -225,5 +225,11 @@ def _gsda_shard_arg(x, devices, indices):
raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit "
f"mesh: {pjit_mesh},\n GDA mesh: {x._global_mesh}")
return [s.data for s in x.local_shards]
pxla.shard_arg_handlers[GlobalDeviceArray] = _gsda_shard_arg
def _gsda_array_result_handler(global_aval, out_axis_resources, global_mesh):
return lambda bufs: GlobalDeviceArray(global_aval.shape, global_mesh,
out_axis_resources, bufs)
pxla.global_result_handlers[core.ShapedArray] = _gsda_array_result_handler
pxla.global_result_handlers[core.ConcreteArray] = _gsda_array_result_handler

View File

@ -412,22 +412,49 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping) -> AxisResource
tuple(reverse_map[i]) if reverse_map[i] else None for i in range(max_index + 1)
)
def aval_to_result_handler(
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding_spec: Optional[ShardingSpec] = None,
indices: Optional[Tuple[Index]] = None,
out_axis_resources: Optional[AxisResource] = None,
global_mesh = None,
sharding_spec: Optional[ShardingSpec],
indices: Optional[Tuple[Index]],
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The local output AbstractValue.
sharding_spec: Indicates how the output is sharded across devices, or None
for non-array avals.
indices: The pre-computed result of spec_to_indices, or None for non-array
avals.
aval: The output AbstractValue. Can be global or local depending on whether
`jax_gsda_out` flag is enabled or not.
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
"""
try:
return local_result_handlers[type(aval)](aval, sharding_spec, indices)
except KeyError as err:
raise TypeError(
"No pxla_result_handler for type: {}".format(type(aval))) from err
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
local_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
def sda_array_result_handler(aval: ShapedArray, sharding_spec, indices):
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
local_result_handlers[ShapedArray] = sda_array_result_handler
local_result_handlers[ConcreteArray] = sda_array_result_handler
def global_aval_to_result_handler(
aval: core.AbstractValue,
out_axis_resources: Optional[AxisResource], global_mesh,
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The global output AbstractValue.
out_axis_resources: A tuple specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
@ -439,35 +466,14 @@ def aval_to_result_handler(
to the user, e.g. a ShardedDeviceArray.
"""
try:
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval,
out_axis_resources, global_mesh)
return global_result_handlers[type(aval)](aval, out_axis_resources,
global_mesh)
except KeyError as err:
raise TypeError("No pxla_result_handler for type: {}".format(type(aval))
) from err
raise TypeError(
"No pxla_result_handler for type: {}".format(type(aval))) from err
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
def array_result_handler(sharding_spec, indices, aval: ShapedArray,
out_axis_resources, global_mesh):
if config.jax_gsda_out and global_mesh is not None:
return gsda_array_result_handler(aval, global_mesh, out_axis_resources)
else:
return sda_array_result_handler(sharding_spec, indices, aval)
pxla_result_handlers[ShapedArray] = array_result_handler
pxla_result_handlers[ConcreteArray] = array_result_handler
def sda_array_result_handler(sharding_spec, indices, aval: ShapedArray):
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
from jax.experimental.global_device_array import GlobalDeviceArray
return lambda bufs: GlobalDeviceArray(global_aval.shape, global_mesh,
out_axis_resources, bufs)
global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
global_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
### lazy device-memory persistence and result handling
@ -1356,7 +1362,7 @@ def local_avals_to_results_handler(
if aval is not core.abstract_unit else None
for aval, spec in safe_zip(unmapped_local_out_avals, local_out_specs)] # pytype: disable=attribute-error
handlers = [
aval_to_result_handler(aval, sharding_spec=spec, indices=idcs)
local_aval_to_result_handler(aval, spec, idcs)
for aval, spec, idcs in safe_zip(unmapped_local_out_avals, local_out_specs, out_indices)
]
return ResultsHandler(handlers, local_out_specs, out_indices, unmapped_local_out_avals)
@ -1374,7 +1380,7 @@ def global_avals_to_results_handler(global_out_avals: Sequence[ShapedArray],
for aval, spec in safe_zip(global_out_avals, global_out_specs)]
out_axis_resources = [array_mapping_to_axis_resources(o) for o in out_axes]
handlers = [
aval_to_result_handler(global_aval, out_axis_resources=out_axis, global_mesh=global_mesh)
global_aval_to_result_handler(global_aval, out_axis, global_mesh)
for global_aval, out_axis in safe_zip(global_out_avals, out_axis_resources)
]
return ResultsHandler(handlers, global_out_specs, out_indices, global_out_avals)

View File

@ -65,7 +65,7 @@ def _aval_to_result_handler(npart, parts, aval):
indices = pxla.spec_to_indices(aval.shape, spec)
else:
spec = indices = None
return pxla.aval_to_result_handler(aval, spec, indices)
return pxla.local_aval_to_result_handler(aval, spec, indices)
@lu.cache

View File

@ -40,7 +40,7 @@ def create_global_mesh(mesh_shape, axis_names):
return global_mesh
class GSDATest(jtu.JaxTestCase):
class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y", ["x", "y"],
@ -74,7 +74,7 @@ class GSDATest(jtu.JaxTestCase):
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7]),
)
def test_gsda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
@ -123,7 +123,7 @@ class GSDATest(jtu.JaxTestCase):
(4, 4, 2),
[0, 0, 1, 1, 2, 2, 3, 3]),
)
def test_gsda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
global_input_shape = (8, 4, 2)
@ -157,7 +157,7 @@ class GSDATest(jtu.JaxTestCase):
(16,),
[0, 1, 2, 3, 4, 5, 6, 7]),
)
def test_gsda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = create_global_mesh((8,), ('x'))
global_input_shape = (16,)
@ -185,7 +185,7 @@ class GSDATest(jtu.JaxTestCase):
(4, 1),
[0, 0, 0, 0]),
)
def test_gsda_subset_devices(self, mesh_axes, expected_index,
def test_gda_subset_devices(self, mesh_axes, expected_index,
expected_shard_shape, expected_replica_ids):
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
@ -211,7 +211,7 @@ class GSDATest(jtu.JaxTestCase):
self.assertEqual(g.replica_id, l.replica_id)
self.assertArraysEqual(g.data, l.data)
def test_gsda_batched_callback(self):
def test_gda_batched_callback(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [('x', 'y')]
@ -231,7 +231,7 @@ class GSDATest(jtu.JaxTestCase):
self.assertArraysEqual(gda.local_data(1).to_py(),
expected_second_shard_value)
def test_gsda_batched_callback_with_devices(self):
def test_gda_batched_callback_with_devices(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x']
@ -257,7 +257,7 @@ class GSDATest(jtu.JaxTestCase):
self.assertArraysEqual(gda.local_data(1).to_py(),
expected_second_shard_value)
def test_gsda_str_repr(self):
def test_gda_str_repr(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [('x', 'y')]

View File

@ -594,10 +594,10 @@ class PJitTest(jtu.BufferDonationTestCase):
lambda: exe(x_i32, x_i32))
class GSDAPjitTest(jtu.JaxTestCase):
class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_single_output(self):
def test_pjit_gda_single_output(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
@ -629,7 +629,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
f(input_data)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_multi_input_multi_output(self):
def test_pjit_gda_multi_input_multi_output(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
@ -638,16 +638,16 @@ class GSDAPjitTest(jtu.JaxTestCase):
return input_data[index]
mesh_axes1 = P('x', 'y')
gsda1 = global_device_array.GlobalDeviceArray.from_callback(
gda1 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes1, cb)
mesh_axes2 = P('x')
gsda2 = global_device_array.GlobalDeviceArray.from_callback(
gda2 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes2, cb)
mesh_axes3 = P(('x', 'y'))
gsda3 = global_device_array.GlobalDeviceArray.from_callback(
gda3 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes3, cb)
mesh_axes4 = P(None)
gsda4 = global_device_array.GlobalDeviceArray.from_callback(
gda4 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes4, cb)
with jax._src.config.gsda_out(True):
@ -658,7 +658,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
def f(x, y, z, a):
return x @ x.T, y, z, a
out1, out2, out3, out4 = f(gsda1, gsda2, gsda3, gsda4)
out1, out2, out3, out4 = f(gda1, gda2, gda3, gda4)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 8))
@ -702,7 +702,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(s.data, input_data[s.index])
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_mixed_inputs(self):
def test_pjit_gda_mixed_inputs(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
@ -737,8 +737,37 @@ class GSDAPjitTest(jtu.JaxTestCase):
for s in out2.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_non_gda_inputs(self):
input_shape = (8, 2)
input_data = np.arange(prod(input_shape)).reshape(input_shape)
with jax._src.config.gsda_out(True):
@partial(pjit,
in_axis_resources=(None, P('x', 'y')),
out_axis_resources=(P('x', 'y'), P(('x', 'y'))))
def f(x, y):
return x @ x.T, y @ y.T
expected_matrix_mul = input_data @ input_data.T
out1, out2 = f(input_data, input_data)
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out1._global_mesh.shape, {'x': 4, 'y': 2})
for s in out1.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
self.assertEqual(out2.shape, (8, 8))
self.assertEqual(out2.local_shards[0].data.shape, (1, 8))
self.assertDictEqual(out2._global_mesh.shape, {'x': 4, 'y': 2})
for s in out2.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_pjit_gsda_mesh_mismatch(self):
def test_pjit_gda_mesh_mismatch(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x', 'y']
@ -759,7 +788,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
f(gda_obj)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_wrong_resource_for_gsda_input(self):
def test_pjit_gda_wrong_resource_for_gda_input(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x']