mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Lift non-contiguous mesh restriction for fully replicated values
PiperOrigin-RevId: 420423427
This commit is contained in:
parent
285c20388b
commit
0be3cfe00b
@ -46,9 +46,9 @@ from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
|
||||
split_list, cache, tuple_insert)
|
||||
xops = xc._xla.ops
|
||||
|
||||
class _FromGsdaSingleton:
|
||||
class _FromGdaSingleton:
|
||||
pass
|
||||
FROM_GDA = _FromGsdaSingleton()
|
||||
FROM_GDA = _FromGdaSingleton()
|
||||
|
||||
def _is_from_gda(x):
|
||||
# It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
|
||||
@ -309,10 +309,12 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
|
||||
# will be raised because get_array_mapping (in local_to_global) of a
|
||||
# FROM_GDA cannot happen.
|
||||
tree_map(_check_resources_mismatch, in_axis_resources_flat, is_gda)
|
||||
# If all inputs are GDAs, then the avals are global and the mesh should also
|
||||
# be global. This split is because non-contiguous mesh can only be used if all
|
||||
# inputs are GDAs.
|
||||
if all(is_gda):
|
||||
# If all inputs are either GDAs or fully replicated, then the avals are
|
||||
# global and the mesh should also be global. This split is because
|
||||
# non-contiguous mesh can only be used if all inputs are either GDAs or fully
|
||||
# replicated.
|
||||
if all(((not _is_from_gda(p) and p.partitions == ()) or ig)
|
||||
for p, ig in safe_zip(in_axis_resources_flat, is_gda)):
|
||||
_check_shapes_against_resources(
|
||||
"pjit arguments", mesh.is_multi_process, mesh.shape, local_in_avals,
|
||||
in_axis_resources_flat)
|
||||
@ -355,8 +357,8 @@ class ParsedPartitionSpec:
|
||||
__slots__ = ('partitions', 'unsafe_user_spec', 'sync')
|
||||
|
||||
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
|
||||
self.partitions = tuple(partitions)
|
||||
self.unsafe_user_spec = user_spec
|
||||
self.partitions = tuple(partitions)
|
||||
self.sync = sync
|
||||
|
||||
@property
|
||||
@ -917,14 +919,14 @@ def global_to_local(positional_semantics, mesh, avals, axes):
|
||||
if isinstance(positional_semantics, maps._PositionalSemantics):
|
||||
positional_semantics = [positional_semantics] * len(axes)
|
||||
return [
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL else mesh.global_to_local(
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL or not aval_axes else mesh.global_to_local(
|
||||
get_array_mapping(aval_axes), aval)
|
||||
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
|
||||
]
|
||||
|
||||
def local_to_global(positional_semantics, mesh, avals, axes):
|
||||
return [
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL else mesh.local_to_global(
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL or not aval_axes else mesh.local_to_global(
|
||||
get_array_mapping(aval_axes), aval)
|
||||
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
|
||||
]
|
||||
|
@ -316,7 +316,7 @@ def shard_args(devices: Sequence[xb.xla_client.Device],
|
||||
A list of length matching args, containing lists of per-device buffers
|
||||
for each argument.
|
||||
"""
|
||||
return [_shard_arg(arg, devices, indices[a]) for a, arg in enumerate(args)]
|
||||
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
|
||||
|
||||
|
||||
shard_arg_handlers: Dict[Any, Callable[[Any, Any, Any], Sequence[Any]]] = {}
|
||||
@ -2051,10 +2051,11 @@ class MeshComputation:
|
||||
|
||||
def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
|
||||
input_specs, input_indices, input_avals = [], [], []
|
||||
num_local_devices = len(global_mesh.local_devices)
|
||||
for gaval, axis, is_gda in safe_zip(global_in_avals, in_axes, in_is_gda):
|
||||
# TODO(yashkatariya): Don't calculate input_indices and input_specs for GDA
|
||||
# as GDA doesn't need it.
|
||||
if is_gda:
|
||||
if is_gda or not axis:
|
||||
aval = gaval
|
||||
mesh = global_mesh
|
||||
else:
|
||||
@ -2063,7 +2064,12 @@ def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
|
||||
|
||||
spec = (mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, axis)
|
||||
if aval is not core.abstract_unit else None)
|
||||
index = spec_to_indices(aval.shape, spec) if spec is not None else None
|
||||
# We special case this logic to support fully replicated non-GDA values
|
||||
# with non-contiguous submeshes
|
||||
if not axis:
|
||||
index = tuple(() for _ in range(num_local_devices))
|
||||
else:
|
||||
index = spec_to_indices(aval.shape, spec) if spec is not None else None
|
||||
input_specs.append(spec)
|
||||
input_indices.append(index)
|
||||
input_avals.append(aval)
|
||||
|
@ -886,7 +886,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
# It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
|
||||
# pickling in_axis_resources and sending to other processes). Make sure this
|
||||
# this doesn't cause an error to avoid user confusion.
|
||||
from_gda_dup = pjit_lib._FromGsdaSingleton()
|
||||
from_gda_dup = pjit_lib._FromGdaSingleton()
|
||||
with mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
pjit(lambda x: x, in_axis_resources=from_gda_dup, out_axis_resources=None)(
|
||||
input_gda)
|
||||
|
Loading…
x
Reference in New Issue
Block a user