mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make lowering oblivious to real physical devices. Instead cache lowering on HloSharding only (which is based on logical device numbers)
Make an exception for callbacks and custom_partitioning because they need access to device_assignment during lowering. PiperOrigin-RevId: 589244695
This commit is contained in:
parent
c86116e72a
commit
5fb8ceca73
@ -156,14 +156,17 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
|
||||
f" {type(sharding)}"
|
||||
)
|
||||
device = next(iter(sharding.device_set))
|
||||
device_assignment = axis_context.device_assignment
|
||||
if device_assignment is None:
|
||||
raise AssertionError(
|
||||
"Please file a bug at https://github.com/google/jax/issues")
|
||||
try:
|
||||
device_index = axis_context.device_assignment.index(device)
|
||||
device_index = device_assignment.index(device)
|
||||
except IndexError as e:
|
||||
raise ValueError(
|
||||
"Sharding provided to pure_callback specifies a device"
|
||||
f" {device} that is not in the device assignment"
|
||||
f" ({axis_context.device_assignment})"
|
||||
) from e
|
||||
f" ({device_assignment})") from e
|
||||
else:
|
||||
device_index = 0
|
||||
|
||||
|
@ -346,6 +346,9 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
devices = axis_context.device_assignment
|
||||
if devices is None:
|
||||
raise AssertionError(
|
||||
'Please file a bug at https://github.com/google/jax/issues')
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
devices = list(axis_context.mesh.devices.flat)
|
||||
else:
|
||||
|
@ -191,6 +191,7 @@ def should_tuple_args(num_args: int, platform: str) -> bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
@util.weakref_lru_cache
|
||||
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
||||
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
|
||||
for eqn in jaxpr.eqns:
|
||||
|
@ -829,8 +829,7 @@ def lower_jaxpr_to_module(
|
||||
host_callbacks=host_callbacks,
|
||||
lowering_parameters=lowering_parameters,
|
||||
shape_poly_state=ShapePolyLoweringState(
|
||||
dim_vars,
|
||||
lowering_parameters.platforms))
|
||||
dim_vars, lowering_parameters.platforms))
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
# XLA computation preserves the module name.
|
||||
|
@ -1752,7 +1752,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
|
||||
@weakref_lru_cache
|
||||
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
semantic_in_shardings, semantic_out_shardings,
|
||||
in_layouts, out_layouts, da_object,
|
||||
in_layouts, out_layouts, num_devices, device_assignment,
|
||||
donated_invars, name_stack, all_default_mem_kind,
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
@ -1760,9 +1760,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
global_in_avals = closed_jaxpr.in_avals
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
# TODO(yashkatariya): Make device_assignment directly usable in the downstream
|
||||
# code without tuple conversion.
|
||||
device_assignment = tuple(da_object)
|
||||
|
||||
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
||||
if logger.isEnabledFor(log_priority):
|
||||
@ -1787,8 +1784,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
|
||||
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(global_in_avals)
|
||||
axis_ctx = sharding_impls.ShardingContext(device_assignment)
|
||||
num_partitions = len(device_assignment)
|
||||
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment)
|
||||
num_partitions = num_devices
|
||||
else:
|
||||
# This path is triggered for `jit(pmap)` cases.
|
||||
replicated_args = None
|
||||
@ -1800,7 +1797,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
|
||||
if len(device_assignment) > 1:
|
||||
if num_devices > 1:
|
||||
unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects)
|
||||
unsupported_effects = effects.shardable_ordered_effects.filter_not_in(
|
||||
unsupported_effects)
|
||||
@ -1972,12 +1969,16 @@ def lower_sharding_computation(
|
||||
# 2. Build up the HLO
|
||||
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
||||
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
|
||||
is_callback_p = (dispatch.jaxpr_has_primitive(jaxpr, 'inspect_sharding') or
|
||||
dispatch.jaxpr_has_primitive(jaxpr, 'custom_partitioning') or
|
||||
dispatch.jaxpr_has_primitive(jaxpr, 'pure_callback') or
|
||||
dispatch.jaxpr_has_primitive(jaxpr, 'io_callback'))
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
semantic_out_shardings, in_layouts, out_layouts, da_object,
|
||||
donated_invars, name_stack, all_default_mem_kind,
|
||||
lowering_parameters=lowering_parameters)
|
||||
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
|
||||
tuple(da_object) if is_callback_p else None, donated_invars, name_stack,
|
||||
all_default_mem_kind, lowering_parameters=lowering_parameters)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
|
@ -1400,12 +1400,12 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
out_shardings, api_name):
|
||||
mod_ctx = ctx.module_context
|
||||
axis_ctx = ctx.module_context.axis_context
|
||||
da = None
|
||||
num_devices = None
|
||||
if isinstance(axis_ctx, sharding_impls.ShardingContext):
|
||||
da = tuple(axis_ctx.device_assignment)
|
||||
num_devices = axis_ctx.num_devices
|
||||
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
||||
da = axis_ctx.mesh._flat_devices_tuple
|
||||
key = (pjit_p, name, jaxpr, effects, da,
|
||||
num_devices = axis_ctx.mesh.size
|
||||
key = (pjit_p, name, jaxpr, effects, num_devices,
|
||||
pxla.SemanticallyEqualShardings(in_shardings),
|
||||
pxla.SemanticallyEqualShardings(out_shardings), api_name)
|
||||
|
||||
|
@ -1173,7 +1173,12 @@ class ShardingContext:
|
||||
|
||||
This context also uses the GSPMD partitioner.
|
||||
"""
|
||||
device_assignment: Sequence[xc.Device]
|
||||
num_devices: int
|
||||
device_assignment: Sequence[xc.Device] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device_assignment is not None:
|
||||
assert self.num_devices == len(self.device_assignment)
|
||||
|
||||
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
|
||||
@property
|
||||
|
@ -182,7 +182,7 @@ def _tpu_custom_call_lowering(
|
||||
" call in a shard_map or xmap."
|
||||
)
|
||||
elif isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
if len(axis_context.device_assignment) != 1:
|
||||
if axis_context.num_devices != 1:
|
||||
raise NotImplementedError(
|
||||
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
|
||||
" call in a shard_map or xmap."
|
||||
|
@ -480,6 +480,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
devices = axis_context.device_assignment
|
||||
if devices is None:
|
||||
raise AssertionError(
|
||||
'Please file a bug at https://github.com/google/jax/issues')
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
devices = list(axis_context.mesh.devices.flat)
|
||||
else:
|
||||
|
@ -715,7 +715,7 @@ def _wrap_main_func(
|
||||
# Make a context just for lowering the dimension value computations
|
||||
module_context = mlir.ModuleContext(
|
||||
backend_or_name="cpu", platforms=["cpu"],
|
||||
axis_context=sharding_impls.ShardingContext([]),
|
||||
axis_context=sharding_impls.ShardingContext(0),
|
||||
name_stack=source_info_util.new_name_stack(),
|
||||
keepalives=[], channel_iterator=itertools.count(1),
|
||||
host_callbacks=[], module=wrapped_module, context=context,
|
||||
@ -1171,16 +1171,16 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
ctx_device_assignment = axis_context.device_assignment
|
||||
num_devices = axis_context.num_devices
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
ctx_device_assignment = list(axis_context.mesh.devices.flat)
|
||||
num_devices = axis_context.mesh.size
|
||||
else:
|
||||
raise NotImplementedError(type(axis_context))
|
||||
if len(ctx_device_assignment) != exported.nr_devices:
|
||||
if num_devices != exported.nr_devices:
|
||||
raise NotImplementedError(
|
||||
f"Exported module {exported.fun_name} was lowered for "
|
||||
f"{exported.nr_devices} devices and is called in a context with "
|
||||
f"{len(ctx_device_assignment)} devices"
|
||||
f"{num_devices} devices"
|
||||
)
|
||||
|
||||
# Apply in_shardings
|
||||
|
@ -3770,6 +3770,52 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with mesh:
|
||||
f() # doesn't crash
|
||||
|
||||
def test_lowering_cache_hit_different_devices(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Requires >=4 devices')
|
||||
|
||||
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
|
||||
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
def g(a):
|
||||
a = jax.device_put(a, NamedSharding(mesh1, P('x')))
|
||||
out_a = f(a) # lowering cached
|
||||
|
||||
# same num_devices but different devices.
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
|
||||
f(b) # lowering cache *hit*
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_lowering_cache_miss_different_devices_and_sharding(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Requires >=4 devices')
|
||||
|
||||
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
|
||||
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
def g(a):
|
||||
a = jax.device_put(a, NamedSharding(mesh1, P('x')))
|
||||
out_a = f(a) # lowering cached
|
||||
|
||||
# same num_devices but different devices and sharding
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
||||
f(b) # lowering cache *miss*
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user