mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make lowering oblivious to real physical devices. Instead cache lowering on HloSharding only (which is based on logical device numbers)
Make an exception for callbacks and custom_partitioning because they need access to device_assignment during lowering. PiperOrigin-RevId: 589244695
This commit is contained in:
parent
c86116e72a
commit
5fb8ceca73
@ -156,14 +156,17 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
|
|||||||
f" {type(sharding)}"
|
f" {type(sharding)}"
|
||||||
)
|
)
|
||||||
device = next(iter(sharding.device_set))
|
device = next(iter(sharding.device_set))
|
||||||
|
device_assignment = axis_context.device_assignment
|
||||||
|
if device_assignment is None:
|
||||||
|
raise AssertionError(
|
||||||
|
"Please file a bug at https://github.com/google/jax/issues")
|
||||||
try:
|
try:
|
||||||
device_index = axis_context.device_assignment.index(device)
|
device_index = device_assignment.index(device)
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sharding provided to pure_callback specifies a device"
|
"Sharding provided to pure_callback specifies a device"
|
||||||
f" {device} that is not in the device assignment"
|
f" {device} that is not in the device assignment"
|
||||||
f" ({axis_context.device_assignment})"
|
f" ({device_assignment})") from e
|
||||||
) from e
|
|
||||||
else:
|
else:
|
||||||
device_index = 0
|
device_index = 0
|
||||||
|
|
||||||
|
@ -346,6 +346,9 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
|||||||
|
|
||||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||||
devices = axis_context.device_assignment
|
devices = axis_context.device_assignment
|
||||||
|
if devices is None:
|
||||||
|
raise AssertionError(
|
||||||
|
'Please file a bug at https://github.com/google/jax/issues')
|
||||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||||
devices = list(axis_context.mesh.devices.flat)
|
devices = list(axis_context.mesh.devices.flat)
|
||||||
else:
|
else:
|
||||||
|
@ -191,6 +191,7 @@ def should_tuple_args(num_args: int, platform: str) -> bool:
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@util.weakref_lru_cache
|
||||||
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
||||||
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
|
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
|
||||||
for eqn in jaxpr.eqns:
|
for eqn in jaxpr.eqns:
|
||||||
|
@ -829,8 +829,7 @@ def lower_jaxpr_to_module(
|
|||||||
host_callbacks=host_callbacks,
|
host_callbacks=host_callbacks,
|
||||||
lowering_parameters=lowering_parameters,
|
lowering_parameters=lowering_parameters,
|
||||||
shape_poly_state=ShapePolyLoweringState(
|
shape_poly_state=ShapePolyLoweringState(
|
||||||
dim_vars,
|
dim_vars, lowering_parameters.platforms))
|
||||||
lowering_parameters.platforms))
|
|
||||||
with ctx.context, ir.Location.unknown(ctx.context):
|
with ctx.context, ir.Location.unknown(ctx.context):
|
||||||
# Remove module name characters that XLA would alter. This ensures that
|
# Remove module name characters that XLA would alter. This ensures that
|
||||||
# XLA computation preserves the module name.
|
# XLA computation preserves the module name.
|
||||||
|
@ -1752,7 +1752,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
|
|||||||
@weakref_lru_cache
|
@weakref_lru_cache
|
||||||
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||||
semantic_in_shardings, semantic_out_shardings,
|
semantic_in_shardings, semantic_out_shardings,
|
||||||
in_layouts, out_layouts, da_object,
|
in_layouts, out_layouts, num_devices, device_assignment,
|
||||||
donated_invars, name_stack, all_default_mem_kind,
|
donated_invars, name_stack, all_default_mem_kind,
|
||||||
lowering_parameters: mlir.LoweringParameters):
|
lowering_parameters: mlir.LoweringParameters):
|
||||||
jaxpr = closed_jaxpr.jaxpr
|
jaxpr = closed_jaxpr.jaxpr
|
||||||
@ -1760,9 +1760,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|||||||
out_shardings = semantic_out_shardings.shardings
|
out_shardings = semantic_out_shardings.shardings
|
||||||
global_in_avals = closed_jaxpr.in_avals
|
global_in_avals = closed_jaxpr.in_avals
|
||||||
global_out_avals = closed_jaxpr.out_avals
|
global_out_avals = closed_jaxpr.out_avals
|
||||||
# TODO(yashkatariya): Make device_assignment directly usable in the downstream
|
|
||||||
# code without tuple conversion.
|
|
||||||
device_assignment = tuple(da_object)
|
|
||||||
|
|
||||||
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
||||||
if logger.isEnabledFor(log_priority):
|
if logger.isEnabledFor(log_priority):
|
||||||
@ -1787,8 +1784,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|||||||
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
|
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
|
||||||
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
|
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
|
||||||
replicated_args = [False] * len(global_in_avals)
|
replicated_args = [False] * len(global_in_avals)
|
||||||
axis_ctx = sharding_impls.ShardingContext(device_assignment)
|
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment)
|
||||||
num_partitions = len(device_assignment)
|
num_partitions = num_devices
|
||||||
else:
|
else:
|
||||||
# This path is triggered for `jit(pmap)` cases.
|
# This path is triggered for `jit(pmap)` cases.
|
||||||
replicated_args = None
|
replicated_args = None
|
||||||
@ -1800,7 +1797,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
|||||||
|
|
||||||
module_name = f"{api_name}_{fun_name}"
|
module_name = f"{api_name}_{fun_name}"
|
||||||
|
|
||||||
if len(device_assignment) > 1:
|
if num_devices > 1:
|
||||||
unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects)
|
unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects)
|
||||||
unsupported_effects = effects.shardable_ordered_effects.filter_not_in(
|
unsupported_effects = effects.shardable_ordered_effects.filter_not_in(
|
||||||
unsupported_effects)
|
unsupported_effects)
|
||||||
@ -1972,12 +1969,16 @@ def lower_sharding_computation(
|
|||||||
# 2. Build up the HLO
|
# 2. Build up the HLO
|
||||||
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
||||||
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
|
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
|
||||||
|
is_callback_p = (dispatch.jaxpr_has_primitive(jaxpr, 'inspect_sharding') or
|
||||||
|
dispatch.jaxpr_has_primitive(jaxpr, 'custom_partitioning') or
|
||||||
|
dispatch.jaxpr_has_primitive(jaxpr, 'pure_callback') or
|
||||||
|
dispatch.jaxpr_has_primitive(jaxpr, 'io_callback'))
|
||||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||||
semantic_out_shardings, in_layouts, out_layouts, da_object,
|
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
|
||||||
donated_invars, name_stack, all_default_mem_kind,
|
tuple(da_object) if is_callback_p else None, donated_invars, name_stack,
|
||||||
lowering_parameters=lowering_parameters)
|
all_default_mem_kind, lowering_parameters=lowering_parameters)
|
||||||
|
|
||||||
# backend and device_assignment is passed through to MeshExecutable because
|
# backend and device_assignment is passed through to MeshExecutable because
|
||||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||||
|
@ -1400,12 +1400,12 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
|||||||
out_shardings, api_name):
|
out_shardings, api_name):
|
||||||
mod_ctx = ctx.module_context
|
mod_ctx = ctx.module_context
|
||||||
axis_ctx = ctx.module_context.axis_context
|
axis_ctx = ctx.module_context.axis_context
|
||||||
da = None
|
num_devices = None
|
||||||
if isinstance(axis_ctx, sharding_impls.ShardingContext):
|
if isinstance(axis_ctx, sharding_impls.ShardingContext):
|
||||||
da = tuple(axis_ctx.device_assignment)
|
num_devices = axis_ctx.num_devices
|
||||||
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
||||||
da = axis_ctx.mesh._flat_devices_tuple
|
num_devices = axis_ctx.mesh.size
|
||||||
key = (pjit_p, name, jaxpr, effects, da,
|
key = (pjit_p, name, jaxpr, effects, num_devices,
|
||||||
pxla.SemanticallyEqualShardings(in_shardings),
|
pxla.SemanticallyEqualShardings(in_shardings),
|
||||||
pxla.SemanticallyEqualShardings(out_shardings), api_name)
|
pxla.SemanticallyEqualShardings(out_shardings), api_name)
|
||||||
|
|
||||||
|
@ -1173,7 +1173,12 @@ class ShardingContext:
|
|||||||
|
|
||||||
This context also uses the GSPMD partitioner.
|
This context also uses the GSPMD partitioner.
|
||||||
"""
|
"""
|
||||||
device_assignment: Sequence[xc.Device]
|
num_devices: int
|
||||||
|
device_assignment: Sequence[xc.Device] | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.device_assignment is not None:
|
||||||
|
assert self.num_devices == len(self.device_assignment)
|
||||||
|
|
||||||
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
|
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
|
||||||
@property
|
@property
|
||||||
|
@ -182,7 +182,7 @@ def _tpu_custom_call_lowering(
|
|||||||
" call in a shard_map or xmap."
|
" call in a shard_map or xmap."
|
||||||
)
|
)
|
||||||
elif isinstance(axis_context, sharding_impls.ShardingContext):
|
elif isinstance(axis_context, sharding_impls.ShardingContext):
|
||||||
if len(axis_context.device_assignment) != 1:
|
if axis_context.num_devices != 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
|
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
|
||||||
" call in a shard_map or xmap."
|
" call in a shard_map or xmap."
|
||||||
|
@ -480,6 +480,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
|||||||
|
|
||||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||||
devices = axis_context.device_assignment
|
devices = axis_context.device_assignment
|
||||||
|
if devices is None:
|
||||||
|
raise AssertionError(
|
||||||
|
'Please file a bug at https://github.com/google/jax/issues')
|
||||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||||
devices = list(axis_context.mesh.devices.flat)
|
devices = list(axis_context.mesh.devices.flat)
|
||||||
else:
|
else:
|
||||||
|
@ -715,7 +715,7 @@ def _wrap_main_func(
|
|||||||
# Make a context just for lowering the dimension value computations
|
# Make a context just for lowering the dimension value computations
|
||||||
module_context = mlir.ModuleContext(
|
module_context = mlir.ModuleContext(
|
||||||
backend_or_name="cpu", platforms=["cpu"],
|
backend_or_name="cpu", platforms=["cpu"],
|
||||||
axis_context=sharding_impls.ShardingContext([]),
|
axis_context=sharding_impls.ShardingContext(0),
|
||||||
name_stack=source_info_util.new_name_stack(),
|
name_stack=source_info_util.new_name_stack(),
|
||||||
keepalives=[], channel_iterator=itertools.count(1),
|
keepalives=[], channel_iterator=itertools.count(1),
|
||||||
host_callbacks=[], module=wrapped_module, context=context,
|
host_callbacks=[], module=wrapped_module, context=context,
|
||||||
@ -1171,16 +1171,16 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
|||||||
|
|
||||||
axis_context = ctx.module_context.axis_context
|
axis_context = ctx.module_context.axis_context
|
||||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||||
ctx_device_assignment = axis_context.device_assignment
|
num_devices = axis_context.num_devices
|
||||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||||
ctx_device_assignment = list(axis_context.mesh.devices.flat)
|
num_devices = axis_context.mesh.size
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(type(axis_context))
|
raise NotImplementedError(type(axis_context))
|
||||||
if len(ctx_device_assignment) != exported.nr_devices:
|
if num_devices != exported.nr_devices:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Exported module {exported.fun_name} was lowered for "
|
f"Exported module {exported.fun_name} was lowered for "
|
||||||
f"{exported.nr_devices} devices and is called in a context with "
|
f"{exported.nr_devices} devices and is called in a context with "
|
||||||
f"{len(ctx_device_assignment)} devices"
|
f"{num_devices} devices"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply in_shardings
|
# Apply in_shardings
|
||||||
|
@ -3770,6 +3770,52 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
|||||||
with mesh:
|
with mesh:
|
||||||
f() # doesn't crash
|
f() # doesn't crash
|
||||||
|
|
||||||
|
def test_lowering_cache_hit_different_devices(self):
|
||||||
|
if jax.device_count() < 4:
|
||||||
|
self.skipTest('Requires >=4 devices')
|
||||||
|
|
||||||
|
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
|
||||||
|
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def f(x):
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
def g(a):
|
||||||
|
a = jax.device_put(a, NamedSharding(mesh1, P('x')))
|
||||||
|
out_a = f(a) # lowering cached
|
||||||
|
|
||||||
|
# same num_devices but different devices.
|
||||||
|
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
|
||||||
|
f(b) # lowering cache *hit*
|
||||||
|
|
||||||
|
with jtu.count_jit_and_pmap_compiles() as count:
|
||||||
|
g(np.arange(8))
|
||||||
|
self.assertEqual(count[0], 1)
|
||||||
|
|
||||||
|
def test_lowering_cache_miss_different_devices_and_sharding(self):
|
||||||
|
if jax.device_count() < 4:
|
||||||
|
self.skipTest('Requires >=4 devices')
|
||||||
|
|
||||||
|
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
|
||||||
|
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def f(x):
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
def g(a):
|
||||||
|
a = jax.device_put(a, NamedSharding(mesh1, P('x')))
|
||||||
|
out_a = f(a) # lowering cached
|
||||||
|
|
||||||
|
# same num_devices but different devices and sharding
|
||||||
|
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
||||||
|
f(b) # lowering cache *miss*
|
||||||
|
|
||||||
|
with jtu.count_jit_and_pmap_compiles() as count:
|
||||||
|
g(np.arange(8))
|
||||||
|
self.assertEqual(count[0], 2)
|
||||||
|
|
||||||
|
|
||||||
class TempSharding(Sharding):
|
class TempSharding(Sharding):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user