mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Output GSDAs from pjit if jax_gsda_out
flag is enabled.
PiperOrigin-RevId: 409585439
This commit is contained in:
parent
155475de6f
commit
e94cc97d70
@ -498,12 +498,6 @@ log_compiles = config.define_bool_state(
|
||||
'option is set, the log level is WARNING; otherwise the level is '
|
||||
'DEBUG.'))
|
||||
|
||||
gsda_out = config.define_bool_state(
|
||||
name='jax_gsda_out',
|
||||
default=False,
|
||||
help='If True, pjit will output GSDAs.')
|
||||
|
||||
|
||||
distributed_debug = config.define_bool_state(
|
||||
name='jax_distributed_debug',
|
||||
default=False,
|
||||
|
@ -412,58 +412,18 @@ def _shard_abstract_array(size, axis: int, x):
|
||||
return x.update(shape=tuple_delete(x.shape, axis))
|
||||
shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
||||
|
||||
MeshAxisName = Any
|
||||
"""
|
||||
ArrayMapping specifies how an ndarray should map to mesh axes.
|
||||
|
||||
Note that the ordering is crucial for the cases when this mapping is non-injective
|
||||
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
||||
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
||||
according to which chunks of the value along the repeated dimension will be assigned.
|
||||
|
||||
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
||||
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
||||
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
||||
that would mean that a flat list of chunks would get assigned to a flattened list of
|
||||
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
||||
mesh devices ndarray would have to be transposed before flattening and assignment.
|
||||
"""
|
||||
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
||||
|
||||
AxisResource = Tuple[Optional[Tuple[Any, ...]], ...]
|
||||
|
||||
def array_mapping_to_axis_resources(array_mapping: ArrayMapping) -> AxisResource:
|
||||
if not array_mapping:
|
||||
return tuple()
|
||||
max_index = array_mapping[max(array_mapping, key=array_mapping.get)] # type: ignore
|
||||
reverse_map = defaultdict(list)
|
||||
for axis, index in array_mapping.items():
|
||||
reverse_map[index].append(axis)
|
||||
return tuple(
|
||||
tuple(reverse_map[i]) if reverse_map[i] else None for i in range(max_index + 1)
|
||||
)
|
||||
|
||||
def aval_to_result_handler(
|
||||
sharding_spec: Optional[ShardingSpec],
|
||||
indices: Optional[Tuple[Index]],
|
||||
aval: core.AbstractValue,
|
||||
global_aval: Optional[ShapedArray] = None,
|
||||
out_axis_resources: Optional[AxisResource] = None,
|
||||
global_mesh = None,
|
||||
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
|
||||
def aval_to_result_handler(sharding_spec: Optional[ShardingSpec],
|
||||
indices: Optional[Tuple[Index]],
|
||||
aval: core.AbstractValue) -> Callable[
|
||||
[List[xb.xla_client.Buffer]], Any]:
|
||||
"""Returns a function for handling the raw buffers of a single output aval.
|
||||
|
||||
Args:
|
||||
sharding_spec: Indicates how the output is sharded across devices, or None
|
||||
sharding_spec: indicates how the output is sharded across devices, or None
|
||||
for non-array avals.
|
||||
indices: The pre-computed result of spec_to_indices, or None for non-array
|
||||
indices: the pre-computed result of spec_to_indices, or None for non-array
|
||||
avals.
|
||||
aval: The output AbstractValue.
|
||||
global_aval: Global output AbstractValue. Used for creating GSDAs.
|
||||
out_axis_resources: A tuple specifying the sharding of outputs.
|
||||
Used for creating GSDAs.
|
||||
global_mesh: The global device mesh that generated this output. Used
|
||||
for creating GSDAs.
|
||||
aval: the output AbstractValue.
|
||||
|
||||
Returns:
|
||||
A function for handling the Buffers that will eventually be produced
|
||||
@ -471,8 +431,7 @@ def aval_to_result_handler(
|
||||
to the user, e.g. a ShardedDeviceArray.
|
||||
"""
|
||||
try:
|
||||
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval,
|
||||
global_aval, out_axis_resources, global_mesh)
|
||||
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval)
|
||||
except KeyError as err:
|
||||
raise TypeError("No pxla_result_handler for type: {}".format(type(aval))
|
||||
) from err
|
||||
@ -480,26 +439,14 @@ def aval_to_result_handler(
|
||||
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
|
||||
pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
|
||||
def array_result_handler(sharding_spec, indices, aval: ShapedArray):
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
|
||||
def array_result_handler(sharding_spec, indices, aval: ShapedArray, global_aval,
|
||||
out_axis_resources, global_mesh):
|
||||
if config.jax_gsda_out:
|
||||
return gsda_array_result_handler(global_aval, global_mesh, out_axis_resources)
|
||||
else:
|
||||
return sda_array_result_handler(sharding_spec, indices, aval)
|
||||
|
||||
pxla_result_handlers[ShapedArray] = array_result_handler
|
||||
pxla_result_handlers[ConcreteArray] = array_result_handler
|
||||
|
||||
def sda_array_result_handler(sharding_spec, indices, aval: ShapedArray):
|
||||
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
|
||||
indices)
|
||||
|
||||
def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
|
||||
from ..experimental.gsda import GlobalShardedDeviceArray
|
||||
|
||||
return lambda bufs: GlobalShardedDeviceArray(
|
||||
global_aval.shape, global_mesh, out_axis_resources, bufs)
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
@ -1225,31 +1172,12 @@ class ResultsHandler:
|
||||
def __call__(self, out_bufs):
|
||||
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
|
||||
|
||||
|
||||
def avals_to_results_handler(
|
||||
nrep,
|
||||
npart,
|
||||
out_specs,
|
||||
unmapped_local_out_avals,
|
||||
global_out_avals: Optional[Sequence[ShapedArray]] = None,
|
||||
out_axis_resources: Optional[Sequence[AxisResource]] = None,
|
||||
global_mesh=None):
|
||||
def avals_to_results_handler(nrep, npart, out_specs, unmapped_local_out_avals):
|
||||
out_indices = [spec_to_indices(aval.shape, spec)
|
||||
if aval is not core.abstract_unit else None
|
||||
for aval, spec in safe_zip(unmapped_local_out_avals, out_specs)] # pytype: disable=attribute-error
|
||||
if global_out_avals and out_axis_resources and global_mesh:
|
||||
handlers = [
|
||||
aval_to_result_handler(spec, idcs, aval, global_aval, out_axis, global_mesh)
|
||||
for spec, idcs, aval, global_aval, out_axis in safe_zip(
|
||||
out_specs, out_indices, unmapped_local_out_avals,
|
||||
global_out_avals, out_axis_resources)
|
||||
]
|
||||
else:
|
||||
handlers = [
|
||||
aval_to_result_handler(spec, idcs, aval)
|
||||
for spec, idcs, aval, in safe_zip(out_specs, out_indices,
|
||||
unmapped_local_out_avals)
|
||||
]
|
||||
handlers = [aval_to_result_handler(spec, idcs, aval)
|
||||
for spec, idcs, aval in safe_zip(out_specs, out_indices, unmapped_local_out_avals)]
|
||||
|
||||
return ResultsHandler(handlers, out_specs, out_indices, unmapped_local_out_avals)
|
||||
|
||||
@ -1469,6 +1397,24 @@ def _unravel_index(c, axis_env):
|
||||
|
||||
# ------------------- xmap -------------------
|
||||
|
||||
MeshAxisName = Any
|
||||
"""
|
||||
ArrayMapping specifies how an ndarray should map to mesh axes.
|
||||
|
||||
Note that the ordering is crucial for the cases when this mapping is non-injective
|
||||
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
||||
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
||||
according to which chunks of the value along the repeated dimension will be assigned.
|
||||
|
||||
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
||||
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
||||
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
||||
that would mean that a flat list of chunks would get assigned to a flattened list of
|
||||
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
||||
mesh devices ndarray would have to be transposed before flattening and assignment.
|
||||
"""
|
||||
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
||||
|
||||
class Mesh:
|
||||
|
||||
def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
|
||||
@ -1724,8 +1670,8 @@ def lower_mesh_computation(
|
||||
built = c.Build(out_tuple)
|
||||
return MeshComputation(
|
||||
built, mesh, local_in_untiled_avals,
|
||||
local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None),
|
||||
in_axes, out_axes, spmd_lowering, tuple_args)
|
||||
local_out_untiled_avals, in_axes, out_axes,
|
||||
spmd_lowering, tuple_args)
|
||||
|
||||
|
||||
class MeshComputation:
|
||||
@ -1757,7 +1703,6 @@ class MeshExecutable:
|
||||
mesh: Mesh,
|
||||
local_in_untiled_avals: Sequence[ShapedArray],
|
||||
local_out_untiled_avals: Sequence[ShapedArray],
|
||||
global_out_avals: Optional[Sequence[ShapedArray]],
|
||||
in_axes: Sequence[ArrayMapping],
|
||||
out_axes: Sequence[ArrayMapping],
|
||||
spmd_lowering: bool, tuple_args: bool,
|
||||
@ -1799,10 +1744,8 @@ class MeshExecutable:
|
||||
|
||||
local_output_specs = [local_sharding_spec(aval, aval_out_axes)
|
||||
for aval, aval_out_axes in safe_zip(local_out_untiled_avals, out_axes)]
|
||||
out_axis_resources = [array_mapping_to_axis_resources(o) for o in out_axes]
|
||||
handle_outs = avals_to_results_handler(num_local_replicas, num_local_partitions,
|
||||
local_output_specs, local_out_untiled_avals,
|
||||
global_out_avals, out_axis_resources, mesh)
|
||||
local_output_specs, local_out_untiled_avals)
|
||||
|
||||
if _allow_compile_replicated and hasattr(backend, "compile_replicated"):
|
||||
self.unsafe_call = backend.compile_replicated(
|
||||
|
@ -597,99 +597,16 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with jax._src.config.gsda_out(True):
|
||||
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
|
||||
def f(x):
|
||||
return x @ x.T
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
|
||||
out = f(gsda_obj)
|
||||
self.assertIsInstance(out, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.local_shards[0].data.shape, (2, 4))
|
||||
self.assertDictEqual(out._global_mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
out1 = f(input_data)
|
||||
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
|
||||
for s in out.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gsda_multi_input_multi_output(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
mesh_axes1 = P('x', 'y')
|
||||
gsda1 = gsda.GlobalShardedDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes1, cb)
|
||||
mesh_axes2 = P('x')
|
||||
gsda2 = gsda.GlobalShardedDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes2, cb)
|
||||
mesh_axes3 = P(('x', 'y'))
|
||||
gsda3 = gsda.GlobalShardedDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes3, cb)
|
||||
mesh_axes4 = P(None)
|
||||
gsda4 = gsda.GlobalShardedDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes4, cb)
|
||||
|
||||
with jax._src.config.gsda_out(True):
|
||||
@partial(
|
||||
pjit,
|
||||
in_axis_resources=(mesh_axes1, mesh_axes2, mesh_axes3, mesh_axes4),
|
||||
out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
|
||||
def f(x, y, z, a):
|
||||
return x @ x.T, y, z, a
|
||||
out1, out2, out3, out4 = f(gsda1, gsda2, gsda3, gsda4)
|
||||
|
||||
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
|
||||
self.assertEqual(out1.local_shards[0].index, (slice(0, 2), slice(0, 4)))
|
||||
self.assertEqual(out1.local_shards[1].index, (slice(0, 2), slice(4, 8)))
|
||||
self.assertListEqual([s.replica_id for s in out1.local_shards],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0])
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
for s in out1.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
self.assertEqual(out2.local_shards[0].data.shape, (8, 2))
|
||||
self.assertEqual(out2.local_shards[0].index, (slice(None), slice(None)))
|
||||
self.assertEqual(out2.local_shards[1].index, (slice(None), slice(None)))
|
||||
self.assertListEqual([s.replica_id for s in out2.local_shards],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7])
|
||||
for s in out2.local_shards:
|
||||
self.assertArraysEqual(s.data, input_data)
|
||||
|
||||
self.assertIsInstance(out3, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out3.shape, (8, 2))
|
||||
self.assertEqual(out3.local_shards[0].data.shape, (2, 2))
|
||||
self.assertEqual(out3.local_shards[0].index, (slice(0, 2), slice(None)))
|
||||
self.assertEqual(out3.local_shards[1].index, (slice(0, 2), slice(None)))
|
||||
self.assertListEqual([s.replica_id for s in out3.local_shards],
|
||||
[0, 1, 0, 1, 0, 1, 0, 1])
|
||||
for s in out3.local_shards:
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out4, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out4.shape, (8, 2))
|
||||
self.assertEqual(out4.local_shards[0].data.shape, (1, 2))
|
||||
self.assertEqual(out4.local_shards[0].index, (slice(0, 1), slice(None)))
|
||||
self.assertEqual(out4.local_shards[1].index, (slice(1, 2), slice(None)))
|
||||
self.assertListEqual([s.replica_id for s in out4.local_shards],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0])
|
||||
for s in out4.local_shards:
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
|
||||
def f(x):
|
||||
return x @ x.T
|
||||
|
||||
out = f(gsda_obj)
|
||||
# TODO(yashkatariya): Enable the gsda_out flag and check for GSDA as the
|
||||
# output.
|
||||
self.assertIsInstance(out, pxla.ShardedDeviceArray)
|
||||
self.assertLen(out.device_buffers, 8)
|
||||
self.assertEqual(out.device_buffers[0].shape, (2, 4))
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def test_pjit_gsda_mesh_mismatch(self):
|
||||
@ -903,15 +820,15 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
pjit(lambda x, y: x, p, p)(x, x) # Error, but make sure we hint at tupling
|
||||
# TODO(apaszke): Disable implicit list casts and enable this
|
||||
# error = re.escape(
|
||||
# r"pjit in_axis_resources specification must be a tree prefix of the "
|
||||
# r"corresponding value, got specification (None, None, None) for value "
|
||||
# r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that "
|
||||
# r"are non-trivial pytrees should always be wrapped in a tuple representing "
|
||||
# r"the argument list. In particular, you're passing in a single argument "
|
||||
# r"which means that pjit in_axis_resources might need to be wrapped in a "
|
||||
# r"singleton tuple.")
|
||||
# r"pjit in_axis_resources specification must be a tree prefix of the "
|
||||
# r"corresponding value, got specification (None, None, None) for value "
|
||||
# r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that "
|
||||
# r"are non-trivial pytrees should always be wrapped in a tuple representing "
|
||||
# r"the argument list. In particular, you're passing in a single argument "
|
||||
# r"which means that pjit in_axis_resources might need to be wrapped in a "
|
||||
# r"singleton tuple.")
|
||||
# with self.assertRaisesRegex(ValueError, error):
|
||||
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
|
||||
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
|
||||
error = re.escape(
|
||||
r"pjit out_axis_resources specification must be a tree prefix of the "
|
||||
r"corresponding value, got specification [[None, None, None], None] for "
|
||||
@ -936,7 +853,6 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
def testOpShardingRoundTrip(self):
|
||||
FakeDevice = namedtuple('FakeDevice', ['id'])
|
||||
mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)])
|
||||
@ -963,14 +879,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
spec[rng.choice(dims)] += (axis,)
|
||||
roundtrip(P(*spec))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("linear", {'x': 0, 'y': 1, 'z': 2}, (('x',), ('y',), ('z',))),
|
||||
("combine", {'x': 0, 'y': 0, 'z': 1}, (('x', 'y'), ('z',))),
|
||||
("skip", {'x': 0, 'y': 0, 'z': 2}, (('x', 'y'), None, ('z',))),
|
||||
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, (('x',), ('y',), None, ('z',))),
|
||||
)
|
||||
def test_array_mapping_to_axis_resources(self, inp, expected_out):
|
||||
self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user