Set the jax_enable_memories flag to True.

If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.

If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.

PiperOrigin-RevId: 564457393
This commit is contained in:
Yash Katariya 2023-09-11 11:54:29 -07:00 committed by jax authors
parent bfc12bdda9
commit a36598b2a7
8 changed files with 94 additions and 32 deletions

View File

@ -784,7 +784,7 @@ def _update_jax_memories_thread_local(val):
enable_memories = config.define_bool_state(
'jax_enable_memories',
default=False,
default=True,
upgrade=True,
update_global_hook=_update_jax_memories_global,
update_thread_local_hook=_update_jax_memories_thread_local,

View File

@ -529,7 +529,7 @@ def device_put_transpose_rule(ct, _, device, src):
ad.deflinear2(device_put_p, device_put_transpose_rule)
batching.defvectorized(device_put_p)
def _device_put_lowering(ctx, x, *, device, src):
def _tpu_device_put_lowering(ctx, x, *, device, src):
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
device.memory_kind is not None):
aval, = ctx.avals_in
@ -540,4 +540,14 @@ def _device_put_lowering(ctx, x, *, device, src):
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
return [x]
return [x]
mlir.register_lowering(device_put_p, _device_put_lowering)
mlir.register_lowering(device_put_p, _tpu_device_put_lowering, platform='tpu')
def _common_device_put_lowering(ctx, x, *, device, src):
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
device.memory_kind is not None):
raise NotImplementedError(
"Passing memory_kind to device_put via Shardings is not supported on"
f" platform {ctx.module_context.platform}")
return [x]
mlir.register_lowering(device_put_p, _common_device_put_lowering)

View File

@ -676,6 +676,7 @@ def lower_jaxpr_to_module(
result_names: Sequence[str | None] | None = None,
num_replicas: int = 1,
num_partitions: int = 1,
all_default_mem_kind: bool = True,
override_lowering_rules: None | (
tuple[tuple[core.Primitive, LoweringRule]]) = None,
) -> LoweringResult:
@ -701,10 +702,14 @@ def lower_jaxpr_to_module(
map(sharded_aval, jaxpr.in_avals, arg_shardings))
out_avals = (jaxpr.out_avals if result_shardings is None else
map(sharded_aval, jaxpr.out_avals, result_shardings))
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
if arg_shardings is not None else None)
result_memory_kinds = (map(_get_mem_kind, result_shardings)
if result_shardings is not None else None)
if all_default_mem_kind:
arg_memory_kinds = None
result_memory_kinds = None
else:
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
if arg_shardings is not None else None)
result_memory_kinds = (map(_get_mem_kind, result_shardings)
if result_shardings is not None else None)
if platform in _platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(

View File

@ -66,7 +66,8 @@ from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTO, UnspecifiedValue, UNSPECIFIED,
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
get_array_mapping as _get_array_mapping, is_auto, is_unspecified,
is_unspecified_or_auto
)
from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
@ -1783,7 +1784,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
da_object, lowering_platform,
donated_invars, name_stack, override_lowering_rules):
donated_invars, name_stack, all_default_mem_kind,
override_lowering_rules):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
@ -1855,6 +1857,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=nreps,
num_partitions=num_partitions,
all_default_mem_kind=all_default_mem_kind,
override_lowering_rules=override_lowering_rules)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
@ -1925,15 +1928,27 @@ def _create_da_object( # pytype: disable=invalid-annotation
return _DeviceAssignment(device_assignment)
def jaxpr_has_dp_with_transfer_mem_kind(jaxpr: core.Jaxpr) -> bool:
def jaxpr_transfer_mem_kinds(
jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]:
for eqn in jaxpr.eqns:
if (eqn.primitive is dispatch.device_put_p and
isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)):
return True
yield eqn.params['device']
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_has_dp_with_transfer_mem_kind(subjaxpr):
return True
return False
yield from jaxpr_transfer_mem_kinds(subjaxpr)
def are_all_shardings_default_mem_kind(da_object, shardings):
try:
default_mem_kind = da_object.default_memory_kind
except:
return True
for i in shardings:
if is_unspecified_or_auto(i):
continue
if i.memory_kind != default_mem_kind:
return False
return True
@profiler.annotate_function
@ -1988,19 +2003,26 @@ def lower_sharding_computation(
for js, source_info in jaxpr_sharding]),
devices_from_context)
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in in_shardings) or
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings) or
jaxpr_has_dp_with_transfer_mem_kind(jaxpr))
transfer_mem_kind_in_jaxpr)
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
da_object = _create_da_object(tuple(device_assignment))
all_default_mem_kind = are_all_shardings_default_mem_kind(
da_object,
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding], # type: ignore
transfer_mem_kind_in_jaxpr))
if not da_object.is_fully_addressable: # type: ignore
if inline and config.jax_spmd_mode != 'allow_all':
raise RuntimeError(
@ -2022,7 +2044,7 @@ def lower_sharding_computation(
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, da_object, lowering_platform,
donated_invars, name_stack, override_lowering_rules)
donated_invars, name_stack, all_default_mem_kind, override_lowering_rules)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
@ -2050,7 +2072,8 @@ def lower_sharding_computation(
committed=committed,
pmap_nreps=nreps,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=shape_poly_state)
shape_poly_state=shape_poly_state,
all_default_mem_kind=all_default_mem_kind)
def _to_logical_sharding(
@ -2295,18 +2318,25 @@ def _get_input_indices(
def get_gspmd_shardings_from_executable(
xla_executable, device_assignment: Sequence[xc.Device],
num_out_avals: int
xla_executable,
device_assignment: Sequence[xc.Device],
num_out_avals: int,
num_ordered_effects: int,
all_default_mem_kind: bool,
) -> Sequence[sharding_impls.XLACompatibleSharding]:
from jax._src import pjit
if config.jax_enable_memories:
if all_default_mem_kind:
omk = [None] * num_out_avals
else:
try:
omk = xla_executable.get_output_memory_kinds()[0]
if num_ordered_effects > 0:
omk = omk[num_ordered_effects:]
except:
omk = [None] * num_out_avals
else:
omk = [None] * num_out_avals
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case,
@ -2560,7 +2590,8 @@ class UnloadedMeshExecutable:
pmap_nreps: int = 1,
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
compiler_options=None
all_default_mem_kind: bool = True,
compiler_options=None,
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
@ -2616,7 +2647,8 @@ class UnloadedMeshExecutable:
# without tuple conversion.
device_assignment = tuple(da)
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment, len(global_out_avals)) # type: ignore
xla_executable, device_assignment, len(global_out_avals),
len(ordered_effects), all_default_mem_kind) # type: ignore
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,

