mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Set the jax_enable_memories flag to True.
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes. If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted. PiperOrigin-RevId: 564457393
This commit is contained in:
parent
bfc12bdda9
commit
a36598b2a7
@ -784,7 +784,7 @@ def _update_jax_memories_thread_local(val):
|
||||
|
||||
enable_memories = config.define_bool_state(
|
||||
'jax_enable_memories',
|
||||
default=False,
|
||||
default=True,
|
||||
upgrade=True,
|
||||
update_global_hook=_update_jax_memories_global,
|
||||
update_thread_local_hook=_update_jax_memories_thread_local,
|
||||
|
@ -529,7 +529,7 @@ def device_put_transpose_rule(ct, _, device, src):
|
||||
ad.deflinear2(device_put_p, device_put_transpose_rule)
|
||||
batching.defvectorized(device_put_p)
|
||||
|
||||
def _device_put_lowering(ctx, x, *, device, src):
|
||||
def _tpu_device_put_lowering(ctx, x, *, device, src):
|
||||
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
aval, = ctx.avals_in
|
||||
@ -540,4 +540,14 @@ def _device_put_lowering(ctx, x, *, device, src):
|
||||
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
return [x]
|
||||
return [x]
|
||||
mlir.register_lowering(device_put_p, _device_put_lowering)
|
||||
mlir.register_lowering(device_put_p, _tpu_device_put_lowering, platform='tpu')
|
||||
|
||||
|
||||
def _common_device_put_lowering(ctx, x, *, device, src):
|
||||
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
|
||||
device.memory_kind is not None):
|
||||
raise NotImplementedError(
|
||||
"Passing memory_kind to device_put via Shardings is not supported on"
|
||||
f" platform {ctx.module_context.platform}")
|
||||
return [x]
|
||||
mlir.register_lowering(device_put_p, _common_device_put_lowering)
|
||||
|
@ -676,6 +676,7 @@ def lower_jaxpr_to_module(
|
||||
result_names: Sequence[str | None] | None = None,
|
||||
num_replicas: int = 1,
|
||||
num_partitions: int = 1,
|
||||
all_default_mem_kind: bool = True,
|
||||
override_lowering_rules: None | (
|
||||
tuple[tuple[core.Primitive, LoweringRule]]) = None,
|
||||
) -> LoweringResult:
|
||||
@ -701,10 +702,14 @@ def lower_jaxpr_to_module(
|
||||
map(sharded_aval, jaxpr.in_avals, arg_shardings))
|
||||
out_avals = (jaxpr.out_avals if result_shardings is None else
|
||||
map(sharded_aval, jaxpr.out_avals, result_shardings))
|
||||
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
|
||||
if arg_shardings is not None else None)
|
||||
result_memory_kinds = (map(_get_mem_kind, result_shardings)
|
||||
if result_shardings is not None else None)
|
||||
if all_default_mem_kind:
|
||||
arg_memory_kinds = None
|
||||
result_memory_kinds = None
|
||||
else:
|
||||
arg_memory_kinds = (map(_get_mem_kind, arg_shardings)
|
||||
if arg_shardings is not None else None)
|
||||
result_memory_kinds = (map(_get_mem_kind, result_shardings)
|
||||
if result_shardings is not None else None)
|
||||
|
||||
if platform in _platforms_with_donation:
|
||||
input_output_aliases, donated_args = _set_up_aliases(
|
||||
|
@ -66,7 +66,8 @@ from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
|
||||
AUTO, UnspecifiedValue, UNSPECIFIED,
|
||||
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
|
||||
get_array_mapping as _get_array_mapping, is_auto, is_unspecified,
|
||||
is_unspecified_or_auto
|
||||
)
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_delete, distributed_debug_log,
|
||||
@ -1783,7 +1784,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
|
||||
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
semantic_in_shardings, semantic_out_shardings,
|
||||
da_object, lowering_platform,
|
||||
donated_invars, name_stack, override_lowering_rules):
|
||||
donated_invars, name_stack, all_default_mem_kind,
|
||||
override_lowering_rules):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings.shardings
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
@ -1855,6 +1857,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions,
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
override_lowering_rules=override_lowering_rules)
|
||||
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
||||
unordered_effects = list(
|
||||
@ -1925,15 +1928,27 @@ def _create_da_object( # pytype: disable=invalid-annotation
|
||||
return _DeviceAssignment(device_assignment)
|
||||
|
||||
|
||||
def jaxpr_has_dp_with_transfer_mem_kind(jaxpr: core.Jaxpr) -> bool:
|
||||
def jaxpr_transfer_mem_kinds(
|
||||
jaxpr: core.Jaxpr) -> Iterator[sharding_impls.TransferToMemoryKind]:
|
||||
for eqn in jaxpr.eqns:
|
||||
if (eqn.primitive is dispatch.device_put_p and
|
||||
isinstance(eqn.params['device'], sharding_impls.TransferToMemoryKind)):
|
||||
return True
|
||||
yield eqn.params['device']
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
if jaxpr_has_dp_with_transfer_mem_kind(subjaxpr):
|
||||
return True
|
||||
return False
|
||||
yield from jaxpr_transfer_mem_kinds(subjaxpr)
|
||||
|
||||
|
||||
def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
try:
|
||||
default_mem_kind = da_object.default_memory_kind
|
||||
except:
|
||||
return True
|
||||
for i in shardings:
|
||||
if is_unspecified_or_auto(i):
|
||||
continue
|
||||
if i.memory_kind != default_mem_kind:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
@ -1988,19 +2003,26 @@ def lower_sharding_computation(
|
||||
for js, source_info in jaxpr_sharding]),
|
||||
devices_from_context)
|
||||
|
||||
transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
||||
|
||||
committed = bool(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not is_unspecified(i) for i in in_shardings) or
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(o) for o in out_shardings) or
|
||||
jaxpr_has_dp_with_transfer_mem_kind(jaxpr))
|
||||
transfer_mem_kind_in_jaxpr)
|
||||
|
||||
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
da_object,
|
||||
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding], # type: ignore
|
||||
transfer_mem_kind_in_jaxpr))
|
||||
|
||||
if not da_object.is_fully_addressable: # type: ignore
|
||||
if inline and config.jax_spmd_mode != 'allow_all':
|
||||
raise RuntimeError(
|
||||
@ -2022,7 +2044,7 @@ def lower_sharding_computation(
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, da_object, lowering_platform,
|
||||
donated_invars, name_stack, override_lowering_rules)
|
||||
donated_invars, name_stack, all_default_mem_kind, override_lowering_rules)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
@ -2050,7 +2072,8 @@ def lower_sharding_computation(
|
||||
committed=committed,
|
||||
pmap_nreps=nreps,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=shape_poly_state)
|
||||
shape_poly_state=shape_poly_state,
|
||||
all_default_mem_kind=all_default_mem_kind)
|
||||
|
||||
|
||||
def _to_logical_sharding(
|
||||
@ -2295,18 +2318,25 @@ def _get_input_indices(
|
||||
|
||||
|
||||
def get_gspmd_shardings_from_executable(
|
||||
xla_executable, device_assignment: Sequence[xc.Device],
|
||||
num_out_avals: int
|
||||
xla_executable,
|
||||
device_assignment: Sequence[xc.Device],
|
||||
num_out_avals: int,
|
||||
num_ordered_effects: int,
|
||||
all_default_mem_kind: bool,
|
||||
) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
from jax._src import pjit
|
||||
|
||||
if config.jax_enable_memories:
|
||||
if all_default_mem_kind:
|
||||
omk = [None] * num_out_avals
|
||||
else:
|
||||
try:
|
||||
omk = xla_executable.get_output_memory_kinds()[0]
|
||||
if num_ordered_effects > 0:
|
||||
omk = omk[num_ordered_effects:]
|
||||
except:
|
||||
omk = [None] * num_out_avals
|
||||
else:
|
||||
omk = [None] * num_out_avals
|
||||
|
||||
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
|
||||
|
||||
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
||||
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
||||
@ -2560,7 +2590,8 @@ class UnloadedMeshExecutable:
|
||||
pmap_nreps: int = 1,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None,
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
compiler_options=None
|
||||
all_default_mem_kind: bool = True,
|
||||
compiler_options=None,
|
||||
) -> MeshExecutable:
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
@ -2616,7 +2647,8 @@ class UnloadedMeshExecutable:
|
||||
# without tuple conversion.
|
||||
device_assignment = tuple(da)
|
||||
out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
xla_executable, device_assignment, len(global_out_avals)) # type: ignore
|
||||
xla_executable, device_assignment, len(global_out_avals),
|
||||
len(ordered_effects), all_default_mem_kind) # type: ignore
|
||||
orig_out_shardings = out_shardings
|
||||
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
|
||||
|
@ -653,7 +653,8 @@ def _check_lowering(lowering) -> None:
|
||||
"spmd_lowering", "auto_spmd_lowering",
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
||||
"device_assignment", "jaxpr_debug_info", "shape_poly_state"]
|
||||
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
|
||||
"all_default_mem_kind"]
|
||||
for compile_arg in lowering.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
@ -10278,7 +10278,7 @@ class OverrideLoweringTest(jtu.JaxTestCase):
|
||||
jax.jit(f)
|
||||
.lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16),
|
||||
_experimental_override_lowering_rules=rules).as_text())
|
||||
self.assertNotIn("stablehlo.custom_call", lowered_ir)
|
||||
self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -868,7 +868,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Sharding NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
|
||||
r"spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
|
||||
r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only "
|
||||
"valid for values of rank at least 4, but was applied to a value of rank 2"):
|
||||
new_mps.is_compatible_aval(shape)
|
||||
|
||||
@ -884,15 +884,16 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
op.replicate_on_last_tile_dim = True
|
||||
s = jax.sharding.GSPMDSharding(jax.devices(), op)
|
||||
self.assertEqual(
|
||||
repr(s),
|
||||
# memory kind also appears in the repr but only for TPU.
|
||||
self.assertIn(
|
||||
'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
|
||||
'last_tile_dim_replicate})')
|
||||
'last_tile_dim_replicate}', repr(s))
|
||||
|
||||
op2 = xc.OpSharding()
|
||||
op2.type = xc.OpSharding.Type.REPLICATED
|
||||
s2 = jax.sharding.GSPMDSharding(jax.devices(), op2)
|
||||
self.assertEqual(repr(s2), 'GSPMDSharding({replicated})')
|
||||
# memory kind also appears in the repr but only for TPU.
|
||||
self.assertIn('GSPMDSharding({replicated}', repr(s2))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y"), (4, 2), (), False),
|
||||
|
@ -1226,7 +1226,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
ValueError,
|
||||
r"One of with_sharding_constraint.*Sharding "
|
||||
r"NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
|
||||
r"spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
|
||||
r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only "
|
||||
"valid for values of rank at least 4, but was applied to a value of rank 1"):
|
||||
pjit_f(jnp.array([1, 2, 3]))
|
||||
|
||||
@ -3658,6 +3658,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
' manager.*SingleDeviceSharding'):
|
||||
jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_device_put_memory_kind_not_tpu(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
return jax.device_put(y, sharding_impls.TransferToMemoryKind('unpinned_host'))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Passing memory_kind to device_put via Shardings is not supported on'
|
||||
' platform.*'):
|
||||
f(jnp.arange(8))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user