diff --git a/jax/_src/api.py b/jax/_src/api.py index 05f8dd52e..6c5e5a5be 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -565,7 +565,7 @@ def xla_computation(fun: Callable, core.ClosedJaxpr(jaxpr, consts), ordered_effects=ordered_effects, backend_or_name=backend, - platform=platform, + platforms=[platform], axis_context=sharding_impls.ReplicaAxisContext(axis_env_), name_stack=source_info_util.new_name_stack( wrap_name(fun_name, "xla_computation")), diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 6d5d6d5d5..69219870d 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -385,7 +385,7 @@ def inspect_sharding_partition(shapes, arg_shardings, result_shape, jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) trivial_comp = mlir.build_xla_computation_helper(closed_jaxpr, - name="tmp_xla_computation", platform=module_context.platform, + name="tmp_xla_computation", platforms=module_context.platforms, backend_or_name=module_context.backend_or_name, axis_context=module_context.axis_context) # The trivial computation built here has a dummy tuple as the result, diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index ecd746615..22fc52bbd 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -554,6 +554,6 @@ def _common_device_put_lowering(ctx, x, *, device, src): device.memory_kind is not None): raise NotImplementedError( "Passing memory_kind to device_put via Shardings is not supported on" - f" platform {ctx.module_context.platform}") + f" platforms {ctx.module_context.platforms}") return [x] mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e576c4907..416c72b11 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -430,8 +430,7 @@ class LoweringParameters: # The current lowering platforms, a non-empty tuple containing some of # 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are # doing multi-platform lowering, otherwise it can specify cross-platform - # lowering. The value None specifies default lowering platform, for the - # platform specified by `ModuleContext.platform`. + # lowering. The value None specifies the default lowering platform. # This is used only in export and jax2tf. platforms: tuple[str, ...] | None = None @@ -454,23 +453,6 @@ class LoweringParameters: # native execution (and we can remove this parameter). replace_tokens_with_dummy: bool = True - @property - def override_platform(self) -> str | None: - """Overrides the lowering platform for cross-platform lowering. - - One of 'cpu', 'cuda', 'rocm', 'tpu'. - If None, use the default JAX mechanisms to pick the lowering platform. - This is currently used for export and jax2tf. - """ - if self.platforms is not None: - return self.platforms[0] - else: - return None - - @property - def is_multi_platform(self) -> bool: - return self.platforms is not None and len(self.platforms) > 1 - @dataclasses.dataclass class ModuleContext: @@ -480,7 +462,7 @@ class ModuleContext: ip: ir.InsertionPoint symbol_table: ir.SymbolTable backend_or_name: str | xb.XlaBackend | None - platform: str + platforms: Sequence[str] axis_context: AxisContext name_stack: source_info_util.NameStack keepalives: list[Any] @@ -503,7 +485,7 @@ class ModuleContext: self, *, backend_or_name: str | xb.XlaBackend | None, - platform: str, + platforms: Sequence[str], axis_context: AxisContext, name_stack: source_info_util.NameStack, keepalives: list[Any], @@ -519,13 +501,13 @@ class ModuleContext: cached_call_jaxpr_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, shape_poly_state = None): - assert platform is not None + self.context = context or make_ir_context() self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) self.ip = ip or ir.InsertionPoint(self.module.body) self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation) self.backend_or_name = backend_or_name - self.platform = platform + self.platforms = platforms self.axis_context = axis_context self.name_stack = name_stack self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None @@ -536,12 +518,18 @@ class ModuleContext: self.cached_call_jaxpr_lowerings = ({} if cached_call_jaxpr_lowerings is None else cached_call_jaxpr_lowerings) - self.shape_poly_state = shape_poly_state or ShapePolyLoweringState((), - (platform,)) + self.shape_poly_state = ( + shape_poly_state or ShapePolyLoweringState((), tuple(platforms))) self.lowering_parameters = lowering_parameters @property def backend(self) -> xb.XlaBackend: + # TODO(necula): clean the use of backend and backend_or_name vs. platforms + if len(self.platforms) > 1: + raise NotImplementedError( + "accessing .backend in multi-lowering setting. This can occur when " + "lowering a primitive that has not been adapted to multi-platform " + "lowering") if self.backend_or_name is None or isinstance(self.backend_or_name, str): return xb.get_backend(self.backend_or_name) return self.backend_or_name @@ -722,7 +710,7 @@ def lower_jaxpr_to_module( *, ordered_effects: list[core.Effect], backend_or_name: str | xb.XlaBackend | None, - platform: str, + platforms: Sequence[str], axis_context: AxisContext, name_stack: source_info_util.NameStack, donated_args: Sequence[bool], @@ -741,13 +729,7 @@ def lower_jaxpr_to_module( Handles the quirks of the argument/return value passing conventions of the runtime. """ - platforms: tuple[str, ...] - platform = xb.canonicalize_platform(platform) - if lowering_parameters.is_multi_platform: - platforms = tuple(map(xb.canonicalize_platform, - lowering_parameters.platforms)) # type: ignore - else: - platforms = (platform,) + platforms = tuple(map(xb.canonicalize_platform, platforms)) input_output_aliases = None in_avals = (jaxpr.in_avals if arg_shardings is None else @@ -809,7 +791,7 @@ def lower_jaxpr_to_module( if result_shardings is not None else result_shardings) ctx = ModuleContext(backend_or_name=backend_or_name, - platform=platform, axis_context=axis_context, + platforms=platforms, axis_context=axis_context, name_stack=name_stack, keepalives=keepalives, channel_iterator=channel_iter, @@ -1349,7 +1331,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, dim_var_values: the list of dimension variables values in the current IR function, in the order of ctx.shape_poly_state.dim_vars. """ - assert ctx.platform != "gpu" + assert "gpu" not in ctx.platforms def read(v: core.Atom) -> Sequence[ir.Value]: if type(v) is core.Literal: return ir_constants(xla.canonicalize_dtype(v.val)) @@ -1393,14 +1375,15 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, ctx.name_stack) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) - if not ctx.lowering_parameters.is_multi_platform: + if len(ctx.platforms) == 1: # Classic, single-platform lowering # TODO(necula): unify the code paths when multi-platform is finished + platform = ctx.platforms[0] if override_rule is not None: rule = override_rule - elif eqn.primitive in _platform_specific_lowerings[ctx.platform]: - rule = _platform_specific_lowerings[ctx.platform][eqn.primitive] - elif eqn.primitive in xla._backend_specific_translations[ctx.platform]: + elif eqn.primitive in _platform_specific_lowerings[platform]: + rule = _platform_specific_lowerings[platform][eqn.primitive] + elif eqn.primitive in xla._backend_specific_translations[platform]: rule = xla_fallback_lowering(eqn.primitive) elif eqn.primitive in _lowerings: rule = _lowerings[eqn.primitive] @@ -1409,7 +1392,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, else: raise NotImplementedError( f"MLIR translation rule for primitive '{eqn.primitive.name}' not " - f"found for platform {ctx.platform}") + f"found for platform {platform}") else: rules: list[MultiPlatformLoweringRule] # See mlir.lower_multi_platform for the `rules` format @@ -1418,7 +1401,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, else: # First the platform-specific rules rules = [] - for p in ctx.lowering_parameters.platforms: # type: ignore + for p in ctx.platforms: if eqn.primitive in _platform_specific_lowerings[p]: rules.append( ([p], _platform_specific_lowerings[p][eqn.primitive])) @@ -1449,7 +1432,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) rule_inputs = map(_unwrap_singleton_ir_values, in_nodes) - if not ctx.lowering_parameters.is_multi_platform: + if len(ctx.platforms) == 1: # Classic, single-platform lowering ans = rule(rule_ctx, *rule_inputs, **eqn.params) else: @@ -1528,12 +1511,7 @@ def lower_multi_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] - if ctx.module_context.lowering_parameters.is_multi_platform: - assert ctx.module_context.lowering_parameters.platforms is not None - platforms = ctx.module_context.lowering_parameters.platforms - else: - platforms = (ctx.module_context.platform,) + platforms: Sequence[str] = ctx.module_context.platforms platforms_with_specific_rules: Sequence[str] = util.flatten( [ps for ps, _ in rules if ps is not None]) platforms_with_default_rule = [p for p in platforms @@ -1681,11 +1659,16 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, return func_op -def check_backend_matches(inner_backend, outer_backend): +def check_backend_matches(inner_backend: Optional[str], + lowering_platforms: Sequence[str]): # For nested calls, the outermost call sets the backend for all inner calls; # it's an error if the inner call has a conflicting explicit backend spec. if inner_backend is None: return + outer_backend, *more_lowering_platforms = lowering_platforms + if more_lowering_platforms: + raise NotImplementedError( + "Multi-platform lowering when a backend= parameter is specified") if (inner_backend != outer_backend and outer_backend not in xb.expand_platform_alias(inner_backend)): raise ValueError( @@ -1693,13 +1676,15 @@ def check_backend_matches(inner_backend, outer_backend): f"inner-jit backend specification {inner_backend}.") -def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, +def _call_lowering(fn_name, stack_name, call_jaxpr, backend, + ctx: ModuleContext, avals_in, avals_out, tokens_in, *args, dim_var_values: Sequence[ir.Value], arg_names=None, result_names=None): + del stack_name, avals_in if isinstance(call_jaxpr, core.Jaxpr): call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) - check_backend_matches(backend, ctx.platform) + check_backend_matches(backend, ctx.platforms) effects = list(tokens_in.effects()) output_types = map(aval_to_ir_types, avals_out) output_types = [token_type()] * len(effects) + output_types @@ -1717,7 +1702,8 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens))) return out_nodes, tokens_out -def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): +def core_call_lowering(ctx: LoweringRuleContext, + *args, name, backend=None, call_jaxpr): out_nodes, tokens = _call_lowering( name, name, call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args, @@ -2137,8 +2123,11 @@ def xla_fallback_lowering(prim: core.Primitive): raise NotImplementedError( f"Shape polymorphism for xla_fallback_lowering is not implemented ({ctx.primitive}); b/261682623") + if len(module_ctx.platforms) > 1: + raise NotImplementedError( + "fallback lowering not implemented for multi-platform lowering") xla_computation = xla.primitive_subcomputation( - module_ctx.platform, axis_env, prim, ctx.avals_in, + module_ctx.platforms[0], axis_env, prim, ctx.avals_in, ctx.avals_out, **params) xla_module = xla_computation_to_mlir_module(xla_computation) callee_name = merge_mlir_modules( @@ -2301,7 +2290,9 @@ def emit_python_callback( result_layouts: Sequence[Sequence[int] | None] | None = None, ) -> tuple[Sequence[ir.Value], Any, Any]: """Emits MLIR that calls back to a provided Python function.""" - platform = ctx.module_context.platform + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for python_callback") + platform = ctx.module_context.platforms[0] if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") @@ -2423,7 +2414,8 @@ def emit_python_callback( return results, token, ifrt_callback def build_xla_computation_helper( - closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str, + closed_jaxpr: core.ClosedJaxpr, *, name: str, + platforms: Sequence[str], backend_or_name: str, axis_context: AxisContext) -> xc.XlaComputation: """Helper to generate pmap-style XLA computations for custom partitioners.""" if closed_jaxpr.effects: @@ -2432,7 +2424,7 @@ def build_xla_computation_helper( backend_or_name=backend_or_name, ordered_effects=[], name_stack=source_info_util.NameStack(), donated_args=[False] * len(closed_jaxpr.jaxpr.invars), - axis_context=axis_context, platform=platform, + axis_context=axis_context, platforms=platforms, lowering_parameters=LoweringParameters()) return xc._xla.mlir.mlir_module_to_xla_computation( module_to_string(lowering_result.module), use_tuple_args=False, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a09fec093..aa179a577 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -758,7 +758,7 @@ def lower_parallel_callable( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platform=lowering_parameters.override_platform or backend.platform, + platforms=lowering_parameters.platforms or (backend.platform,), axis_context=sharding_impls.ReplicaAxisContext(axis_env), name_stack=name_stack, donated_args=donated_invars, @@ -1362,7 +1362,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, call_jaxpr, backend=None, in_axes, out_axes, donated_invars, is_explicit_global_axis_size): del donated_invars # Unused. - mlir.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platforms) # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. if ctx.module_context.axis_env.names and devices is not None: @@ -1835,7 +1835,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, ordered_effects=ordered_effects, backend_or_name=backend, # Optionally, override the lowering platform - platform=lowering_parameters.override_platform or backend.platform, + platforms=lowering_parameters.platforms or (backend.platform,), axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, @@ -2201,7 +2201,7 @@ def lower_mesh_computation( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platform=lowering_parameters.override_platform or backend.platform, + platforms=lowering_parameters.platforms or (backend.platform,), axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 8e1316bd7..d297438bd 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -725,12 +725,7 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups): for arg, named_shape in zip(args, named_shapes)] def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): - # TODO(necula): clean this up when we have module_context.platforms - if ctx.module_context.lowering_parameters.is_multi_platform: - for_tpu = ("tpu" in ctx.module_context.lowering_parameters.platforms) - else: - for_tpu = (ctx.module_context.platform == "tpu") - if axis_index_groups is not None and for_tpu: + if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms): len_0 = len(axis_index_groups[0]) if any(len(g) != len_0 for g in axis_index_groups): raise ValueError("axis_index_groups must all be the same size for TPU lowering") diff --git a/jax/_src/maps.py b/jax/_src/maps.py index ca7ff271b..86d3bd740 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1305,7 +1305,8 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, global_axis_sizes, spmd_in_axes, spmd_out_axes, axis_resources, resource_env, backend): - mlir.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platforms) + del backend, donated_invars # The only way for any of those two assertions to be violated is when xmap # is using the SPMD lowering, but then this rule shouldn't even trigger. assert spmd_in_axes is None and spmd_out_axes is None @@ -1381,7 +1382,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, donated_invars, global_axis_sizes, spmd_in_axes, spmd_out_axes, axis_resources, resource_env, backend): - mlir.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platforms) + del backend, donated_invars plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes) @@ -1447,9 +1449,10 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, donated_invars, global_axis_sizes, spmd_in_axes, spmd_out_axes, axis_resources, resource_env, backend): + del donated_invars assert spmd_in_axes is None and spmd_out_axes is None # This first part (up to vtile_manual) is shared with non-MANUAL SPMD rule. - mlir.check_backend_matches(backend, ctx.module_context.platform) + mlir.check_backend_matches(backend, ctx.module_context.platforms) plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes) manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index d0dfd0eb7..2868fe02e 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1503,7 +1503,9 @@ def pallas_call_lowering( **compiler_params ) num_warps = compiler_params.get("num_warps", 4) - if ctx.module_context.platform == 'rocm': + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for Pallas kernels") + if ctx.module_context.platforms[0] == 'rocm': num_stages = compiler_params.get("num_stages", 1) else: num_stages = compiler_params.get("num_stages", 3) @@ -1521,7 +1523,7 @@ def pallas_call_lowering( debug=debug, ) #Triton returns a tuple for ROCm. We just want file path to be passed - if ctx.module_context.platform == 'rocm': + if ctx.module_context.platforms[0] == 'rocm': compilation_result.ptx = compilation_result.ptx[1] if debug: diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index ba6cbe44e..3aa77adbf 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -181,7 +181,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, built = mlir.build_xla_computation_helper( closed_jaxpr, name="tmp_xla_computation", - platform=module_context.platform, + platforms=module_context.platforms, backend_or_name=module_context.backend_or_name, axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)), ) diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index bfeb921ce..d0220d8bd 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -689,7 +689,7 @@ def _wrap_main_func( with ir.InsertionPoint(entry_block): # Make a context just for lowering the dimension value computations module_context = mlir.ModuleContext( - backend_or_name="cpu", platform="cpu", + backend_or_name="cpu", platforms=["cpu"], axis_context=sharding_impls.ShardingContext([]), name_stack=source_info_util.new_name_stack(), keepalives=[], channel_iterator=itertools.count(1), @@ -1179,11 +1179,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, submodule_args = [] # All the platforms for the current lowering must be among the platforms # for which the callee was lowered. - if ctx.module_context.lowering_parameters.is_multi_platform: - assert ctx.module_context.lowering_parameters.platforms is not None - lowering_platforms = ctx.module_context.lowering_parameters.platforms - else: - lowering_platforms = (ctx.module_context.platform,) + lowering_platforms = ctx.module_context.platforms callee_lowering_platform_index: list[int] = [] for platform in lowering_platforms: diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 1bb111ef2..89f9bd6de 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1101,7 +1101,9 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext, flat_results_aval=(), **params): """MLIR Lowering for `CustomCall`-based HCB.""" - platform = ctx.module_context.platform + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for host_callback") + platform = ctx.module_context.platforms[0] use_outfeed = _use_outfeed(platform) if use_outfeed: # Fall back to XLA path if we are using the outfeed @@ -1118,7 +1120,7 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext, else: # TODO(necula): It seems that on CPU, with custom call, the device_index # does not work, and the callback is always run on device_index=0 - if (device_index != 0 and ctx.module_context.platform == "cpu"): + if (device_index != 0 and "cpu" in ctx.module_context.platforms): raise ValueError( "The device_index feature on CPU works only when using outfeed.") # We expect the current tokens at the end, inserted by _rewrite_jaxpr.