From edbe49fb2ada1c5db6577167e06da0111e3f0be5 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 25 Oct 2023 10:39:47 -0700 Subject: [PATCH] Cleanup the handling of single- and multi-platform lowering in ModuleContext Previously, we introduced support for multi-platform lowering, by adding a new LoweringParameters object that can be used to specify a cross-lowering platform or even multiple platforms. But we had kept the ModuleContext.platform in place because some lowering rules were still referencing it. Now we replace ModuleContext.platform with ModuleContext.platforms, which removes the redundancy, simplifies the code, and makes it clearer that the lowering rules should not simply assume single-platform lowering. PiperOrigin-RevId: 576575376 --- jax/_src/api.py | 2 +- jax/_src/debugging.py | 2 +- jax/_src/dispatch.py | 2 +- jax/_src/interpreters/mlir.py | 104 +++++++++++------------- jax/_src/interpreters/pxla.py | 8 +- jax/_src/lax/parallel.py | 7 +- jax/_src/maps.py | 9 +- jax/_src/pallas/triton/lowering.py | 6 +- jax/experimental/custom_partitioning.py | 2 +- jax/experimental/export/export.py | 8 +- jax/experimental/host_callback.py | 6 +- 11 files changed, 73 insertions(+), 83 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 05f8dd52e..6c5e5a5be 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -565,7 +565,7 @@ def xla_computation(fun: Callable, core.ClosedJaxpr(jaxpr, consts), ordered_effects=ordered_effects, backend_or_name=backend, - platform=platform, + platforms=[platform], axis_context=sharding_impls.ReplicaAxisContext(axis_env_), name_stack=source_info_util.new_name_stack( wrap_name(fun_name, "xla_computation")), diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 6d5d6d5d5..69219870d 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -385,7 +385,7 @@ def inspect_sharding_partition(shapes, arg_shardings, result_shape, jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) trivial_comp = mlir.build_xla_computation_helper(closed_jaxpr, - name="tmp_xla_computation", platform=module_context.platform, + name="tmp_xla_computation", platforms=module_context.platforms, backend_or_name=module_context.backend_or_name, axis_context=module_context.axis_context) # The trivial computation built here has a dummy tuple as the result, diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index ecd746615..22fc52bbd 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -554,6 +554,6 @@ def _common_device_put_lowering(ctx, x, *, device, src): device.memory_kind is not None): raise NotImplementedError( "Passing memory_kind to device_put via Shardings is not supported on" - f" platform {ctx.module_context.platform}") + f" platforms {ctx.module_context.platforms}") return [x] mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e576c4907..416c72b11 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -430,8 +430,7 @@ class LoweringParameters: # The current lowering platforms, a non-empty tuple containing some of # 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are # doing multi-platform lowering, otherwise it can specify cross-platform - # lowering. The value None specifies default lowering platform, for the - # platform specified by `ModuleContext.platform`. + # lowering. The value None specifies the default lowering platform. # This is used only in export and jax2tf. platforms: tuple[str, ...] | None = None @@ -454,23 +453,6 @@ class LoweringParameters: # native execution (and we can remove this parameter). replace_tokens_with_dummy: bool = True - @property - def override_platform(self) -> str | None: - """Overrides the lowering platform for cross-platform lowering. - - One of 'cpu', 'cuda', 'rocm', 'tpu'. - If None, use the default JAX mechanisms to pick the lowering platform. - This is currently used for export and jax2tf. - """ - if self.platforms is not None: - return self.platforms[0] - else: - return None - - @property - def is_multi_platform(self) -> bool: - return self.platforms is not None and len(self.platforms) > 1 - @dataclasses.dataclass class ModuleContext: @@ -480,7 +462,7 @@ class ModuleContext: ip: ir.InsertionPoint symbol_table: ir.SymbolTable backend_or_name: str | xb.XlaBackend | None - platform: str + platforms: Sequence[str] axis_context: AxisContext name_stack: source_info_util.NameStack keepalives: list[Any] @@ -503,7 +485,7 @@ class ModuleContext: self, *, backend_or_name: str | xb.XlaBackend | None, - platform: str, + platforms: Sequence[str], axis_context: AxisContext, name_stack: source_info_util.NameStack, keepalives: list[Any], @@ -519,13 +501,13 @@ class ModuleContext: cached_call_jaxpr_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, shape_poly_state = None): - assert platform is not None + self.context = context or make_ir_context() self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) self.ip = ip or ir.InsertionPoint(self.module.body) self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation) self.backend_or_name = backend_or_name - self.platform = platform + self.platforms = platforms self.axis_context = axis_context self.name_stack = name_stack self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None @@ -536,12 +518,18 @@ class ModuleContext: self.cached_call_jaxpr_lowerings = ({} if cached_call_jaxpr_lowerings is None else cached_call_jaxpr_lowerings) - self.shape_poly_state = shape_poly_state or ShapePolyLoweringState((), - (platform,)) + self.shape_poly_state = ( + shape_poly_state or ShapePolyLoweringState((), tuple(platforms))) self.lowering_parameters = lowering_parameters @property def backend(self) -> xb.XlaBackend: + # TODO(necula): clean the use of backend and backend_or_name vs. platforms + if len(self.platforms) > 1: + raise NotImplementedError( + "accessing .backend in multi-lowering setting. This can occur when " + "lowering a primitive that has not been adapted to multi-platform " + "lowering") if self.backend_or_name is None or isinstance(self.backend_or_name, str): return xb.get_backend(self.backend_or_name) return self.backend_or_name @@ -722,7 +710,7 @@ def lower_jaxpr_to_module( *, ordered_effects: list[core.Effect], backend_or_name: str | xb.XlaBackend | None, - platform: str, + platforms: Sequence[str], axis_context: AxisContext, name_stack: source_info_util.NameStack, donated_args: Sequence[bool], @@ -741,13 +729,7 @@ def lower_jaxpr_to_module( Handles the quirks of the argument/return value passing conventions of the runtime. """ - platforms: tuple[str, ...] - platform = xb.canonicalize_platform(platform) - if lowering_parameters.is_multi_platform: - platforms = tuple(map(xb.canonicalize_platform, - lowering_parameters.platforms)) # type: ignore - else: - platforms = (platform,) + platforms = tuple(map(xb.canonicalize_platform, platforms)) input_output_aliases = None in_avals = (jaxpr.in_avals if arg_shardings is None else @@ -809,7 +791,7 @@ def lower_jaxpr_to_module( if result_shardings is not None else result_shardings) ctx = ModuleContext(backend_or_name=backend_or_name, - platform=platform, axis_context=axis_context, + platforms=platforms, axis_context=axis_context, name_stack=name_stack, keepalives=keepalives, channel_iterator=channel_iter, @@ -1349,7 +1331,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, dim_var_values: the list of dimension variables values in the current IR function, in the order of ctx.shape_poly_state.dim_vars. """ - assert ctx.platform != "gpu" + assert "gpu" not in ctx.platforms def read(v: core.Atom) -> Sequence[ir.Value]: if type(v) is core.Literal: return ir_constants(xla.canonicalize_dtype(v.val)) @@ -1393,14 +1375,15 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, ctx.name_stack) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) - if not ctx.lowering_parameters.is_multi_platform: + if len(ctx.platforms) == 1: # Classic, single-platform lowering # TODO(necula): unify the code paths when multi-platform is finished + platform = ctx.platforms[0] if override_rule is not None: rule = override_rule - elif eqn.primitive in _platform_specific_lowerings[ctx.platform]: - rule = _platform_specific_lowerings[ctx.platform][eqn.primitive] - elif eqn.primitive in xla._backend_specific_translations[ctx.platform]: + elif eqn.primitive in _platform_specific_lowerings[platform]: + rule = _platform_specific_lowerings[platform][eqn.primitive] + elif eqn.primitive in xla._backend_specific_translations[platform]: rule = xla_fallback_lowering(eqn.primitive) elif eqn.primitive in _lowerings: rule = _lowerings[eqn.primitive] @@ -1409,7 +1392,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, else: raise NotImplementedError( f"MLIR translation rule for primitive '{eqn.primitive.name}' not " - f"found for platform {ctx.platform}") + f"found for platform {platform}") else: rules: list[MultiPlatformLoweringRule] # See mlir.lower_multi_platform for the `rules` format @@ -1418,7 +1401,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, else: # First the platform-specific rules rules = [] - for p in ctx.lowering_parameters.platforms: # type: ignore + for p in ctx.platforms: if eqn.primitive in _platform_specific_lowerings[p]: rules.append( ([p], _platform_specific_lowerings[p][eqn.primitive])) @@ -1449,7 +1432,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) rule_inputs = map(_unwrap_singleton_ir_values, in_nodes) - if not ctx.lowering_parameters.is_multi_platform: + if len(ctx.platforms) == 1: # Classic, single-platform lowering ans = rule(rule_ctx, *rule_inputs, **eqn.params) else: @@ -1528,12 +1511,7 @@ def lower_multi_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] - if ctx.module_context.lowering_parameters.is_multi_platform: - assert ctx.module_context.lowering_parameters.platforms is not None - platforms = ctx.module_context.lowering_parameters.platforms - else: - platforms = (ctx.module_context.platform,) + platforms: Sequence[str] = ctx.module_context.platforms platforms_with_specific_rules: Sequence[str] = util.flatten( [ps for ps, _ in rules if ps is not None]) platforms_with_default_rule = [p for p in platforms @@ -1681,11 +1659,16 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, return func_op -def check_backend_matches(inner_backend, outer_backend): +def check_backend_matches(inner_backend: Optional[str], + lowering_platforms: Sequence[str]): # For nested calls, the outermost call sets the backend for all inner calls; # it's an error if the inner call has a conflicting explicit backend spec. if inner_backend is None: return + outer_backend, *more_lowering_platforms = lowering_platforms + if more_lowering_platforms: + raise NotImplementedError( + "Multi-platform lowering when a backend= parameter is specified") if (inner_backend != outer_backend and outer_backend not in xb.expand_platform_alias(inner_backend)): raise ValueError( @@ -1693,13 +1676,15 @@ def check_backend_matches(inner_backend, outer_backend): f"inner-jit backend specification {inner_backend}.") -def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, +def _call_lowering(fn_name, stack_name, call_jaxpr, backend, + ctx: ModuleContext, avals_in, avals_out, tokens_in, *args, dim_var_values: Sequence[ir.Value], arg_names=None, result_names=None): + del stack_name, avals_in if isinstance(call_jaxpr, core.Jaxpr): call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) - check_backend_matches(backend, ctx.platform) + check_backend_matches(backend, ctx.platforms) effects = list(tokens_in.effects()) output_types = map(aval_to_ir_types, avals_out) output_types = [token_type()] * len(effects) + output_types @@ -1717,7 +1702,8 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens))) return out_nodes, tokens_out -def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): +def core_call_lowering(ctx: LoweringRuleContext, + *args, name, backend=None, call_jaxpr): out_nodes, tokens = _call_lowering( name, name, call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args, @@ -2137,8 +2123,11 @@ def xla_fallback_lowering(prim: core.Primitive): raise NotImplementedError( f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623") + if len(module_ctx.platforms) > 1: + raise NotImplementedError( + "fallback lowering not implemented for multi-platform lowering") xla_computation = xla.primitive_subcomputation( - module_ctx.platform, axis_env, prim, ctx.avals_in, + module_ctx.platforms[0], axis_env, prim, ctx.avals_in, ctx.avals_out, **params) xla_module = xla_computation_to_mlir_module(xla_computation) callee_name = merge_mlir_modules( @@ -2301,7 +2290,9 @@ def emit_python_callback( result_layouts: Sequence[Sequence[int] | None] | None = None, ) -> tuple[Sequence[ir.Value], Any, Any]: """Emits MLIR that calls back to a provided Python function.""" - platform = ctx.module_context.platform + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for python_callback") + platform = ctx.module_context.platforms[0] if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") @@ -2423,7 +2414,8 @@ def emit_python_callback( return results, token, ifrt_callback def build_xla_computation_helper( - closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str, + closed_jaxpr: core.ClosedJaxpr, *, name: str, + platforms: Sequence[str], backend_or_name: str, axis_context: AxisContext) -> xc.XlaComputation: """Helper to generate pmap-style XLA computations for custom partitioners.""" if closed_jaxpr.effects: @@ -2432,7 +2424,7 @@ def build_xla_computation_helper( backend_or_name=backend_or_name, ordered_effects=[], name_stack=source_info_util.NameStack(), donated_args=[False] * len(closed_jaxpr.jaxpr.invars), - axis_context=axis_context, platform=platform, + axis_context=axis_context, platforms=platforms, lowering_parameters=LoweringParameters()) return xc._xla.mlir.mlir_module_to_xla_computation( module_to_string(lowering_result.module), use_tuple_args=False, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a09fec093..aa179a577 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -758,7 +758,7 @@ def lower_parallel_callable( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platform=lowering_parameters.override_platform or backend.platform, + platforms=lowering_parameters.platforms or (backend.platform,), axis_context=sharding_impls.ReplicaAxisContext(axis_env), name_stack=name_stack, donated_args=donated_invars, @@ -1362,7 +1362,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, call_jaxpr, backend=None, in_axes, out_axes, donated_invars, is_explicit_global_axis_size): del donated_invars # Unused. - mlir.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platforms) # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. if ctx.module_context.axis_env.names and devices is not None: @@ -1835,7 +1835,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, ordered_effects=ordered_effects, backend_or_name=backend, # Optionally, override the lowering platform - platform=lowering_parameters.override_platform or backend.platform, + platforms=lowering_parameters.platforms or (backend.platform,), axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, @@ -2201,7 +2201,7 @@ def lower_mesh_computation( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platform=lowering_parameters.override_platform or backend.platform, + platforms=lowering_parameters.platforms or (backend.platform,), axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 8e1316bd7..d297438bd 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -725,12 +725,7 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups): for arg, named_shape in zip(args, named_shapes)] def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): - # TODO(necula): clean this up when we have module_context.platforms - if ctx.module_context.lowering_parameters.is_multi_platform: - for_tpu = ("tpu" in ctx.module_context.lowering_parameters.platforms) - else: - for_tpu = (ctx.module_context.platform == "tpu") - if axis_index_groups is not None and for_tpu: + if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms): len_0 = len(axis_index_groups[0]) if any(len(g) != len_0 for g in axis_index_groups): raise ValueError("axis_index_groups must all be the same size for TPU lowering") diff --git a/jax/_src/maps.py b/jax/_src/maps.py index ca7ff271b..86d3bd740 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1305,7 +1305,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, global_axis_sizes, spmd_in_axes, spmd_out_axes, axis_resources, resource_env, backend): - mlir.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platforms) + del backend, donated_invars # The only way for any of those two assertions to be violated is when xmap # is using the SPMD lowering, but then this rule shouldn't even trigger. assert spmd_in_axes is None and spmd_out_axes is None @@ -1381,7 +1382,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, donated_invars, global_axis_sizes, spmd_in_axes, spmd_out_axes, axis_resources, resource_env, backend): - mlir.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platforms) + del backend, donated_invars plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes) @@ -1447,9 +1449,10 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, donated_invars, global_axis_sizes, spmd_in_axes, spmd_out_axes, axis_resources, resource_env, backend): + del donated_invars assert spmd_in_axes is None and spmd_out_axes is None # This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule. - mlir.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platforms) plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes) manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index d0dfd0eb7..2868fe02e 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1503,7 +1503,9 @@ def pallas_call_lowering( **compiler_params ) num_warps = compiler_params.get("num_warps", 4) - if ctx.module_context.platform == 'rocm': + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for Pallas kernels") + if ctx.module_context.platforms[0] == 'rocm': num_stages = compiler_params.get("num_stages", 1) else: num_stages = compiler_params.get("num_stages", 3) @@ -1521,7 +1523,7 @@ def pallas_call_lowering( debug=debug, ) #Triton returns a tuple for ROCm. We just want file path to be passed - if ctx.module_context.platform == 'rocm': + if ctx.module_context.platforms[0] == 'rocm': compilation_result.ptx = compilation_result.ptx[1] if debug: diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index ba6cbe44e..3aa77adbf 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -181,7 +181,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, built = mlir.build_xla_computation_helper( closed_jaxpr, name="tmp_xla_computation", - platform=module_context.platform, + platforms=module_context.platforms, backend_or_name=module_context.backend_or_name, axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)), ) diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index bfeb921ce..d0220d8bd 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -689,7 +689,7 @@ def _wrap_main_func( with ir.InsertionPoint(entry_block): # Make a context just for lowering the dimension value computations module_context = mlir.ModuleContext( - backend_or_name="cpu", platform="cpu", + backend_or_name="cpu", platforms=["cpu"], axis_context=sharding_impls.ShardingContext([]), name_stack=source_info_util.new_name_stack(), keepalives=[], channel_iterator=itertools.count(1), @@ -1179,11 +1179,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, submodule_args = [] # All the platforms for the current lowering must be among the platforms # for which the callee was lowered. - if ctx.module_context.lowering_parameters.is_multi_platform: - assert ctx.module_context.lowering_parameters.platforms is not None - lowering_platforms = ctx.module_context.lowering_parameters.platforms - else: - lowering_platforms = (ctx.module_context.platform,) + lowering_platforms = ctx.module_context.platforms callee_lowering_platform_index: list[int] = [] for platform in lowering_platforms: diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 1bb111ef2..89f9bd6de 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1101,7 +1101,9 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext, flat_results_aval=(), **params): """MLIR Lowering for `CustomCall`-based HCB.""" - platform = ctx.module_context.platform + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for host_callback") + platform = ctx.module_context.platforms[0] use_outfeed = _use_outfeed(platform) if use_outfeed: # Fall back to XLA path if we are using the outfeed @@ -1118,7 +1120,7 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext, else: # TODO(necula): It seems that on CPU, with custom call, the device_index # does not work, and the callback is always run on device_index=0 - if (device_index != 0 and ctx.module_context.platform == "cpu"): + if (device_index != 0 and "cpu" in ctx.module_context.platforms): raise ValueError( "The device_index feature on CPU works only when using outfeed.") # We expect the current tokens at the end, inserted by _rewrite_jaxpr.