Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).

`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
This commit is contained in:
Yash Katariya 2023-05-20 22:59:52 -07:00 committed by jax authors
parent e0b5003880
commit b71829f882
4 changed files with 124 additions and 127 deletions

View File

@ -62,7 +62,7 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTOAxisResource, UnspecifiedValue, UNSPECIFIED,
AUTO, UnspecifiedValue, UNSPECIFIED,
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
)
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
@ -1693,7 +1693,7 @@ TilingMethod = Union[TileVectorize, TileManual]
def check_if_any_auto(
shardings: Iterable[Union[sharding_impls.XLACompatibleSharding,
AUTOAxisResource, UnspecifiedValue]]) -> bool:
AUTO, UnspecifiedValue]]) -> bool:
for s in shardings:
if is_auto(s):
return True
@ -1755,8 +1755,7 @@ class DeviceAssignmentMismatchError(Exception):
ShardingInfo = Tuple[
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue,
AUTOAxisResource],
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports
@ -1775,13 +1774,14 @@ def _get_and_check_device_assignment(
devices = tuple(devices)
for i, s_type, source_info in shardings:
if is_auto(i) or is_unspecified(i):
if is_unspecified(i):
continue
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
# skipped.
if first_sharding_info is None:
first_sharding_info = (i._device_assignment, s_type, source_info) # type: ignore
arr_device_assignment = i._device_assignment # type: ignore
first_sharding_info = (
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
else (i._device_assignment, s_type, source_info)) # type: ignore
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
if not devices:
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
@ -1815,7 +1815,7 @@ def cache_wrap(fn):
@cache_wrap
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
keep_unused, donated_invars):
keep_unused, donated_invars, auto_spmd_lowering):
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
if isinstance(fun_or_jaxpr, lu.WrappedFun):
@ -1830,7 +1830,7 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
global_out_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
if (keep_unused or
if (keep_unused or auto_spmd_lowering or
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals)):
kept_var_idx = set(range(len(global_in_avals)))
@ -2006,10 +2006,14 @@ def lower_sharding_computation(
the singleton UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
auto_spmd_lowering = (
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
donated_invars)
donated_invars, auto_spmd_lowering)
jaxpr = closed_jaxpr.jaxpr
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
@ -2091,14 +2095,13 @@ def lower_sharding_computation(
module,
False,
donated_invars,
mesh=None,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
in_shardings=in_shardings,
out_shardings=out_shardings,
spmd_lowering=True,
tuple_args=tuple_args,
auto_spmd_lowering=False,
auto_spmd_lowering=auto_spmd_lowering,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
host_callbacks=host_callbacks,
@ -2112,7 +2115,7 @@ def lower_sharding_computation(
def _to_logical_sharding(
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTO]
) -> Optional[sharding_impls.XLACompatibleSharding]:
if is_unspecified(sharding) or is_auto(sharding):
return None
@ -2131,8 +2134,8 @@ def lower_mesh_computation(
api_name: str,
fun_name: str,
mesh: Mesh,
in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTOAxisResource]],
out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTOAxisResource,
in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO]],
out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO,
UnspecifiedValue]],
donated_invars: Sequence[bool],
spmd_lowering: bool,
@ -2143,11 +2146,6 @@ def lower_mesh_computation(
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
auto_spmd_lowering = check_if_any_auto((*in_shardings, *out_shardings))
if auto_spmd_lowering and not spmd_lowering:
raise ValueError('Enable spmd_lowering to use auto spmd lowering.')
global_axis_sizes = mesh.shape
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
@ -2171,7 +2169,6 @@ def lower_mesh_computation(
else:
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
assert not callable(out_shardings)
assert not auto_spmd_lowering
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
# is why `.spec` can be accessed.
@ -2181,7 +2178,6 @@ def lower_mesh_computation(
in_jaxpr_avals = global_in_avals
else:
assert isinstance(tiling_method, TileVectorize)
assert not auto_spmd_lowering
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
# why `.spec` can be accessed.
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
@ -2274,14 +2270,13 @@ def lower_mesh_computation(
lowering_result.module,
False,
donated_invars,
mesh=mesh,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
in_shardings=in_shardings,
out_shardings=out_shardings,
spmd_lowering=spmd_lowering,
tuple_args=tuple_args,
auto_spmd_lowering=auto_spmd_lowering,
auto_spmd_lowering=False,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
host_callbacks=lowering_result.host_callbacks,
@ -2501,12 +2496,6 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
device_assignment = da.device_assignment if isinstance(
da, _DeviceAssignment) else da
dev: np.ndarray
if auto_spmd_lowering:
assert mesh is not None and spmd_lowering
dev = mesh.devices
num_replicas, num_partitions = 1, mesh.size
else:
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
@ -2610,13 +2599,10 @@ class UnloadedMeshExecutable:
@staticmethod
def from_hlo(name: str,
hlo: ir.Module,
# TODO(yashkatariya): Remove `mesh` from here once AUTO can work
# without mesh.
mesh: Optional[Mesh],
global_in_avals: Sequence[ShapedArray],
global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTOAxisResource]],
out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTOAxisResource,
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO]],
out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO,
UnspecifiedValue]],
spmd_lowering: bool,
tuple_args: bool,
@ -2641,6 +2627,14 @@ class UnloadedMeshExecutable:
device_assignment, _DeviceAssignment) else tuple(device_assignment)
del device_assignment
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
mesh = None
if auto_spmd_lowering:
for i in it.chain.from_iterable([in_shardings, out_shardings]):
if is_auto(i):
mesh = i.mesh # type: ignore
break
xla_executable, compile_options = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
@ -2661,7 +2655,7 @@ class UnloadedMeshExecutable:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if is_auto(i) else i
in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings_tuple = [
(x, True) if is_auto(o) else (o, False)

View File

@ -54,7 +54,7 @@ from jax._src.lib import xla_client as xc
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
AUTOAxisResource, UNSPECIFIED, UnspecifiedValue,
AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.traceback_util import api_boundary
@ -72,10 +72,10 @@ zip, unsafe_zip = safe_zip, zip
traceback_util.register_exclusion(__file__)
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTOAxisResource]
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource]
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource]
MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource]
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTO]
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTO]
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTO]
MeshShardingMinusUnspecified = Union[NamedSharding, AUTO]
logger = logging.getLogger(__name__)
@ -342,6 +342,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums) = infer_params_fn(*args, **kwargs)
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
try:
in_shardings = _resolve_in_shardings(
args_flat, params['in_shardings'], params['out_shardings'], mesh)
lowering = _pjit_lower(
@ -349,6 +350,14 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
params['resource_env'], params['donated_invars'], params['name'],
params['keep_unused'], params['inline'], always_lower=True,
lowering_platform=_experimental_lowering_platform)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
arg_names = _get_arg_names(fun, in_tree, args_flat)
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
raise ValueError(msg) from None
if kwargs:
args_kwargs_in_tree = in_tree
@ -1210,29 +1219,9 @@ def _pjit_lower_cached(
mesh = None
api_name = 'jit'
# Convert to `NamedSharding` when `jax_array` is not enabled. This is
# because GDA/SDA/DA are dependent on mesh for generating outputs.
# NamedSharding is required for host-local inputs too.
any_auto = pxla.check_if_any_auto(it.chain(in_shardings, out_shardings))
if any_auto:
in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef]
Tuple[MeshShardingMinusUnspecified, ...], tuple(
NamedSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore
if isinstance(i, GSPMDSharding) else i
for i in in_shardings
))
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
Tuple[MeshSharding, ...], tuple(
NamedSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore
if isinstance(o, GSPMDSharding) else o
for o in out_shardings
))
# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
# because `xmap` only supports SPMDAxisContext right now.
if any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
if dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
return pxla.lower_mesh_computation(
jaxpr, api_name, name, mesh,
in_shardings, out_shardings, donated_invars,
@ -1929,10 +1918,11 @@ def _fast_path_get_device_assignment(
shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
da = None
for i in shardings:
if is_auto(i) or is_unspecified(i):
if is_unspecified(i):
continue
da = i._device_assignment # type: ignore
break
if is_auto(i):
return i.mesh._flat_devices_tuple # type: ignore
return i._device_assignment # type: ignore
return da

View File

@ -719,12 +719,14 @@ class GSPMDSharding(XLACompatibleSharding):
return cls(tuple(device_assignment), proto)
class AUTOAxisResource:
pass
AUTO = AUTOAxisResource()
class AUTO:
def __init__(self, mesh: mesh_lib.Mesh):
self.mesh = mesh
def is_auto(x):
return isinstance(x, AUTOAxisResource)
return isinstance(x, AUTO)
class UnspecifiedValue:
@ -757,8 +759,7 @@ mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource,
UnspecifiedValue]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
if not array_mapping:
@ -779,11 +780,11 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
return PartitionSpec(*partitions)
def get_array_mapping(
axis_resources: Union[ParsedPartitionSpec, AUTOAxisResource, UnspecifiedValue]
axis_resources: Union[ParsedPartitionSpec, AUTO, UnspecifiedValue]
) -> ArrayMappingOrAutoOrUnspecified:
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
# Don't use `is_auto` here to satisfy pytype and mypy.
if isinstance(axis_resources, (AUTOAxisResource, UnspecifiedValue)):
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
return axis_resources
return OrderedDict((axis, i)
for i, axes in enumerate(axis_resources)

View File

@ -1335,8 +1335,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with global_mesh:
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
f = jax.jit(lambda x: x, in_shardings=AUTO(global_mesh),
out_shardings=AUTO(global_mesh))
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
@ -1355,7 +1355,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with global_mesh:
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
f = pjit(lambda x: x, in_shardings=AUTO(global_mesh),
out_shardings=AUTO(global_mesh))
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()
@ -1379,7 +1380,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with global_mesh:
f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO, out_shardings=AUTO)
f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO(global_mesh),
out_shardings=AUTO(global_mesh))
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp, inp).compile()
self.assertLen(compiled.output_shardings, 3)
@ -1390,33 +1392,42 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
('2d_array', (4, 2), ('x', 'y'), P('y', 'x')),
('1d_array', (8,), ('x'), P('x')),
)
def test_pjit_arr_partial_auto_sharding_array(
def test_jit_arr_partial_auto_sharding_array(
self, mesh_shape, mesh_axis_names, pspec):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
global_input_shape = (8, 4)
input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
in_resource = NamedSharding(global_mesh, pspec)
with global_mesh:
f = pjit(
inp_s = NamedSharding(mesh, pspec)
f = jax.jit(
lambda x, y: (x, y),
in_shardings=(in_resource, AUTO),
out_shardings=AUTO,
)
in_shardings=(inp_s, AUTO(mesh)),
out_shardings=AUTO(mesh))
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp, inp).compile()
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
inputs = [create_array(global_input_shape, mesh, ip, input_data)[0]
for ip in compiled.input_shardings[0]]
self.assertEqual(compiled.input_shardings[0][0], inp_s)
out1, out2 = compiled(*inputs)
for o in [out1, out2]:
self.assertIsInstance(o, array.ArrayImpl)
self.assertArraysEqual(o._value, input_data)
def test_jit_different_mesh_in_auto(self):
mesh1 = jtu.create_global_mesh((4,), ('x',))
dev = jax.devices()
mesh2 = jax.sharding.Mesh([dev[0], dev[3], dev[2], dev[1]], 'x')
f = jax.jit(lambda x, y: (x, y),
in_shardings=(NamedSharding(mesh2, P('x')), AUTO(mesh1)))
inp = core.ShapedArray((8, 2), np.float32)
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for jitted computation"):
f.lower(inp, inp).compile()
@unittest.skip('The error is not raised yet. Enable this back once we raise '
'the error in pjit again.')
def test_pjit_array_error(self):
@ -1429,7 +1440,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with global_mesh:
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
f = pjit(lambda x: x, in_shardings=AUTO(global_mesh),
out_shardings=AUTO(global_mesh))
inp = core.ShapedArray(input_data.shape, input_data.dtype)
compiled = f.lower(inp).compile()