Make lowering oblivious to real physical devices. Instead cache lowering on HloSharding only (which is based on logical device numbers)

Make an exception for callbacks and custom_partitioning because they need access to device_assignment during lowering.

PiperOrigin-RevId: 589244695
This commit is contained in:
Yash Katariya 2023-12-08 14:35:27 -08:00 committed by jax authors
parent c86116e72a
commit 5fb8ceca73
11 changed files with 87 additions and 26 deletions

View File

@ -156,14 +156,17 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
f" {type(sharding)}" f" {type(sharding)}"
) )
device = next(iter(sharding.device_set)) device = next(iter(sharding.device_set))
device_assignment = axis_context.device_assignment
if device_assignment is None:
raise AssertionError(
"Please file a bug at https://github.com/google/jax/issues")
try: try:
device_index = axis_context.device_assignment.index(device) device_index = device_assignment.index(device)
except IndexError as e: except IndexError as e:
raise ValueError( raise ValueError(
"Sharding provided to pure_callback specifies a device" "Sharding provided to pure_callback specifies a device"
f" {device} that is not in the device assignment" f" {device} that is not in the device assignment"
f" ({axis_context.device_assignment})" f" ({device_assignment})") from e
) from e
else: else:
device_index = 0 device_index = 0

View File

@ -346,6 +346,9 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
if isinstance(axis_context, sharding_impls.ShardingContext): if isinstance(axis_context, sharding_impls.ShardingContext):
devices = axis_context.device_assignment devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues')
elif isinstance(axis_context, sharding_impls.SPMDAxisContext): elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = list(axis_context.mesh.devices.flat) devices = list(axis_context.mesh.devices.flat)
else: else:

View File

