From fdb6af82d2eb233bcaee0bead2d7cd8eb5a5211c Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 7 Jan 2025 08:42:23 -0800 Subject: [PATCH] Clean up `backend_or_name` vs. `platforms` in lowering code. It turns out that the backend is rarely needed when lowering, e.g., for lowering callbacks. Whenever we need the backend for lowering, we must be in single-platform lowering mode (`len(platforms) == 1`) and we can look up the backend from `platforms[0]`. However, in some rare cases we can have a custom `XlaBackend` whose platform matches `platforms[0]`. We rename `backend_or_name` to just `backend` and we restrict its type to be an optional `XlaBackend` (not a platform string). PiperOrigin-RevId: 712926140 --- jax/_src/custom_partitioning.py | 2 +- jax/_src/export/_export.py | 2 +- jax/_src/interpreters/mlir.py | 35 +++++++++++++++++++-------------- jax/_src/interpreters/pxla.py | 4 ++-- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 4cf34f200..6b0ef293e 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -190,7 +190,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, closed_jaxpr, name="tmp_xla_computation", platforms=module_context.platforms, - backend_or_name=module_context.backend_or_name, + backend=module_context.backend, axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)), ) result_sharding = _pack_result_sharding(result_shape, result_shardings) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 8ba43083d..6b1945746 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -896,7 +896,7 @@ def _wrap_main_func( with ir.InsertionPoint(entry_block): # Make a context just for lowering the dimension value computations module_context = mlir.ModuleContext( - backend_or_name="cpu", platforms=["cpu"], + backend=None, platforms=["cpu"], axis_context=sharding_impls.ShardingContext(0), keepalives=[], channel_iterator=itertools.count(1), host_callbacks=[], module=wrapped_module, context=context, diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 85da4e53b..c6a634a05 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -698,10 +698,11 @@ class ModuleContext: module: ir.Module ip: ir.InsertionPoint symbol_table: ir.SymbolTable - backend_or_name: str | xb.XlaBackend | None # The lowering platforms for the module. Can be more than one only when # exporting. platforms: Sequence[str] + # See ModuleContext.get_backend() for backend and platforms usage. + backend: xb.XlaBackend | None axis_context: AxisContext keepalives: list[Any] channel_iterator: Iterator[int] @@ -725,8 +726,8 @@ class ModuleContext: def __init__( self, *, - backend_or_name: str | xb.XlaBackend | None, platforms: Sequence[str], + backend: xb.XlaBackend | None, axis_context: AxisContext, keepalives: list[Any], channel_iterator: Iterator[int], @@ -745,7 +746,7 @@ class ModuleContext: self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) self.ip = ip or ir.InsertionPoint(self.module.body) self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation) - self.backend_or_name = backend_or_name + self.backend = backend self.platforms = platforms self.axis_context = axis_context self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None @@ -760,17 +761,20 @@ class ModuleContext: self.all_default_mem_kind = all_default_mem_kind self.lowering_parameters = lowering_parameters - @property - def backend(self) -> xb.XlaBackend: - # TODO(necula): clean the use of backend and backend_or_name vs. platforms + def get_backend(self) -> xb.XlaBackend: if len(self.platforms) > 1: raise NotImplementedError( "accessing .backend in multi-lowering setting. This can occur when " "lowering a primitive that has not been adapted to multi-platform " "lowering") - if self.backend_or_name is None or isinstance(self.backend_or_name, str): - return xb.get_backend(self.backend_or_name) - return self.backend_or_name + if self.backend is not None: + if xb.canonicalize_platform(self.backend.platform) != self.platforms[0]: + raise ValueError( + "the platform for the specified backend " + f"{xb.canonicalize_platform(self.backend.platform)} is different " + f"from the lowering platform {self.platforms[0]}") + return self.backend + return xb.get_backend(self.platforms[0]) def new_channel(self) -> int: channel = next(self.channel_iterator) @@ -1072,14 +1076,14 @@ def _get_unconstrained_dimensions(s, aval): return (us, all_unconstrained(s, aval), ({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None)) - def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, *, ordered_effects: list[core.Effect], - backend_or_name: str | xb.XlaBackend | None, + # See ModuleContext.get_backend() for backend and platforms usage. platforms: Sequence[str], + backend: xb.XlaBackend | None, axis_context: AxisContext, name_stack: source_info_util.NameStack, donated_args: Sequence[bool], @@ -1170,7 +1174,7 @@ def lower_jaxpr_to_module( else: dim_vars = () - ctx = ModuleContext(backend_or_name=backend_or_name, + ctx = ModuleContext(backend=backend, platforms=platforms, axis_context=axis_context, keepalives=keepalives, channel_iterator=channel_iter, @@ -2892,7 +2896,7 @@ def emit_python_callback( if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") - backend = ctx.module_context.backend + backend = ctx.module_context.get_backend() result_shapes = util.flatten( [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) operand_shapes = util.flatten( @@ -3012,13 +3016,14 @@ def emit_python_callback( def build_mlir_module_helper( closed_jaxpr: core.ClosedJaxpr, *, name: str, platforms: Sequence[str], - backend_or_name: str, axis_context: AxisContext) -> ir.Module: + backend: xb.XlaBackend | None, + axis_context: AxisContext) -> ir.Module: """Helper to generate pmap-style XLA computations for custom partitioners.""" unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects) if unlowerable_effects: raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}') lowering_result = lower_jaxpr_to_module(name, closed_jaxpr, - backend_or_name=backend_or_name, ordered_effects=[], + backend=backend, ordered_effects=[], name_stack=source_info_util.NameStack(), donated_args=[False] * len(closed_jaxpr.jaxpr.invars), axis_context=axis_context, platforms=platforms, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a61600402..807dd0afe 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -871,7 +871,7 @@ def lower_parallel_callable( module_name, closed_jaxpr, ordered_effects=ordered_effects, - backend_or_name=backend, + backend=backend, platforms=platforms, axis_context=sharding_impls.ReplicaAxisContext(axis_env), name_stack=name_stack, @@ -1954,7 +1954,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, module_name, closed_jaxpr, ordered_effects=ordered_effects, - backend_or_name=backend, + backend=backend, platforms=platforms, axis_context=axis_ctx, name_stack=name_stack,