mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Split aval_to_result_handler
into local_aval_to_result_handler
and global_aval_to_result_handler
because aval_to_result_handler
was shared by pmap, pjit and xmap. This led to parameters having optional values (None) and hacks were introduced to get around that.
For example (before this change): If `config.jax_gsda_out` flag is true and you execute a computation via pmap, it try to create a gda (but that's not supported) and it would error out. So hacks were introduced to prevent that. Why would you enable the gsda output flag for pmap? Because bf-jax enables that flag at the top level and it becomes inconvenient to untoggle the flag (via context manager) for every unsupported API by GDA. This change removes choosing the handler based on a flag and also removes handling of `None`'s. Also replace `gsda` -> `gda` in test files. PiperOrigin-RevId: 414400644
This commit is contained in:
parent
e7c29eeb3b
commit
0e3653f347
@ -225,5 +225,11 @@ def _gsda_shard_arg(x, devices, indices):
|
||||
raise ValueError("Pjit's mesh and GDA's mesh should be equal. Got Pjit "
|
||||
f"mesh: {pjit_mesh},\n GDA mesh: {x._global_mesh}")
|
||||
return [s.data for s in x.local_shards]
|
||||
|
||||
pxla.shard_arg_handlers[GlobalDeviceArray] = _gsda_shard_arg
|
||||
|
||||
|
||||
def _gsda_array_result_handler(global_aval, out_axis_resources, global_mesh):
|
||||
return lambda bufs: GlobalDeviceArray(global_aval.shape, global_mesh,
|
||||
out_axis_resources, bufs)
|
||||
pxla.global_result_handlers[core.ShapedArray] = _gsda_array_result_handler
|
||||
pxla.global_result_handlers[core.ConcreteArray] = _gsda_array_result_handler
|
||||
|
@ -412,22 +412,49 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping) -> AxisResource
|
||||
tuple(reverse_map[i]) if reverse_map[i] else None for i in range(max_index + 1)
|
||||
)
|
||||
|
||||
def aval_to_result_handler(
|
||||
def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
sharding_spec: Optional[ShardingSpec] = None,
|
||||
indices: Optional[Tuple[Index]] = None,
|
||||
out_axis_resources: Optional[AxisResource] = None,
|
||||
global_mesh = None,
|
||||
sharding_spec: Optional[ShardingSpec],
|
||||
indices: Optional[Tuple[Index]],
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
Args:
|
||||
aval: The local output AbstractValue.
|
||||
sharding_spec: Indicates how the output is sharded across devices, or None
|
||||
for non-array avals.
|
||||
indices: The pre-computed result of spec_to_indices, or None for non-array
|
||||
avals.
|
||||
aval: The output AbstractValue. Can be global or local depending on whether
|
||||
`jax_gsda_out` flag is enabled or not.
|
||||
|
||||
Returns:
|
||||
A function for handling the Buffers that will eventually be produced
|
||||
for this output. The function will return an object suitable for returning
|
||||
to the user, e.g. a ShardedDeviceArray.
|
||||
"""
|
||||
try:
|
||||
return local_result_handlers[type(aval)](aval, sharding_spec, indices)
|
||||
except KeyError as err:
|
||||
raise TypeError(
|
||||
"No pxla_result_handler for type: {}".format(type(aval))) from err
|
||||
|
||||
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
|
||||
local_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
local_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
|
||||
def sda_array_result_handler(aval: ShapedArray, sharding_spec, indices):
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
local_result_handlers[ShapedArray] = sda_array_result_handler
|
||||
local_result_handlers[ConcreteArray] = sda_array_result_handler
|
||||
|
||||
|
||||
def global_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
out_axis_resources: Optional[AxisResource], global_mesh,
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
Args:
|
||||
aval: The global output AbstractValue.
|
||||
out_axis_resources: A tuple specifying the sharding of outputs.
|
||||
Used for creating GSDAs.
|
||||
global_mesh: The global device mesh that generated this output. Used
|
||||
@ -439,35 +466,14 @@ def aval_to_result_handler(
|
||||
to the user, e.g. a ShardedDeviceArray.
|
||||
"""
|
||||
try:
|
||||
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval,
|
||||
out_axis_resources, global_mesh)
|
||||
return global_result_handlers[type(aval)](aval, out_axis_resources,
|
||||
global_mesh)
|
||||
except KeyError as err:
|
||||
raise TypeError("No pxla_result_handler for type: {}".format(type(aval))
|
||||
) from err
|
||||
raise TypeError(
|
||||
"No pxla_result_handler for type: {}".format(type(aval))) from err
|
||||
|
||||
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
|
||||
pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
|
||||
|
||||
def array_result_handler(sharding_spec, indices, aval: ShapedArray,
|
||||
out_axis_resources, global_mesh):
|
||||
if config.jax_gsda_out and global_mesh is not None:
|
||||
return gsda_array_result_handler(aval, global_mesh, out_axis_resources)
|
||||
else:
|
||||
return sda_array_result_handler(sharding_spec, indices, aval)
|
||||
|
||||
pxla_result_handlers[ShapedArray] = array_result_handler
|
||||
pxla_result_handlers[ConcreteArray] = array_result_handler
|
||||
|
||||
def sda_array_result_handler(sharding_spec, indices, aval: ShapedArray):
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
|
||||
def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
|
||||
return lambda bufs: GlobalDeviceArray(global_aval.shape, global_mesh,
|
||||
out_axis_resources, bufs)
|
||||
global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
global_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
@ -1356,7 +1362,7 @@ def local_avals_to_results_handler(
|
||||
if aval is not core.abstract_unit else None
|
||||
for aval, spec in safe_zip(unmapped_local_out_avals, local_out_specs)] # pytype: disable=attribute-error
|
||||
handlers = [
|
||||
aval_to_result_handler(aval, sharding_spec=spec, indices=idcs)
|
||||
local_aval_to_result_handler(aval, spec, idcs)
|
||||
for aval, spec, idcs in safe_zip(unmapped_local_out_avals, local_out_specs, out_indices)
|
||||
]
|
||||
return ResultsHandler(handlers, local_out_specs, out_indices, unmapped_local_out_avals)
|
||||
@ -1374,7 +1380,7 @@ def global_avals_to_results_handler(global_out_avals: Sequence[ShapedArray],
|
||||
for aval, spec in safe_zip(global_out_avals, global_out_specs)]
|
||||
out_axis_resources = [array_mapping_to_axis_resources(o) for o in out_axes]
|
||||
handlers = [
|
||||
aval_to_result_handler(global_aval, out_axis_resources=out_axis, global_mesh=global_mesh)
|
||||
global_aval_to_result_handler(global_aval, out_axis, global_mesh)
|
||||
for global_aval, out_axis in safe_zip(global_out_avals, out_axis_resources)
|
||||
]
|
||||
return ResultsHandler(handlers, global_out_specs, out_indices, global_out_avals)
|
||||
|
@ -65,7 +65,7 @@ def _aval_to_result_handler(npart, parts, aval):
|
||||
indices = pxla.spec_to_indices(aval.shape, spec)
|
||||
else:
|
||||
spec = indices = None
|
||||
return pxla.aval_to_result_handler(aval, spec, indices)
|
||||
return pxla.local_aval_to_result_handler(aval, spec, indices)
|
||||
|
||||
|
||||
@lu.cache
|
||||
|
@ -40,7 +40,7 @@ def create_global_mesh(mesh_shape, axis_names):
|
||||
return global_mesh
|
||||
|
||||
|
||||
class GSDATest(jtu.JaxTestCase):
|
||||
class GDATest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", ["x", "y"],
|
||||
@ -74,7 +74,7 @@ class GSDATest(jtu.JaxTestCase):
|
||||
(8, 2),
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
)
|
||||
def test_gsda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
@ -123,7 +123,7 @@ class GSDATest(jtu.JaxTestCase):
|
||||
(4, 4, 2),
|
||||
[0, 0, 1, 1, 2, 2, 3, 3]),
|
||||
)
|
||||
def test_gsda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids):
|
||||
global_mesh = create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
global_input_shape = (8, 4, 2)
|
||||
@ -157,7 +157,7 @@ class GSDATest(jtu.JaxTestCase):
|
||||
(16,),
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
)
|
||||
def test_gsda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
|
||||
expected_replica_ids):
|
||||
global_mesh = create_global_mesh((8,), ('x'))
|
||||
global_input_shape = (16,)
|
||||
@ -185,7 +185,7 @@ class GSDATest(jtu.JaxTestCase):
|
||||
(4, 1),
|
||||
[0, 0, 0, 0]),
|
||||
)
|
||||
def test_gsda_subset_devices(self, mesh_axes, expected_index,
|
||||
def test_gda_subset_devices(self, mesh_axes, expected_index,
|
||||
expected_shard_shape, expected_replica_ids):
|
||||
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
@ -211,7 +211,7 @@ class GSDATest(jtu.JaxTestCase):
|
||||
self.assertEqual(g.replica_id, l.replica_id)
|
||||
self.assertArraysEqual(g.data, l.data)
|
||||
|
||||
def test_gsda_batched_callback(self):
|
||||
def test_gda_batched_callback(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = [('x', 'y')]
|
||||
@ -231,7 +231,7 @@ class GSDATest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(gda.local_data(1).to_py(),
|
||||
expected_second_shard_value)
|
||||
|
||||
def test_gsda_batched_callback_with_devices(self):
|
||||
def test_gda_batched_callback_with_devices(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x']
|
||||
@ -257,7 +257,7 @@ class GSDATest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(gda.local_data(1).to_py(),
|
||||
expected_second_shard_value)
|
||||
|
||||
def test_gsda_str_repr(self):
|
||||
def test_gda_str_repr(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = [('x', 'y')]
|
||||
|
@ -594,10 +594,10 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
lambda: exe(x_i32, x_i32))
|
||||
|
||||
|
||||
class GSDAPjitTest(jtu.JaxTestCase):
|
||||
class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gsda_single_output(self):
|
||||
def test_pjit_gda_single_output(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
@ -629,7 +629,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
f(input_data)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gsda_multi_input_multi_output(self):
|
||||
def test_pjit_gda_multi_input_multi_output(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_data = np.arange(
|
||||
@ -638,16 +638,16 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
return input_data[index]
|
||||
|
||||
mesh_axes1 = P('x', 'y')
|
||||
gsda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
gda1 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes1, cb)
|
||||
mesh_axes2 = P('x')
|
||||
gsda2 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
gda2 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes2, cb)
|
||||
mesh_axes3 = P(('x', 'y'))
|
||||
gsda3 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
gda3 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes3, cb)
|
||||
mesh_axes4 = P(None)
|
||||
gsda4 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
gda4 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes4, cb)
|
||||
|
||||
with jax._src.config.gsda_out(True):
|
||||
@ -658,7 +658,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
|
||||
def f(x, y, z, a):
|
||||
return x @ x.T, y, z, a
|
||||
out1, out2, out3, out4 = f(gsda1, gsda2, gsda3, gsda4)
|
||||
out1, out2, out3, out4 = f(gda1, gda2, gda3, gda4)
|
||||
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
@ -702,7 +702,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gsda_mixed_inputs(self):
|
||||
def test_pjit_gda_mixed_inputs(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
@ -737,8 +737,37 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
for s in out2.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gda_non_gda_inputs(self):
|
||||
input_shape = (8, 2)
|
||||
input_data = np.arange(prod(input_shape)).reshape(input_shape)
|
||||
|
||||
with jax._src.config.gsda_out(True):
|
||||
@partial(pjit,
|
||||
in_axis_resources=(None, P('x', 'y')),
|
||||
out_axis_resources=(P('x', 'y'), P(('x', 'y'))))
|
||||
def f(x, y):
|
||||
return x @ x.T, y @ y.T
|
||||
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
out1, out2 = f(input_data, input_data)
|
||||
|
||||
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
|
||||
self.assertDictEqual(out1._global_mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out1.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
||||
self.assertEqual(out2.shape, (8, 8))
|
||||
self.assertEqual(out2.local_shards[0].data.shape, (1, 8))
|
||||
self.assertDictEqual(out2._global_mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out2.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def test_pjit_gsda_mesh_mismatch(self):
|
||||
def test_pjit_gda_mesh_mismatch(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x', 'y']
|
||||
@ -759,7 +788,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
f(gda_obj)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gsda_wrong_resource_for_gsda_input(self):
|
||||
def test_pjit_gda_wrong_resource_for_gda_input(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x']
|
||||
|
Loading…
x
Reference in New Issue
Block a user