@ -191,6 +191,7 @@ def should_tuple_args(num_args: int, platform: str) -> bool:
else: else:
return False return False
@util.weakref_lru_cache
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool: def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
"""Whether there is a primitive given by user anywhere inside a Jaxpr.""" """Whether there is a primitive given by user anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns: for eqn in jaxpr.eqns:

View File

@ -829,8 +829,7 @@ def lower_jaxpr_to_module(
host_callbacks=host_callbacks, host_callbacks=host_callbacks,
lowering_parameters=lowering_parameters, lowering_parameters=lowering_parameters,
shape_poly_state=ShapePolyLoweringState( shape_poly_state=ShapePolyLoweringState(
dim_vars, dim_vars, lowering_parameters.platforms))
lowering_parameters.platforms))
with ctx.context, ir.Location.unknown(ctx.context): with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that # Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name. # XLA computation preserves the module name.

View File

@ -1752,7 +1752,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
@weakref_lru_cache @weakref_lru_cache
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings, semantic_in_shardings, semantic_out_shardings,
in_layouts, out_layouts, da_object, in_layouts, out_layouts, num_devices, device_assignment,
donated_invars, name_stack, all_default_mem_kind, donated_invars, name_stack, all_default_mem_kind,
lowering_parameters: mlir.LoweringParameters): lowering_parameters: mlir.LoweringParameters):
jaxpr = closed_jaxpr.jaxpr jaxpr = closed_jaxpr.jaxpr
@ -1760,9 +1760,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
out_shardings = semantic_out_shardings.shardings out_shardings = semantic_out_shardings.shardings
global_in_avals = closed_jaxpr.in_avals global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals global_out_avals = closed_jaxpr.out_avals
# TODO(yashkatariya): Make device_assignment directly usable in the downstream
# code without tuple conversion.
device_assignment = tuple(da_object)
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
if logger.isEnabledFor(log_priority): if logger.isEnabledFor(log_priority):
@ -1787,8 +1784,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings) in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings) out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals) replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(device_assignment) axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment)
num_partitions = len(device_assignment) num_partitions = num_devices
else: else:
# This path is triggered for `jit(pmap)` cases. # This path is triggered for `jit(pmap)` cases.
replicated_args = None replicated_args = None
@ -1800,7 +1797,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
module_name = f"{api_name}_{fun_name}" module_name = f"{api_name}_{fun_name}"
if len(device_assignment) > 1: if num_devices > 1:
unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects) unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects)
unsupported_effects = effects.shardable_ordered_effects.filter_not_in( unsupported_effects = effects.shardable_ordered_effects.filter_not_in(
unsupported_effects) unsupported_effects)
@ -1972,12 +1969,16 @@ def lower_sharding_computation(
# 2. Build up the HLO # 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
is_callback_p = (dispatch.jaxpr_has_primitive(jaxpr, 'inspect_sharding') or
dispatch.jaxpr_has_primitive(jaxpr, 'custom_partitioning') or
dispatch.jaxpr_has_primitive(jaxpr, 'pure_callback') or
dispatch.jaxpr_has_primitive(jaxpr, 'io_callback'))
(module, keepalive, host_callbacks, unordered_effects, ordered_effects, (module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, in_layouts, out_layouts, da_object, semantic_out_shardings, in_layouts, out_layouts, len(da_object),
donated_invars, name_stack, all_default_mem_kind, tuple(da_object) if is_callback_p else None, donated_invars, name_stack,
lowering_parameters=lowering_parameters) all_default_mem_kind, lowering_parameters=lowering_parameters)
# backend and device_assignment is passed through to MeshExecutable because # backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way # if keep_unused=False and all in_shardings are pruned, then there is no way

View File

@ -1400,12 +1400,12 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
out_shardings, api_name): out_shardings, api_name):
mod_ctx = ctx.module_context mod_ctx = ctx.module_context
axis_ctx = ctx.module_context.axis_context axis_ctx = ctx.module_context.axis_context
da = None num_devices = None
if isinstance(axis_ctx, sharding_impls.ShardingContext): if isinstance(axis_ctx, sharding_impls.ShardingContext):
da = tuple(axis_ctx.device_assignment) num_devices = axis_ctx.num_devices
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext): elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
da = axis_ctx.mesh._flat_devices_tuple num_devices = axis_ctx.mesh.size
key = (pjit_p, name, jaxpr, effects, da, key = (pjit_p, name, jaxpr, effects, num_devices,
pxla.SemanticallyEqualShardings(in_shardings), pxla.SemanticallyEqualShardings(in_shardings),
pxla.SemanticallyEqualShardings(out_shardings), api_name) pxla.SemanticallyEqualShardings(out_shardings), api_name)

View File

@ -1173,7 +1173,12 @@ class ShardingContext:
This context also uses the GSPMD partitioner. This context also uses the GSPMD partitioner.
""" """
device_assignment: Sequence[xc.Device] num_devices: int
device_assignment: Sequence[xc.Device] | None = None
def __post_init__(self):
if self.device_assignment is not None:
assert self.num_devices == len(self.device_assignment)
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner. # Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
@property @property

View File

@ -182,7 +182,7 @@ def _tpu_custom_call_lowering(
" call in a shard_map or xmap." " call in a shard_map or xmap."
) )
elif isinstance(axis_context, sharding_impls.ShardingContext): elif isinstance(axis_context, sharding_impls.ShardingContext):
if len(axis_context.device_assignment) != 1: if axis_context.num_devices != 1:
raise NotImplementedError( raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the" "Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap." " call in a shard_map or xmap."

View File

@ -480,6 +480,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
if isinstance(axis_context, sharding_impls.ShardingContext): if isinstance(axis_context, sharding_impls.ShardingContext):
devices = axis_context.device_assignment devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues')
elif isinstance(axis_context, sharding_impls.SPMDAxisContext): elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = list(axis_context.mesh.devices.flat) devices = list(axis_context.mesh.devices.flat)
else: else:

View File

@ -715,7 +715,7 @@ def _wrap_main_func(
# Make a context just for lowering the dimension value computations # Make a context just for lowering the dimension value computations
module_context = mlir.ModuleContext( module_context = mlir.ModuleContext(
backend_or_name="cpu", platforms=["cpu"], backend_or_name="cpu", platforms=["cpu"],
axis_context=sharding_impls.ShardingContext([]), axis_context=sharding_impls.ShardingContext(0),
name_stack=source_info_util.new_name_stack(), name_stack=source_info_util.new_name_stack(),
keepalives=[], channel_iterator=itertools.count(1), keepalives=[], channel_iterator=itertools.count(1),
host_callbacks=[], module=wrapped_module, context=context, host_callbacks=[], module=wrapped_module, context=context,
@ -1171,16 +1171,16 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
axis_context = ctx.module_context.axis_context axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext): if isinstance(axis_context, sharding_impls.ShardingContext):
ctx_device_assignment = axis_context.device_assignment num_devices = axis_context.num_devices
elif isinstance(axis_context, sharding_impls.SPMDAxisContext): elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
ctx_device_assignment = list(axis_context.mesh.devices.flat) num_devices = axis_context.mesh.size
else: else:
raise NotImplementedError(type(axis_context)) raise NotImplementedError(type(axis_context))
if len(ctx_device_assignment) != exported.nr_devices: if num_devices != exported.nr_devices:
raise NotImplementedError( raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for " f"Exported module {exported.fun_name} was lowered for "
f"{exported.nr_devices} devices and is called in a context with " f"{exported.nr_devices} devices and is called in a context with "
f"{len(ctx_device_assignment)} devices" f"{num_devices} devices"
) )
# Apply in_shardings # Apply in_shardings

View File

@ -3770,6 +3770,52 @@ class ArrayPjitTest(jtu.JaxTestCase):
with mesh: with mesh:
f() # doesn't crash f() # doesn't crash
def test_lowering_cache_hit_different_devices(self):
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
@jax.jit
def f(x):
return x * 2
def g(a):
a = jax.device_put(a, NamedSharding(mesh1, P('x')))
out_a = f(a) # lowering cached
# same num_devices but different devices.
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
f(b) # lowering cache *hit*
with jtu.count_jit_and_pmap_compiles() as count:
g(np.arange(8))
self.assertEqual(count[0], 1)
def test_lowering_cache_miss_different_devices_and_sharding(self):
if jax.device_count() < 4:
self.skipTest('Requires >=4 devices')
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')
@jax.jit
def f(x):
return x * 2
def g(a):
a = jax.device_put(a, NamedSharding(mesh1, P('x')))
out_a = f(a) # lowering cached
# same num_devices but different devices and sharding
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
f(b) # lowering cache *miss*
with jtu.count_jit_and_pmap_compiles() as count:
g(np.arange(8))
self.assertEqual(count[0], 2)
class TempSharding(Sharding): class TempSharding(Sharding):