Let XLA choose in_shardings for inputs who sharding is unspecified.

This is a strict improvement over the current state where JAX always chooses replicated sharding.

PiperOrigin-RevId: 610771289
This commit is contained in:
Yash Katariya 2024-02-27 09:06:21 -08:00 committed by jax authors
parent 98e4b9e7c8
commit c42a035e93
5 changed files with 132 additions and 49 deletions

View File

@ -766,7 +766,7 @@ def make_array_from_single_device_arrays(
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
>>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
When using multiple processes, a common data pipeling is to have data parallelism across devices,
When using multiple processes, a common data pipeline is to have data parallelism across devices,
with each device receiving at least one example. In this case, the following recipe will use
`make_array_from_single_device_arrays` to create a global jax.Array.

View File

@ -2006,8 +2006,15 @@ def lower_sharding_computation(
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
gs = GSPMDSharding.get_replicated(device_assignment)
if xla_extension_version < 240 or hasattr(backend, "compile_replicated"):
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
# TODO(yashkatariya): Allow prng sharding inference by XLA. Enable this after
# output sharding of XLA is partially constrained on the trailing dimensions.
in_shardings = tuple(
gs if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else i for i, a in safe_zip(in_shardings, global_in_avals))
da_object = _create_da_object(tuple(device_assignment))
@ -2318,7 +2325,7 @@ if xla_extension_version < 229:
return input_indices
def get_gspmd_shardings_from_executable(
def get_out_shardings_from_executable(
xla_executable,
device_assignment: Sequence[xc.Device],
num_out_avals: int,
@ -2374,6 +2381,32 @@ def get_gspmd_shardings_from_executable(
for os, mk in safe_zip(out_op_shardings, omk)]
def _get_in_shardings_from_xla(
xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int,
num_ordered_effects: int
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
"""Returns input shardings from XLA."""
from jax._src import pjit
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`.
if len(device_assignment) == 1:
return [sharding_impls.SingleDeviceSharding(device_assignment[0])] * num_in_avals
in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable)
if not in_op_shardings:
return None
if num_ordered_effects > 0:
in_op_shardings = in_op_shardings[num_ordered_effects:]
assert len(in_op_shardings) == num_in_avals, (
len(in_op_shardings), num_in_avals)
return [sharding_impls.GSPMDSharding(device_assignment, os)
for os in in_op_shardings]
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
# without mesh.
def _get_mesh_pspec_shardings_from_executable(
@ -2526,8 +2559,8 @@ def get_logical_mesh_ids(mesh_shape):
@weakref_lru_cache
def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering,
_allow_propagation_to_outputs, host_callbacks, backend,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys,
compiler_options_values):
# TODO(phawkins): One would normally just write:
@ -2580,7 +2613,9 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
get_logical_mesh_ids(list(mesh.shape.values()))
.reshape(-1))
compile_options.parameter_is_tupled_arguments = tuple_args
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)
if xla_extension_version >= 240:
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
if hasattr(backend, "compile_replicated"):
return None, compile_options
@ -2593,22 +2628,59 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
return xla_executable, compile_options
def _get_shardings_from_executable(
def _maybe_get_and_check_in_shardings(
xla_executable, in_shardings, device_assignment,
global_in_avals, num_ordered_effects):
"""Returns in_shardings extracted from XLA or checks and returns original
shardings.
If in_shardings exist on `jit` or on `jax.Array`, then this function will
check that sharding against what XLA returns as in_shardings. If they don't
match, an error is raised.
If in_sharding is unspecified, then the sharding returned by XLA is returned.
"""
in_shardings_xla = _get_in_shardings_from_xla( # type: ignore
xla_executable, device_assignment, len(global_in_avals),
num_ordered_effects) # type: ignore
if in_shardings_xla is None:
return in_shardings
new_in_shardings = []
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
global_in_avals):
if is_unspecified(orig):
new_in_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
# MANUAL HloSharding comes from other partitioning frameworks.
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
not xla_hlo_s.is_manual() and
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
xla_s.memory_kind != orig.memory_kind)): # type: ignore
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_in_shardings.append(orig)
return new_in_shardings
def _get_out_shardings_from_executable(
xla_executable, out_shardings, device_assignment, global_out_avals,
num_ordered_effects, all_default_mem_kind
):
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
out_shardings_xla = get_out_shardings_from_executable( # type: ignore
xla_executable, device_assignment, len(global_out_avals),
num_ordered_effects, all_default_mem_kind) # type: ignore
if out_shardings_xla is None:
return out_shardings, (False,) * len(global_out_avals)
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
new_out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
global_out_avals):
if is_unspecified(orig):
out_shardings.append(xla_s)
new_out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
@ -2621,9 +2693,9 @@ def _get_shardings_from_executable(
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
out_shardings.append(orig)
new_out_shardings.append(orig)
are_out_shardings_from_xla.append(False)
return out_shardings, are_out_shardings_from_xla
return new_out_shardings, are_out_shardings_from_xla
def finalize_out_shardings(out_shardings, are_out_shardings_from_xla,
@ -2722,6 +2794,8 @@ class UnloadedMeshExecutable:
else:
da = _create_da_object(tuple(device_assignment))
del device_assignment
allow_prop_to_inputs = tuple(is_unspecified(i) for i in in_shardings)
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
mesh = None
@ -2733,8 +2807,8 @@ class UnloadedMeshExecutable:
xla_executable, compile_options = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
tuple(host_callbacks), backend, da, pmap_nreps,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
compiler_options_keys, compiler_options_values)
if hasattr(backend, "compile_replicated"):
@ -2761,9 +2835,11 @@ class UnloadedMeshExecutable:
else:
if pmap_nreps == 1:
assert mesh is None
# TODO(yashkatariya): Make da directly usable in the downstream code
# without tuple conversion.
out_shardings, are_out_shardings_from_xla = _get_shardings_from_executable(
if xla_extension_version >= 240:
in_shardings = _maybe_get_and_check_in_shardings(
xla_executable, in_shardings, tuple(da), global_in_avals,
len(ordered_effects))
out_shardings, are_out_shardings_from_xla = _get_out_shardings_from_executable(
xla_executable, out_shardings, tuple(da), global_out_avals,
len(ordered_effects), all_default_mem_kind)
else:

View File

@ -247,14 +247,19 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
# The argument
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]",
count_in_P),
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated",
count_in_replicated),
# The result
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]",
count_out_P),
])
# TODO(b/326476605): Change the condition below if required.
if in_shardings not in [None, "missing"] and out_shardings is not None:
self.check_sharding(
jax2tf.convert(f_jax), [x],
checks=[
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated",
count_in_replicated),
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated",
count_out_replicated),
# No other shardings
(r"custom_call_target.*Sharding",
count_in_P + count_in_replicated + count_out_P + count_out_replicated),
])
@ -437,10 +442,16 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
checks=[
# The input primal argument, and the output grad
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P),
# The primal result, and the input cotangent
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
])
# TODO(b/326476605): Change the condition below if required.
if out_shardings not in [None, "missing"] and in_shardings not in [None, "missing"]:
self.check_sharding(f_grad_tf, [x, x.T],
checks=[
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated),
# The primal result, and the input cotangent
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated),
])
@jtu.parameterized_filterable(

View File

@ -916,7 +916,6 @@ class JaxExportTest(jtu.JaxTestCase):
res_r.addressable_shards[i].data)
@jtu.parameterized_filterable(
one_containing="in_shardings_None_out_shardings_P_with_mesh_False",
kwargs=[
dict(in_shardings=in_shardings, out_shardings=out_shardings,
with_mesh=with_mesh)
@ -971,15 +970,17 @@ class JaxExportTest(jtu.JaxTestCase):
else:
primal_out_sharding = "{replicated}"
main = re.compile(
r"func.func public @main\(%arg0: tensor<10x20xf32>.*"
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\""
r".*%arg1: tensor<20x10xf32>.*"
"mhlo.sharding = \"" + re.escape(primal_out_sharding) + "\""
# result
r".*->.*\(tensor<10x20xf32>.*"
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\"")
self.assertRegex(vjp_module_str, main)
# TODO(b/326476605): Change the condition below if required.
if in_shardings == "P":
main = re.compile(
r"func.func public @main\(%arg0: tensor<10x20xf32>.*"
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\""
r".*%arg1: tensor<20x10xf32>.*"
"mhlo.sharding = \"" + re.escape(primal_out_sharding) + "\""
# result
r".*->.*\(tensor<10x20xf32>.*"
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\"")
self.assertRegex(vjp_module_str, main)
# Custom calls for the primal input shape all match primal_in_sharding
primal_in_calls = re.findall(

View File

@ -559,8 +559,6 @@ class PJitTest(jtu.BufferDonationTestCase):
# Annotations from with_sharding_constraint
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit(self):
@ -1718,19 +1716,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_data = np.arange(
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
with global_mesh:
f = pjit(lambda x: x,
out_shardings=NamedSharding(
global_mesh, P('x', 'y')))
# Since no in_axis_resources is provided, pjit will assume that
# the numpy input is fully replicated over the mesh.
out = f(input_data)
self.assertIsInstance(out, array.ArrayImpl)
for s in out.addressable_shards:
self.assertEqual(s.data.shape, (2, 1))
self.assertArraysEqual(s.data, input_data[s.index])
self.assertArraysEqual(out._value, input_data)
math.prod(input_shape)).reshape(input_shape)
f = pjit(lambda x: x,
out_shardings=NamedSharding(global_mesh, P('x', 'y')))
out = f(input_data)
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out, input_data)
for s in out.addressable_shards:
self.assertEqual(s.data.shape, (2, 1))
self.assertArraysEqual(s.data, input_data[s.index])
def test_numpy_array_input(self):
input_shape = (8, 2)