Lift non-contiguous mesh restriction for fully replicated values

PiperOrigin-RevId: 420423427
This commit is contained in:
Yash Katariya 2022-01-07 20:50:05 -08:00 committed by jax authors
parent 285c20388b
commit 0be3cfe00b
3 changed files with 21 additions and 13 deletions

View File

@ -46,9 +46,9 @@ from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
split_list, cache, tuple_insert)
xops = xc._xla.ops
class _FromGsdaSingleton:
class _FromGdaSingleton:
pass
FROM_GDA = _FromGsdaSingleton()
FROM_GDA = _FromGdaSingleton()
def _is_from_gda(x):
# It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
@ -309,10 +309,12 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
# will be raised because get_array_mapping (in local_to_global) of a
# FROM_GDA cannot happen.
tree_map(_check_resources_mismatch, in_axis_resources_flat, is_gda)
# If all inputs are GDAs, then the avals are global and the mesh should also
# be global. This split is because non-contiguous mesh can only be used if all
# inputs are GDAs.
if all(is_gda):
# If all inputs are either GDAs or fully replicated, then the avals are
# global and the mesh should also be global. This split is because
# non-contiguous mesh can only be used if all inputs are either GDAs or fully
# replicated.
if all(((not _is_from_gda(p) and p.partitions == ()) or ig)
for p, ig in safe_zip(in_axis_resources_flat, is_gda)):
_check_shapes_against_resources(
"pjit arguments", mesh.is_multi_process, mesh.shape, local_in_avals,
in_axis_resources_flat)
@ -355,8 +357,8 @@ class ParsedPartitionSpec:
__slots__ = ('partitions', 'unsafe_user_spec', 'sync')
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
self.partitions = tuple(partitions)
self.unsafe_user_spec = user_spec
self.partitions = tuple(partitions)
self.sync = sync
@property
@ -917,14 +919,14 @@ def global_to_local(positional_semantics, mesh, avals, axes):
if isinstance(positional_semantics, maps._PositionalSemantics):
positional_semantics = [positional_semantics] * len(axes)
return [
aval if ps == maps._PositionalSemantics.GLOBAL else mesh.global_to_local(
aval if ps == maps._PositionalSemantics.GLOBAL or not aval_axes else mesh.global_to_local(
get_array_mapping(aval_axes), aval)
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
]
def local_to_global(positional_semantics, mesh, avals, axes):
return [
aval if ps == maps._PositionalSemantics.GLOBAL else mesh.local_to_global(
aval if ps == maps._PositionalSemantics.GLOBAL or not aval_axes else mesh.local_to_global(
get_array_mapping(aval_axes), aval)
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
]

View File

@ -316,7 +316,7 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
A list of length matching args, containing lists of per-device buffers
for each argument.
"""
return [_shard_arg(arg, devices, indices[a]) for a, arg in enumerate(args)]
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
@ -2051,10 +2051,11 @@ class MeshComputation:
def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
input_specs, input_indices, input_avals = [], [], []
num_local_devices = len(global_mesh.local_devices)
for gaval, axis, is_gda in safe_zip(global_in_avals, in_axes, in_is_gda):
# TODO(yashkatariya): Don't calculate input_indices and input_specs for GDA
# as GDA doesn't need it.
if is_gda:
if is_gda or not axis:
aval = gaval
mesh = global_mesh
else:
@ -2063,7 +2064,12 @@ def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
spec = (mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, axis)
if aval is not core.abstract_unit else None)
index = spec_to_indices(aval.shape, spec) if spec is not None else None
# We special case this logic to support fully replicated non-GDA values
# with non-contiguous submeshes
if not axis:
index = tuple(() for _ in range(num_local_devices))
else:
index = spec_to_indices(aval.shape, spec) if spec is not None else None
input_specs.append(spec)
input_indices.append(index)
input_avals.append(aval)

View File

@ -886,7 +886,7 @@ class GDAPjitTest(jtu.JaxTestCase):
# It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
# pickling in_axis_resources and sending to other processes). Make sure this
# this doesn't cause an error to avoid user confusion.
from_gda_dup = pjit_lib._FromGsdaSingleton()
from_gda_dup = pjit_lib._FromGdaSingleton()
with mesh(global_mesh.devices, global_mesh.axis_names):
pjit(lambda x: x, in_axis_resources=from_gda_dup, out_axis_resources=None)(
input_gda)