mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Let XLA choose in_shardings for inputs who sharding is unspecified.
This is a strict improvement over the current state where JAX always chooses replicated sharding. PiperOrigin-RevId: 610771289
This commit is contained in:
parent
98e4b9e7c8
commit
c42a035e93
@ -766,7 +766,7 @@ def make_array_from_single_device_arrays(
|
||||
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
|
||||
>>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
|
||||
|
||||
When using multiple processes, a common data pipeling is to have data parallelism across devices,
|
||||
When using multiple processes, a common data pipeline is to have data parallelism across devices,
|
||||
with each device receiving at least one example. In this case, the following recipe will use
|
||||
`make_array_from_single_device_arrays` to create a global jax.Array.
|
||||
|
||||
|
@ -2006,8 +2006,15 @@ def lower_sharding_computation(
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
|
||||
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
gs = GSPMDSharding.get_replicated(device_assignment)
|
||||
if xla_extension_version < 240 or hasattr(backend, "compile_replicated"):
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
# TODO(yashkatariya): Allow prng sharding inference by XLA. Enable this after
|
||||
# output sharding of XLA is partially constrained on the trailing dimensions.
|
||||
in_shardings = tuple(
|
||||
gs if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
|
||||
else i for i, a in safe_zip(in_shardings, global_in_avals))
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
@ -2318,7 +2325,7 @@ if xla_extension_version < 229:
|
||||
return input_indices
|
||||
|
||||
|
||||
def get_gspmd_shardings_from_executable(
|
||||
def get_out_shardings_from_executable(
|
||||
xla_executable,
|
||||
device_assignment: Sequence[xc.Device],
|
||||
num_out_avals: int,
|
||||
@ -2374,6 +2381,32 @@ def get_gspmd_shardings_from_executable(
|
||||
for os, mk in safe_zip(out_op_shardings, omk)]
|
||||
|
||||
|
||||
def _get_in_shardings_from_xla(
|
||||
xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int,
|
||||
num_ordered_effects: int
|
||||
) -> Sequence[sharding_impls.XLACompatibleSharding] | None:
|
||||
"""Returns input shardings from XLA."""
|
||||
from jax._src import pjit
|
||||
|
||||
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
||||
# Hence the op shardings will not be set on the `hlo_module`.
|
||||
if len(device_assignment) == 1:
|
||||
return [sharding_impls.SingleDeviceSharding(device_assignment[0])] * num_in_avals
|
||||
|
||||
in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable)
|
||||
if not in_op_shardings:
|
||||
return None
|
||||
|
||||
if num_ordered_effects > 0:
|
||||
in_op_shardings = in_op_shardings[num_ordered_effects:]
|
||||
|
||||
assert len(in_op_shardings) == num_in_avals, (
|
||||
len(in_op_shardings), num_in_avals)
|
||||
|
||||
return [sharding_impls.GSPMDSharding(device_assignment, os)
|
||||
for os in in_op_shardings]
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
|
||||
# without mesh.
|
||||
def _get_mesh_pspec_shardings_from_executable(
|
||||
@ -2526,8 +2559,8 @@ def get_logical_mesh_ids(mesh_shape):
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering,
|
||||
_allow_propagation_to_outputs, host_callbacks, backend,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
allow_prop_to_outputs, host_callbacks, backend,
|
||||
da, pmap_nreps, compiler_options_keys,
|
||||
compiler_options_values):
|
||||
# TODO(phawkins): One would normally just write:
|
||||
@ -2580,7 +2613,9 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
get_logical_mesh_ids(list(mesh.shape.values()))
|
||||
.reshape(-1))
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)
|
||||
if xla_extension_version >= 240:
|
||||
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
|
||||
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
|
||||
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
return None, compile_options
|
||||
@ -2593,22 +2628,59 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
return xla_executable, compile_options
|
||||
|
||||
|
||||
def _get_shardings_from_executable(
|
||||
def _maybe_get_and_check_in_shardings(
|
||||
xla_executable, in_shardings, device_assignment,
|
||||
global_in_avals, num_ordered_effects):
|
||||
"""Returns in_shardings extracted from XLA or checks and returns original
|
||||
shardings.
|
||||
|
||||
If in_shardings exist on `jit` or on `jax.Array`, then this function will
|
||||
check that sharding against what XLA returns as in_shardings. If they don't
|
||||
match, an error is raised.
|
||||
|
||||
If in_sharding is unspecified, then the sharding returned by XLA is returned.
|
||||
"""
|
||||
in_shardings_xla = _get_in_shardings_from_xla( # type: ignore
|
||||
xla_executable, device_assignment, len(global_in_avals),
|
||||
num_ordered_effects) # type: ignore
|
||||
if in_shardings_xla is None:
|
||||
return in_shardings
|
||||
|
||||
new_in_shardings = []
|
||||
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
|
||||
global_in_avals):
|
||||
if is_unspecified(orig):
|
||||
new_in_shardings.append(xla_s)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
# MANUAL HloSharding comes from other partitioning frameworks.
|
||||
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
|
||||
not xla_hlo_s.is_manual() and
|
||||
(not op_shardings.are_op_shardings_equal(xla_hlo_s, orig_hlo_s) or
|
||||
xla_s.memory_kind != orig.memory_kind)): # type: ignore
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
new_in_shardings.append(orig)
|
||||
return new_in_shardings
|
||||
|
||||
|
||||
def _get_out_shardings_from_executable(
|
||||
xla_executable, out_shardings, device_assignment, global_out_avals,
|
||||
num_ordered_effects, all_default_mem_kind
|
||||
):
|
||||
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
out_shardings_xla = get_out_shardings_from_executable( # type: ignore
|
||||
xla_executable, device_assignment, len(global_out_avals),
|
||||
num_ordered_effects, all_default_mem_kind) # type: ignore
|
||||
if out_shardings_xla is None:
|
||||
return out_shardings, (False,) * len(global_out_avals)
|
||||
|
||||
orig_out_shardings = out_shardings
|
||||
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
|
||||
new_out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
|
||||
global_out_avals):
|
||||
if is_unspecified(orig):
|
||||
out_shardings.append(xla_s)
|
||||
new_out_shardings.append(xla_s)
|
||||
are_out_shardings_from_xla.append(True)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
|
||||
@ -2621,9 +2693,9 @@ def _get_shardings_from_executable(
|
||||
raise AssertionError(
|
||||
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
|
||||
"(User sharding)")
|
||||
out_shardings.append(orig)
|
||||
new_out_shardings.append(orig)
|
||||
are_out_shardings_from_xla.append(False)
|
||||
return out_shardings, are_out_shardings_from_xla
|
||||
return new_out_shardings, are_out_shardings_from_xla
|
||||
|
||||
|
||||
def finalize_out_shardings(out_shardings, are_out_shardings_from_xla,
|
||||
@ -2722,6 +2794,8 @@ class UnloadedMeshExecutable:
|
||||
else:
|
||||
da = _create_da_object(tuple(device_assignment))
|
||||
del device_assignment
|
||||
|
||||
allow_prop_to_inputs = tuple(is_unspecified(i) for i in in_shardings)
|
||||
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
|
||||
|
||||
mesh = None
|
||||
@ -2733,8 +2807,8 @@ class UnloadedMeshExecutable:
|
||||
|
||||
xla_executable, compile_options = _cached_compilation(
|
||||
hlo, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
|
||||
tuple(host_callbacks), backend, da, pmap_nreps,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
|
||||
compiler_options_keys, compiler_options_values)
|
||||
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
@ -2761,9 +2835,11 @@ class UnloadedMeshExecutable:
|
||||
else:
|
||||
if pmap_nreps == 1:
|
||||
assert mesh is None
|
||||
# TODO(yashkatariya): Make da directly usable in the downstream code
|
||||
# without tuple conversion.
|
||||
out_shardings, are_out_shardings_from_xla = _get_shardings_from_executable(
|
||||
if xla_extension_version >= 240:
|
||||
in_shardings = _maybe_get_and_check_in_shardings(
|
||||
xla_executable, in_shardings, tuple(da), global_in_avals,
|
||||
len(ordered_effects))
|
||||
out_shardings, are_out_shardings_from_xla = _get_out_shardings_from_executable(
|
||||
xla_executable, out_shardings, tuple(da), global_out_avals,
|
||||
len(ordered_effects), all_default_mem_kind)
|
||||
else:
|
||||
|
@ -247,14 +247,19 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
# The argument
|
||||
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]",
|
||||
count_in_P),
|
||||
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated",
|
||||
count_in_replicated),
|
||||
# The result
|
||||
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]",
|
||||
count_out_P),
|
||||
])
|
||||
# TODO(b/326476605): Change the condition below if required.
|
||||
if in_shardings not in [None, "missing"] and out_shardings is not None:
|
||||
self.check_sharding(
|
||||
jax2tf.convert(f_jax), [x],
|
||||
checks=[
|
||||
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated",
|
||||
count_in_replicated),
|
||||
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated",
|
||||
count_out_replicated),
|
||||
# No other shardings
|
||||
(r"custom_call_target.*Sharding",
|
||||
count_in_P + count_in_replicated + count_out_P + count_out_replicated),
|
||||
])
|
||||
@ -437,10 +442,16 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
checks=[
|
||||
# The input primal argument, and the output grad
|
||||
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P),
|
||||
# The primal result, and the input cotangent
|
||||
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
|
||||
])
|
||||
# TODO(b/326476605): Change the condition below if required.
|
||||
if out_shardings not in [None, "missing"] and in_shardings not in [None, "missing"]:
|
||||
self.check_sharding(f_grad_tf, [x, x.T],
|
||||
checks=[
|
||||
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated),
|
||||
# The primal result, and the input cotangent
|
||||
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
|
||||
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", count_out_replicated),
|
||||
])
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
|
@ -916,7 +916,6 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
res_r.addressable_shards[i].data)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
one_containing="in_shardings_None_out_shardings_P_with_mesh_False",
|
||||
kwargs=[
|
||||
dict(in_shardings=in_shardings, out_shardings=out_shardings,
|
||||
with_mesh=with_mesh)
|
||||
@ -971,15 +970,17 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
else:
|
||||
primal_out_sharding = "{replicated}"
|
||||
|
||||
main = re.compile(
|
||||
r"func.func public @main\(%arg0: tensor<10x20xf32>.*"
|
||||
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\""
|
||||
r".*%arg1: tensor<20x10xf32>.*"
|
||||
"mhlo.sharding = \"" + re.escape(primal_out_sharding) + "\""
|
||||
# result
|
||||
r".*->.*\(tensor<10x20xf32>.*"
|
||||
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\"")
|
||||
self.assertRegex(vjp_module_str, main)
|
||||
# TODO(b/326476605): Change the condition below if required.
|
||||
if in_shardings == "P":
|
||||
main = re.compile(
|
||||
r"func.func public @main\(%arg0: tensor<10x20xf32>.*"
|
||||
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\""
|
||||
r".*%arg1: tensor<20x10xf32>.*"
|
||||
"mhlo.sharding = \"" + re.escape(primal_out_sharding) + "\""
|
||||
# result
|
||||
r".*->.*\(tensor<10x20xf32>.*"
|
||||
"mhlo.sharding = \"" + re.escape(primal_in_sharding) + "\"")
|
||||
self.assertRegex(vjp_module_str, main)
|
||||
|
||||
# Custom calls for the primal input shape all match primal_in_sharding
|
||||
primal_in_calls = re.findall(
|
||||
|
@ -559,8 +559,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
# Annotations from with_sharding_constraint
|
||||
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
|
||||
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
|
||||
# Annotation from pjit
|
||||
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
|
||||
|
||||
def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit(self):
|
||||
|
||||
@ -1718,19 +1716,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
input_shape = (8, 2)
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
input_data = np.arange(
|
||||
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x,
|
||||
out_shardings=NamedSharding(
|
||||
global_mesh, P('x', 'y')))
|
||||
# Since no in_axis_resources is provided, pjit will assume that
|
||||
# the numpy input is fully replicated over the mesh.
|
||||
out = f(input_data)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
for s in out.addressable_shards:
|
||||
self.assertEqual(s.data.shape, (2, 1))
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
f = pjit(lambda x: x,
|
||||
out_shardings=NamedSharding(global_mesh, P('x', 'y')))
|
||||
out = f(input_data)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out, input_data)
|
||||
for s in out.addressable_shards:
|
||||
self.assertEqual(s.data.shape, (2, 1))
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
def test_numpy_array_input(self):
|
||||
input_shape = (8, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user