From b71829f882c56ff94829b63b2b61481426e80d4c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 20 May 2023 22:59:52 -0700 Subject: [PATCH] Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh). `with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD. PiperOrigin-RevId: 533801596 --- jax/_src/interpreters/pxla.py | 100 ++++++++++++++++------------------ jax/_src/pjit.py | 62 +++++++++------------ jax/_src/sharding_impls.py | 17 +++--- tests/pjit_test.py | 72 ++++++++++++++---------- 4 files changed, 124 insertions(+), 127 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 44fff1af8..a4f6b4d55 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -62,7 +62,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, - AUTOAxisResource, UnspecifiedValue, UNSPECIFIED, + AUTO, UnspecifiedValue, UNSPECIFIED, get_array_mapping as _get_array_mapping, is_auto, is_unspecified ) from jax._src.util import (unzip3, safe_map, safe_zip, partition_list, @@ -1693,7 +1693,7 @@ TilingMethod = Union[TileVectorize, TileManual] def check_if_any_auto( shardings: Iterable[Union[sharding_impls.XLACompatibleSharding, - AUTOAxisResource, UnspecifiedValue]]) -> bool: + AUTO, UnspecifiedValue]]) -> bool: for s in shardings: if is_auto(s): return True @@ -1755,8 +1755,7 @@ class DeviceAssignmentMismatchError(Exception): ShardingInfo = Tuple[ - Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, - AUTOAxisResource], + Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO], MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports @@ -1775,13 +1774,14 @@ def _get_and_check_device_assignment( devices = tuple(devices) for i, s_type, source_info in shardings: - if is_auto(i) or is_unspecified(i): + if is_unspecified(i): continue - # Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been - # skipped. + if first_sharding_info is None: - first_sharding_info = (i._device_assignment, s_type, source_info) # type: ignore - arr_device_assignment = i._device_assignment # type: ignore + first_sharding_info = ( + (i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore + else (i._device_assignment, s_type, source_info)) # type: ignore + arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore if not devices: if first_sharding_info[0] != arr_device_assignment: raise DeviceAssignmentMismatchError([ @@ -1815,7 +1815,7 @@ def cache_wrap(fn): @cache_wrap def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name, - keep_unused, donated_invars): + keep_unused, donated_invars, auto_spmd_lowering): name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) if isinstance(fun_or_jaxpr, lu.WrappedFun): @@ -1830,7 +1830,7 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name, global_out_avals = fun_or_jaxpr.out_avals consts = fun_or_jaxpr.consts - if (keep_unused or + if (keep_unused or auto_spmd_lowering or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape) for a in global_in_avals)): kept_var_idx = set(range(len(global_in_avals))) @@ -2006,10 +2006,14 @@ def lower_sharding_computation( the singleton UNSPECIFIED to all out_avals. """ # 1. Trace to jaxpr and preprocess/verify it + auto_spmd_lowering = ( + check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else + check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore + (closed_jaxpr, global_in_avals, global_out_avals, donated_invars, kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce( fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused, - donated_invars) + donated_invars, auto_spmd_lowering) jaxpr = closed_jaxpr.jaxpr in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) @@ -2091,14 +2095,13 @@ def lower_sharding_computation( module, False, donated_invars, - mesh=None, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, out_shardings=out_shardings, spmd_lowering=True, tuple_args=tuple_args, - auto_spmd_lowering=False, + auto_spmd_lowering=auto_spmd_lowering, unordered_effects=unordered_effects, ordered_effects=ordered_effects, host_callbacks=host_callbacks, @@ -2112,7 +2115,7 @@ def lower_sharding_computation( def _to_logical_sharding( - aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource] + aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTO] ) -> Optional[sharding_impls.XLACompatibleSharding]: if is_unspecified(sharding) or is_auto(sharding): return None @@ -2131,9 +2134,9 @@ def lower_mesh_computation( api_name: str, fun_name: str, mesh: Mesh, - in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTOAxisResource]], - out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTOAxisResource, - UnspecifiedValue]], + in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO]], + out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO, + UnspecifiedValue]], donated_invars: Sequence[bool], spmd_lowering: bool, global_in_avals: Sequence[core.ShapedArray], @@ -2143,11 +2146,6 @@ def lower_mesh_computation( backend = xb.get_device_backend(mesh.devices.flat[0]) name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) - auto_spmd_lowering = check_if_any_auto((*in_shardings, *out_shardings)) - - if auto_spmd_lowering and not spmd_lowering: - raise ValueError('Enable spmd_lowering to use auto spmd lowering.') - global_axis_sizes = mesh.shape log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG @@ -2171,7 +2169,6 @@ def lower_mesh_computation( else: raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}") assert not callable(out_shardings) - assert not auto_spmd_lowering assert isinstance(fun_or_jaxpr, lu.WrappedFun) # This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which # is why `.spec` can be accessed. @@ -2181,7 +2178,6 @@ def lower_mesh_computation( in_jaxpr_avals = global_in_avals else: assert isinstance(tiling_method, TileVectorize) - assert not auto_spmd_lowering # In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is # why `.spec` can be accessed. in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore @@ -2274,14 +2270,13 @@ def lower_mesh_computation( lowering_result.module, False, donated_invars, - mesh=mesh, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, out_shardings=out_shardings, spmd_lowering=spmd_lowering, tuple_args=tuple_args, - auto_spmd_lowering=auto_spmd_lowering, + auto_spmd_lowering=False, unordered_effects=unordered_effects, ordered_effects=ordered_effects, host_callbacks=lowering_result.host_callbacks, @@ -2501,26 +2496,20 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, device_assignment = da.device_assignment if isinstance( da, _DeviceAssignment) else da - dev: np.ndarray - if auto_spmd_lowering: - assert mesh is not None and spmd_lowering - dev = mesh.devices - num_replicas, num_partitions = 1, mesh.size + # TODO(phawkins): One would normally just write: + # dev = np.array(device_assignment) + # The formulation below is substantially faster if there are many devices. + # If we were to optimize __getattr__ on xc.Device we might not need this + # workaround. + dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])( + np.arange(len(device_assignment)) + ) + if pmap_nreps > 1: + num_replicas, num_partitions = pmap_nreps, 1 + elif spmd_lowering: + num_replicas, num_partitions = 1, dev.size else: - # TODO(phawkins): One would normally just write: - # dev = np.array(device_assignment) - # The formulation below is substantially faster if there are many devices. - # If we were to optimize __getattr__ on xc.Device we might not need this - # workaround. - dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])( - np.arange(len(device_assignment)) - ) - if pmap_nreps > 1: - num_replicas, num_partitions = pmap_nreps, 1 - elif spmd_lowering: - num_replicas, num_partitions = 1, dev.size - else: - num_replicas, num_partitions = dev.size, 1 + num_replicas, num_partitions = dev.size, 1 if pmap_nreps > 1: # In `jit` device_assignment is set to None when num_replicas > 1. Do @@ -2610,14 +2599,11 @@ class UnloadedMeshExecutable: @staticmethod def from_hlo(name: str, hlo: ir.Module, - # TODO(yashkatariya): Remove `mesh` from here once AUTO can work - # without mesh. - mesh: Optional[Mesh], global_in_avals: Sequence[ShapedArray], global_out_avals: Sequence[ShapedArray], - in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTOAxisResource]], - out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTOAxisResource, - UnspecifiedValue]], + in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO]], + out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO, + UnspecifiedValue]], spmd_lowering: bool, tuple_args: bool, auto_spmd_lowering: bool, @@ -2641,6 +2627,14 @@ class UnloadedMeshExecutable: device_assignment, _DeviceAssignment) else tuple(device_assignment) del device_assignment allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings) + + mesh = None + if auto_spmd_lowering: + for i in it.chain.from_iterable([in_shardings, out_shardings]): + if is_auto(i): + mesh = i.mesh # type: ignore + break + xla_executable, compile_options = _cached_compilation( hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_outputs, @@ -2661,7 +2655,7 @@ class UnloadedMeshExecutable: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( xla_executable, mesh) - in_shardings = [x if is_auto(i) else i + in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore for x, i in safe_zip(in_shardings_xla, in_shardings)] out_shardings_tuple = [ (x, True) if is_auto(o) else (o, False) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 84b9928a6..945c72420 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -54,7 +54,7 @@ from jax._src.lib import xla_client as xc from jax._src.sharding_impls import ( NamedSharding, XLACompatibleSharding, GSPMDSharding, XLADeviceAssignment, SingleDeviceSharding, PmapSharding, - AUTOAxisResource, UNSPECIFIED, UnspecifiedValue, + AUTO, UNSPECIFIED, UnspecifiedValue, ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified, is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding) from jax._src.traceback_util import api_boundary @@ -72,10 +72,10 @@ zip, unsafe_zip = safe_zip, zip traceback_util.register_exclusion(__file__) -PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTOAxisResource] -PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource] -MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource] -MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource] +PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTO] +PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTO] +MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTO] +MeshShardingMinusUnspecified = Union[NamedSharding, AUTO] logger = logging.getLogger(__name__) @@ -342,13 +342,22 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, donate_argnums) = infer_params_fn(*args, **kwargs) resource_env = params['resource_env'] mesh = None if resource_env is None else resource_env.physical_mesh - in_shardings = _resolve_in_shardings( - args_flat, params['in_shardings'], params['out_shardings'], mesh) - lowering = _pjit_lower( - params['jaxpr'], in_shardings, params['out_shardings'], - params['resource_env'], params['donated_invars'], params['name'], - params['keep_unused'], params['inline'], always_lower=True, - lowering_platform=_experimental_lowering_platform) + try: + in_shardings = _resolve_in_shardings( + args_flat, params['in_shardings'], params['out_shardings'], mesh) + lowering = _pjit_lower( + params['jaxpr'], in_shardings, params['out_shardings'], + params['resource_env'], params['donated_invars'], params['name'], + params['keep_unused'], params['inline'], always_lower=True, + lowering_platform=_experimental_lowering_platform) + except pxla.DeviceAssignmentMismatchError as e: + fails, = e.args + api_name = 'jit' if params['resource_env'] is None else 'pjit' + arg_names = _get_arg_names(fun, in_tree, args_flat) + fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) + msg = _device_assignment_mismatch_error( + fun_name, fails, args_flat, api_name, arg_names) + raise ValueError(msg) from None if kwargs: args_kwargs_in_tree = in_tree @@ -1210,29 +1219,9 @@ def _pjit_lower_cached( mesh = None api_name = 'jit' - # Convert to `NamedSharding` when `jax_array` is not enabled. This is - # because GDA/SDA/DA are dependent on mesh for generating outputs. - # NamedSharding is required for host-local inputs too. - any_auto = pxla.check_if_any_auto(it.chain(in_shardings, out_shardings)) - if any_auto: - in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef] - Tuple[MeshShardingMinusUnspecified, ...], tuple( - NamedSharding._from_parsed_pspec( - mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore - if isinstance(i, GSPMDSharding) else i - for i in in_shardings - )) - out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef] - Tuple[MeshSharding, ...], tuple( - NamedSharding._from_parsed_pspec( - mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore - if isinstance(o, GSPMDSharding) else o - for o in out_shardings - )) - # For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path # because `xmap` only supports SPMDAxisContext right now. - if any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'): + if dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'): return pxla.lower_mesh_computation( jaxpr, api_name, name, mesh, in_shardings, out_shardings, donated_invars, @@ -1929,10 +1918,11 @@ def _fast_path_get_device_assignment( shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]: da = None for i in shardings: - if is_auto(i) or is_unspecified(i): + if is_unspecified(i): continue - da = i._device_assignment # type: ignore - break + if is_auto(i): + return i.mesh._flat_devices_tuple # type: ignore + return i._device_assignment # type: ignore return da diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 9b32880ed..5aba33de8 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -719,12 +719,14 @@ class GSPMDSharding(XLACompatibleSharding): return cls(tuple(device_assignment), proto) -class AUTOAxisResource: - pass -AUTO = AUTOAxisResource() +class AUTO: + + def __init__(self, mesh: mesh_lib.Mesh): + self.mesh = mesh + def is_auto(x): - return isinstance(x, AUTOAxisResource) + return isinstance(x, AUTO) class UnspecifiedValue: @@ -757,8 +759,7 @@ mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, the mesh devices ndarray would have to be transposed before flattening and assignment. """ ArrayMapping = OrderedDictType[MeshAxisName, int] -ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource, - UnspecifiedValue] +ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue] def array_mapping_to_axis_resources(array_mapping: ArrayMapping): if not array_mapping: @@ -779,11 +780,11 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): return PartitionSpec(*partitions) def get_array_mapping( - axis_resources: Union[ParsedPartitionSpec, AUTOAxisResource, UnspecifiedValue] + axis_resources: Union[ParsedPartitionSpec, AUTO, UnspecifiedValue] ) -> ArrayMappingOrAutoOrUnspecified: # TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported. # Don't use `is_auto` here to satisfy pytype and mypy. - if isinstance(axis_resources, (AUTOAxisResource, UnspecifiedValue)): + if isinstance(axis_resources, (AUTO, UnspecifiedValue)): return axis_resources return OrderedDict((axis, i) for i, axes in enumerate(axis_resources) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 765c95dd5..3daa8b146 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1335,16 +1335,16 @@ class AutoShardingPjitTest(jtu.JaxTestCase): input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) - with global_mesh: - f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO) + f = jax.jit(lambda x: x, in_shardings=AUTO(global_mesh), + out_shardings=AUTO(global_mesh)) - inp = core.ShapedArray(input_data.shape, input_data.dtype) - compiled = f.lower(inp).compile() - inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0] - for ip in compiled.input_shardings[0]] - out = compiled(*inputs) - self.assertIsInstance(out, array.ArrayImpl) - self.assertArraysEqual(out._value, input_data) + inp = core.ShapedArray(input_data.shape, input_data.dtype) + compiled = f.lower(inp).compile() + inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0] + for ip in compiled.input_shardings[0]] + out = compiled(*inputs) + self.assertIsInstance(out, array.ArrayImpl) + self.assertArraysEqual(out._value, input_data) def test_xla_arr_sharding_mismatch(self): if xla_bridge.get_backend().runtime_type == 'stream_executor': @@ -1355,7 +1355,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase): math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) with global_mesh: - f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO) + f = pjit(lambda x: x, in_shardings=AUTO(global_mesh), + out_shardings=AUTO(global_mesh)) inp = core.ShapedArray(input_data.shape, input_data.dtype) compiled = f.lower(inp).compile() @@ -1379,7 +1380,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase): math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) with global_mesh: - f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO, out_shardings=AUTO) + f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO(global_mesh), + out_shardings=AUTO(global_mesh)) inp = core.ShapedArray(input_data.shape, input_data.dtype) compiled = f.lower(inp, inp, inp).compile() self.assertLen(compiled.output_shardings, 3) @@ -1390,32 +1392,41 @@ class AutoShardingPjitTest(jtu.JaxTestCase): ('2d_array', (4, 2), ('x', 'y'), P('y', 'x')), ('1d_array', (8,), ('x'), P('x')), ) - def test_pjit_arr_partial_auto_sharding_array( + def test_jit_arr_partial_auto_sharding_array( self, mesh_shape, mesh_axis_names, pspec): if xla_bridge.get_backend().runtime_type == 'stream_executor': raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.') - global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) + mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + inp_s = NamedSharding(mesh, pspec) + f = jax.jit( + lambda x, y: (x, y), + in_shardings=(inp_s, AUTO(mesh)), + out_shardings=AUTO(mesh)) - in_resource = NamedSharding(global_mesh, pspec) + inp = core.ShapedArray(input_data.shape, input_data.dtype) + compiled = f.lower(inp, inp).compile() + inputs = [create_array(global_input_shape, mesh, ip, input_data)[0] + for ip in compiled.input_shardings[0]] + self.assertEqual(compiled.input_shardings[0][0], inp_s) + out1, out2 = compiled(*inputs) + for o in [out1, out2]: + self.assertIsInstance(o, array.ArrayImpl) + self.assertArraysEqual(o._value, input_data) - with global_mesh: - f = pjit( - lambda x, y: (x, y), - in_shardings=(in_resource, AUTO), - out_shardings=AUTO, - ) - - inp = core.ShapedArray(input_data.shape, input_data.dtype) - compiled = f.lower(inp, inp).compile() - inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0] - for ip in compiled.input_shardings[0]] - out1, out2 = compiled(*inputs) - for o in [out1, out2]: - self.assertIsInstance(o, array.ArrayImpl) - self.assertArraysEqual(o._value, input_data) + def test_jit_different_mesh_in_auto(self): + mesh1 = jtu.create_global_mesh((4,), ('x',)) + dev = jax.devices() + mesh2 = jax.sharding.Mesh([dev[0], dev[3], dev[2], dev[1]], 'x') + f = jax.jit(lambda x, y: (x, y), + in_shardings=(NamedSharding(mesh2, P('x')), AUTO(mesh1))) + inp = core.ShapedArray((8, 2), np.float32) + with self.assertRaisesRegex( + ValueError, + "Received incompatible devices for jitted computation"): + f.lower(inp, inp).compile() @unittest.skip('The error is not raised yet. Enable this back once we raise ' 'the error in pjit again.') @@ -1429,7 +1440,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase): math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) with global_mesh: - f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO) + f = pjit(lambda x: x, in_shardings=AUTO(global_mesh), + out_shardings=AUTO(global_mesh)) inp = core.ShapedArray(input_data.shape, input_data.dtype) compiled = f.lower(inp).compile()