Output GSDAs from pjit if jax_gsda_out flag is enabled.

PiperOrigin-RevId: 409585439
This commit is contained in:
jax authors 2021-11-12 21:46:37 -08:00
parent 155475de6f
commit e94cc97d70
3 changed files with 52 additions and 207 deletions

View File

@ -498,12 +498,6 @@ log_compiles = config.define_bool_state(
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))
gsda_out = config.define_bool_state(
name='jax_gsda_out',
default=False,
help='If True, pjit will output GSDAs.')
distributed_debug = config.define_bool_state(
name='jax_distributed_debug',
default=False,

View File

@ -412,58 +412,18 @@ def _shard_abstract_array(size, axis: int, x):
return x.update(shape=tuple_delete(x.shape, axis))
shard_aval_handlers[ShapedArray] = _shard_abstract_array
MeshAxisName = Any
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]
AxisResource = Tuple[Optional[Tuple[Any, ...]], ...]
def array_mapping_to_axis_resources(array_mapping: ArrayMapping) -> AxisResource:
if not array_mapping:
return tuple()
max_index = array_mapping[max(array_mapping, key=array_mapping.get)] # type: ignore
reverse_map = defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
return tuple(
tuple(reverse_map[i]) if reverse_map[i] else None for i in range(max_index + 1)
)
def aval_to_result_handler(
sharding_spec: Optional[ShardingSpec],
indices: Optional[Tuple[Index]],
aval: core.AbstractValue,
global_aval: Optional[ShapedArray] = None,
out_axis_resources: Optional[AxisResource] = None,
global_mesh = None,
) -> Callable[[List[xb.xla_client.Buffer]], Any]:
def aval_to_result_handler(sharding_spec: Optional[ShardingSpec],
indices: Optional[Tuple[Index]],
aval: core.AbstractValue) -> Callable[
[List[xb.xla_client.Buffer]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
sharding_spec: Indicates how the output is sharded across devices, or None
sharding_spec: indicates how the output is sharded across devices, or None
for non-array avals.
indices: The pre-computed result of spec_to_indices, or None for non-array
indices: the pre-computed result of spec_to_indices, or None for non-array
avals.
aval: The output AbstractValue.
global_aval: Global output AbstractValue. Used for creating GSDAs.
out_axis_resources: A tuple specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
aval: the output AbstractValue.
Returns:
A function for handling the Buffers that will eventually be produced
@ -471,8 +431,7 @@ def aval_to_result_handler(
to the user, e.g. a ShardedDeviceArray.
"""
try:
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval,
global_aval, out_axis_resources, global_mesh)
return pxla_result_handlers[type(aval)](sharding_spec, indices, aval)
except KeyError as err:
raise TypeError("No pxla_result_handler for type: {}".format(type(aval))
) from err
@ -480,26 +439,14 @@ def aval_to_result_handler(
PxlaResultHandler = Callable[..., Callable[[List[xb.xla_client.Buffer]], Any]]
pxla_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {}
pxla_result_handlers[core.AbstractUnit] = lambda *_: lambda _: core.unit
def array_result_handler(sharding_spec, indices, aval: ShapedArray):
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
def array_result_handler(sharding_spec, indices, aval: ShapedArray, global_aval,
out_axis_resources, global_mesh):
if config.jax_gsda_out:
return gsda_array_result_handler(global_aval, global_mesh, out_axis_resources)
else:
return sda_array_result_handler(sharding_spec, indices, aval)
pxla_result_handlers[ShapedArray] = array_result_handler
pxla_result_handlers[ConcreteArray] = array_result_handler
def sda_array_result_handler(sharding_spec, indices, aval: ShapedArray):
return lambda bufs: make_sharded_device_array(aval, sharding_spec, bufs,
indices)
def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
from ..experimental.gsda import GlobalShardedDeviceArray
return lambda bufs: GlobalShardedDeviceArray(
global_aval.shape, global_mesh, out_axis_resources, bufs)
### lazy device-memory persistence and result handling
@ -1225,31 +1172,12 @@ class ResultsHandler:
def __call__(self, out_bufs):
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
def avals_to_results_handler(
nrep,
npart,
out_specs,
unmapped_local_out_avals,
global_out_avals: Optional[Sequence[ShapedArray]] = None,
out_axis_resources: Optional[Sequence[AxisResource]] = None,
global_mesh=None):
def avals_to_results_handler(nrep, npart, out_specs, unmapped_local_out_avals):
out_indices = [spec_to_indices(aval.shape, spec)
if aval is not core.abstract_unit else None
for aval, spec in safe_zip(unmapped_local_out_avals, out_specs)] # pytype: disable=attribute-error
if global_out_avals and out_axis_resources and global_mesh:
handlers = [
aval_to_result_handler(spec, idcs, aval, global_aval, out_axis, global_mesh)
for spec, idcs, aval, global_aval, out_axis in safe_zip(
out_specs, out_indices, unmapped_local_out_avals,
global_out_avals, out_axis_resources)
]
else:
handlers = [
aval_to_result_handler(spec, idcs, aval)
for spec, idcs, aval, in safe_zip(out_specs, out_indices,
unmapped_local_out_avals)
]
handlers = [aval_to_result_handler(spec, idcs, aval)
for spec, idcs, aval in safe_zip(out_specs, out_indices, unmapped_local_out_avals)]
return ResultsHandler(handlers, out_specs, out_indices, unmapped_local_out_avals)
@ -1469,6 +1397,24 @@ def _unravel_index(c, axis_env):
# ------------------- xmap -------------------
MeshAxisName = Any
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]
class Mesh:
def __init__(self, devices: np.ndarray, axis_names: Sequence[MeshAxisName]):
@ -1724,8 +1670,8 @@ def lower_mesh_computation(
built = c.Build(out_tuple)
return MeshComputation(
built, mesh, local_in_untiled_avals,
local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None),
in_axes, out_axes, spmd_lowering, tuple_args)
local_out_untiled_avals, in_axes, out_axes,
spmd_lowering, tuple_args)
class MeshComputation:
@ -1757,7 +1703,6 @@ class MeshExecutable:
mesh: Mesh,
local_in_untiled_avals: Sequence[ShapedArray],
local_out_untiled_avals: Sequence[ShapedArray],
global_out_avals: Optional[Sequence[ShapedArray]],
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
spmd_lowering: bool, tuple_args: bool,
@ -1799,10 +1744,8 @@ class MeshExecutable:
local_output_specs = [local_sharding_spec(aval, aval_out_axes)
for aval, aval_out_axes in safe_zip(local_out_untiled_avals, out_axes)]
out_axis_resources = [array_mapping_to_axis_resources(o) for o in out_axes]
handle_outs = avals_to_results_handler(num_local_replicas, num_local_partitions,
local_output_specs, local_out_untiled_avals,
global_out_avals, out_axis_resources, mesh)
local_output_specs, local_out_untiled_avals)
if _allow_compile_replicated and hasattr(backend, "compile_replicated"):
self.unsafe_call = backend.compile_replicated(

View File

@ -597,99 +597,16 @@ class GSDAPjitTest(jtu.JaxTestCase):
gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with jax._src.config.gsda_out(True):
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
def f(x):
return x @ x.T
expected_matrix_mul = input_data @ input_data.T
out = f(gsda_obj)
self.assertIsInstance(out, gsda.GlobalShardedDeviceArray)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out._global_mesh.shape, {'x': 4, 'y': 2})
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
out1 = f(input_data)
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_multi_input_multi_output(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
mesh_axes1 = P('x', 'y')
gsda1 = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes1, cb)
mesh_axes2 = P('x')
gsda2 = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes2, cb)
mesh_axes3 = P(('x', 'y'))
gsda3 = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes3, cb)
mesh_axes4 = P(None)
gsda4 = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes4, cb)
with jax._src.config.gsda_out(True):
@partial(
pjit,
in_axis_resources=(mesh_axes1, mesh_axes2, mesh_axes3, mesh_axes4),
out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
def f(x, y, z, a):
return x @ x.T, y, z, a
out1, out2, out3, out4 = f(gsda1, gsda2, gsda3, gsda4)
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
self.assertEqual(out1.local_shards[0].index, (slice(0, 2), slice(0, 4)))
self.assertEqual(out1.local_shards[1].index, (slice(0, 2), slice(4, 8)))
self.assertListEqual([s.replica_id for s in out1.local_shards],
[0, 0, 0, 0, 0, 0, 0, 0])
expected_matrix_mul = input_data @ input_data.T
for s in out1.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, gsda.GlobalShardedDeviceArray)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.local_shards[0].data.shape, (8, 2))
self.assertEqual(out2.local_shards[0].index, (slice(None), slice(None)))
self.assertEqual(out2.local_shards[1].index, (slice(None), slice(None)))
self.assertListEqual([s.replica_id for s in out2.local_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
for s in out2.local_shards:
self.assertArraysEqual(s.data, input_data)
self.assertIsInstance(out3, gsda.GlobalShardedDeviceArray)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.local_shards[0].data.shape, (2, 2))
self.assertEqual(out3.local_shards[0].index, (slice(0, 2), slice(None)))
self.assertEqual(out3.local_shards[1].index, (slice(0, 2), slice(None)))
self.assertListEqual([s.replica_id for s in out3.local_shards],
[0, 1, 0, 1, 0, 1, 0, 1])
for s in out3.local_shards:
self.assertArraysEqual(s.data, input_data[s.index])
self.assertIsInstance(out4, gsda.GlobalShardedDeviceArray)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.local_shards[0].data.shape, (1, 2))
self.assertEqual(out4.local_shards[0].index, (slice(0, 1), slice(None)))
self.assertEqual(out4.local_shards[1].index, (slice(1, 2), slice(None)))
self.assertListEqual([s.replica_id for s in out4.local_shards],
[0, 0, 0, 0, 0, 0, 0, 0])
for s in out4.local_shards:
self.assertArraysEqual(s.data, input_data[s.index])
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
def f(x):
return x @ x.T
out = f(gsda_obj)
# TODO(yashkatariya): Enable the gsda_out flag and check for GSDA as the
# output.
self.assertIsInstance(out, pxla.ShardedDeviceArray)
self.assertLen(out.device_buffers, 8)
self.assertEqual(out.device_buffers[0].shape, (2, 4))
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_pjit_gsda_mesh_mismatch(self):
@ -903,15 +820,15 @@ class PJitErrorTest(jtu.JaxTestCase):
pjit(lambda x, y: x, p, p)(x, x) # Error, but make sure we hint at tupling
# TODO(apaszke): Disable implicit list casts and enable this
# error = re.escape(
# r"pjit in_axis_resources specification must be a tree prefix of the "
# r"corresponding value, got specification (None, None, None) for value "
# r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that "
# r"are non-trivial pytrees should always be wrapped in a tuple representing "
# r"the argument list. In particular, you're passing in a single argument "
# r"which means that pjit in_axis_resources might need to be wrapped in a "
# r"singleton tuple.")
# r"pjit in_axis_resources specification must be a tree prefix of the "
# r"corresponding value, got specification (None, None, None) for value "
# r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that "
# r"are non-trivial pytrees should always be wrapped in a tuple representing "
# r"the argument list. In particular, you're passing in a single argument "
# r"which means that pjit in_axis_resources might need to be wrapped in a "
# r"singleton tuple.")
# with self.assertRaisesRegex(ValueError, error):
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
# pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
error = re.escape(
r"pjit out_axis_resources specification must be a tree prefix of the "
r"corresponding value, got specification [[None, None, None], None] for "
@ -936,7 +853,6 @@ class PJitErrorTest(jtu.JaxTestCase):
class UtilTest(jtu.JaxTestCase):
def testOpShardingRoundTrip(self):
FakeDevice = namedtuple('FakeDevice', ['id'])
mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)])
@ -963,14 +879,6 @@ class UtilTest(jtu.JaxTestCase):
spec[rng.choice(dims)] += (axis,)
roundtrip(P(*spec))
@parameterized.named_parameters(
("linear", {'x': 0, 'y': 1, 'z': 2}, (('x',), ('y',), ('z',))),
("combine", {'x': 0, 'y': 0, 'z': 1}, (('x', 'y'), ('z',))),
("skip", {'x': 0, 'y': 0, 'z': 2}, (('x', 'y'), None, ('z',))),
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, (('x',), ('y',), None, ('z',))),
)
def test_array_mapping_to_axis_resources(self, inp, expected_out):
self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())