View File

@ -653,7 +653,8 @@ def _check_lowering(lowering) -> None:
"spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state"]
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")

View File

@ -10278,7 +10278,7 @@ class OverrideLoweringTest(jtu.JaxTestCase):
jax.jit(f)
.lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16),
_experimental_override_lowering_rules=rules).as_text())
self.assertNotIn("stablehlo.custom_call", lowered_ir)
self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)
if __name__ == '__main__':

View File

@ -868,7 +868,7 @@ class ShardingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
r"Sharding NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
r"spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 2"):
new_mps.is_compatible_aval(shape)
@ -884,15 +884,16 @@ class ShardingTest(jtu.JaxTestCase):
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
op.replicate_on_last_tile_dim = True
s = jax.sharding.GSPMDSharding(jax.devices(), op)
self.assertEqual(
repr(s),
# memory kind also appears in the repr but only for TPU.
self.assertIn(
'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
'last_tile_dim_replicate})')
'last_tile_dim_replicate}', repr(s))
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.REPLICATED
s2 = jax.sharding.GSPMDSharding(jax.devices(), op2)
self.assertEqual(repr(s2), 'GSPMDSharding({replicated})')
# memory kind also appears in the repr but only for TPU.
self.assertIn('GSPMDSharding({replicated}', repr(s2))
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (4, 2), (), False),

View File

@ -1226,7 +1226,7 @@ class PJitTest(jtu.BufferDonationTestCase):
ValueError,
r"One of with_sharding_constraint.*Sharding "
r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
r"spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 1"):
pjit_f(jnp.array([1, 2, 3]))
@ -3658,6 +3658,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
' manager.*SingleDeviceSharding'):
jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr)
@jtu.skip_on_devices("tpu")
def test_device_put_memory_kind_not_tpu(self):
@jax.jit
def f(x):
y = x * 2
return jax.device_put(y, sharding_impls.TransferToMemoryKind('unpinned_host'))
with self.assertRaisesRegex(
NotImplementedError,
'Passing memory_kind to device_put via Shardings is not supported on'
' platform.*'):
f(jnp.arange(8))
class TempSharding(Sharding):