diff --git a/jax/_src/config.py b/jax/_src/config.py index 0a71c0b5e..f122ffb05 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -784,7 +784,7 @@ def _update_jax_memories_thread_local(val): enable_memories = config.define_bool_state( 'jax_enable_memories', - default=False, + default=True, upgrade=True, update_global_hook=_update_jax_memories_global, update_thread_local_hook=_update_jax_memories_thread_local, diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index d9d064c77..e7c39b953 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -529,7 +529,7 @@ def device_put_transpose_rule(ct, _, device, src): ad.deflinear2(device_put_p, device_put_transpose_rule) batching.defvectorized(device_put_p) -def _device_put_lowering(ctx, x, *, device, src): +def _tpu_device_put_lowering(ctx, x, *, device, src): if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and device.memory_kind is not None): aval, = ctx.avals_in @@ -540,4 +540,14 @@ def _device_put_lowering(ctx, x, *, device, src): ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto()) return [x] return [x] -mlir.register_lowering(device_put_p, _device_put_lowering) +mlir.register_lowering(device_put_p, _tpu_device_put_lowering, platform='tpu') + + +def _common_device_put_lowering(ctx, x, *, device, src): + if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and + device.memory_kind is not None): + raise NotImplementedError( + "Passing memory_kind to device_put via Shardings is not supported on" + f" platform {ctx.module_context.platform}") + return [x] +mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ad561d5e4..ddc7770c0 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -676,6 +676,7 @@ def lower_jaxpr_to_module( result_names: Sequence[str | None] | None = None, num_replicas: int = 1, num_partitions: int = 1, + all_default_mem_kind: bool = True, override_lowering_rules: None | ( tuple[tuple[core.Primitive, LoweringRule]]) = None, ) -> LoweringResult: @@ -701,10 +702,14 @@ def lower_jaxpr_to_module( map(sharded_aval, jaxpr.in_avals, arg_shardings)) out_avals = (jaxpr.out_avals if result_shardings is None else map(sharded_aval, jaxpr.out_avals, result_shardings)) - arg_memory_kinds = (map(_get_mem_kind, arg_shardings) - if arg_shardings is not None else None) - result_memory_kinds = (map(_get_mem_kind, result_shardings) - if result_shardings is not None else None) + if all_default_mem_kind: + arg_memory_kinds = None + result_memory_kinds = None + else: + arg_memory_kinds = (map(_get_mem_kind, arg_shardings) + if arg_shardings is not None else None) + result_memory_kinds = (map(_get_mem_kind, result_shardings) + if result_shardings is not None else None) if platform in _platforms_with_donation: input_output_aliases, donated_args = _set_up_aliases( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a251efaaf..14dfe46a8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -66,7 +66,8 @@ from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, UNSPECIFIED, - get_array_mapping as _get_array_mapping, is_auto, is_unspecified + get_array_mapping as _get_array_mapping, is_auto, is_unspecified, + is_unspecified_or_auto ) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_delete, distributed_debug_log, @@ -1783,7 +1784,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap( def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, da_object, lowering_platform, - donated_invars, name_stack, override_lowering_rules): + donated_invars, name_stack, all_default_mem_kind, + override_lowering_rules): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings.shardings out_shardings = semantic_out_shardings.shardings @@ -1855,6 +1857,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, num_replicas=nreps, num_partitions=num_partitions, + all_default_mem_kind=all_default_mem_kind, override_lowering_rules=override_lowering_rules) tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform) unordered_effects = list( @@ -1925,15 +1928,27 @@ def _create_da_object( # pytype: disable=invalid-annotation return _DeviceAssignment(device_assignment) -def jaxpr_has_dp_with_transfer_mem_kind(jaxpr: core.Jaxpr) -> bool: +def jaxpr_transfer_mem_kinds( + jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]: for eqn in jaxpr.eqns: if (eqn.primitive is dispatch.device_put_p and isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)): - return True + yield eqn.params['device'] for subjaxpr in core.subjaxprs(jaxpr): - if jaxpr_has_dp_with_transfer_mem_kind(subjaxpr): - return True - return False + yield from jaxpr_transfer_mem_kinds(subjaxpr) + + +def are_all_shardings_default_mem_kind(da_object, shardings): + try: + default_mem_kind = da_object.default_memory_kind + except: + return True + for i in shardings: + if is_unspecified_or_auto(i): + continue + if i.memory_kind != default_mem_kind: + return False + return True @profiler.annotate_function @@ -1988,19 +2003,26 @@ def lower_sharding_computation( for js, source_info in jaxpr_sharding]), devices_from_context) + transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) + committed = bool( devices_from_context or len(device_assignment) > 1 or any(not is_unspecified(i) for i in in_shardings) or any(not is_unspecified(js) for js, _ in jaxpr_sharding) or any(not is_unspecified(o) for o in out_shardings) or - jaxpr_has_dp_with_transfer_mem_kind(jaxpr)) + transfer_mem_kind_in_jaxpr) gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment) in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) da_object = _create_da_object(tuple(device_assignment)) + all_default_mem_kind = are_all_shardings_default_mem_kind( + da_object, + it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding], # type: ignore + transfer_mem_kind_in_jaxpr)) + if not da_object.is_fully_addressable: # type: ignore if inline and config.jax_spmd_mode != 'allow_all': raise RuntimeError( @@ -2022,7 +2044,7 @@ def lower_sharding_computation( nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, da_object, lowering_platform, - donated_invars, name_stack, override_lowering_rules) + donated_invars, name_stack, all_default_mem_kind, override_lowering_rules) # backend and device_assignment is passed through to MeshExecutable because # if keep_unused=False and all in_shardings are pruned, then there is no way @@ -2050,7 +2072,8 @@ def lower_sharding_computation( committed=committed, pmap_nreps=nreps, jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, - shape_poly_state=shape_poly_state) + shape_poly_state=shape_poly_state, + all_default_mem_kind=all_default_mem_kind) def _to_logical_sharding( @@ -2295,18 +2318,25 @@ def _get_input_indices( def get_gspmd_shardings_from_executable( - xla_executable, device_assignment: Sequence[xc.Device], - num_out_avals: int + xla_executable, + device_assignment: Sequence[xc.Device], + num_out_avals: int, + num_ordered_effects: int, + all_default_mem_kind: bool, ) -> Sequence[sharding_impls.XLACompatibleSharding]: from jax._src import pjit - if config.jax_enable_memories: + if all_default_mem_kind: + omk = [None] * num_out_avals + else: try: omk = xla_executable.get_output_memory_kinds()[0] + if num_ordered_effects > 0: + omk = omk[num_ordered_effects:] except: omk = [None] * num_out_avals - else: - omk = [None] * num_out_avals + + assert len(omk) == num_out_avals, (len(omk), num_out_avals) # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. In that case, @@ -2560,7 +2590,8 @@ class UnloadedMeshExecutable: pmap_nreps: int = 1, jaxpr_debug_info: core.JaxprDebugInfo | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, - compiler_options=None + all_default_mem_kind: bool = True, + compiler_options=None, ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) @@ -2616,7 +2647,8 @@ class UnloadedMeshExecutable: # without tuple conversion. device_assignment = tuple(da) out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore - xla_executable, device_assignment, len(global_out_avals)) # type: ignore + xla_executable, device_assignment, len(global_out_avals), + len(ordered_effects), all_default_mem_kind) # type: ignore orig_out_shardings = out_shardings out_shardings, are_out_shardings_from_xla = [], [] # type: ignore for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings, diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index 62d01d044..259b69aa7 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -653,7 +653,8 @@ def _check_lowering(lowering) -> None: "spmd_lowering", "auto_spmd_lowering", "tuple_args", "ordered_effects", "unordered_effects", "keepalive", "host_callbacks", "pmap_nreps", "committed", - "device_assignment", "jaxpr_debug_info", "shape_poly_state"] + "device_assignment", "jaxpr_debug_info", "shape_poly_state", + "all_default_mem_kind"] for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") diff --git a/tests/api_test.py b/tests/api_test.py index 1e43c807e..f131aa9e3 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10278,7 +10278,7 @@ class OverrideLoweringTest(jtu.JaxTestCase): jax.jit(f) .lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16), _experimental_override_lowering_rules=rules).as_text()) - self.assertNotIn("stablehlo.custom_call", lowered_ir) + self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) if __name__ == '__main__': diff --git a/tests/array_test.py b/tests/array_test.py index 972277317..d5685045f 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -868,7 +868,7 @@ class ShardingTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, r"Sharding NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), " - r"spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only " + r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only " "valid for values of rank at least 4, but was applied to a value of rank 2"): new_mps.is_compatible_aval(shape) @@ -884,15 +884,16 @@ class ShardingTest(jtu.JaxTestCase): op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7] op.replicate_on_last_tile_dim = True s = jax.sharding.GSPMDSharding(jax.devices(), op) - self.assertEqual( - repr(s), + # memory kind also appears in the repr but only for TPU. + self.assertIn( 'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 ' - 'last_tile_dim_replicate})') + 'last_tile_dim_replicate}', repr(s)) op2 = xc.OpSharding() op2.type = xc.OpSharding.Type.REPLICATED s2 = jax.sharding.GSPMDSharding(jax.devices(), op2) - self.assertEqual(repr(s2), 'GSPMDSharding({replicated})') + # memory kind also appears in the repr but only for TPU. + self.assertIn('GSPMDSharding({replicated}', repr(s2)) @parameterized.named_parameters( ("mesh_x_y", P("x", "y"), (4, 2), (), False), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 748868d9f..c4b391117 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1226,7 +1226,7 @@ class PJitTest(jtu.BufferDonationTestCase): ValueError, r"One of with_sharding_constraint.*Sharding " r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), " - r"spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only " + r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only " "valid for values of rank at least 4, but was applied to a value of rank 1"): pjit_f(jnp.array([1, 2, 3])) @@ -3658,6 +3658,19 @@ class ArrayPjitTest(jtu.JaxTestCase): ' manager.*SingleDeviceSharding'): jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr) + @jtu.skip_on_devices("tpu") + def test_device_put_memory_kind_not_tpu(self): + @jax.jit + def f(x): + y = x * 2 + return jax.device_put(y, sharding_impls.TransferToMemoryKind('unpinned_host')) + + with self.assertRaisesRegex( + NotImplementedError, + 'Passing memory_kind to device_put via Shardings is not supported on' + ' platform.*'): + f(jnp.arange(8)) + class TempSharding(Sharding):