Clean up backend_or_name vs. platforms in lowering code.

It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.

However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).

PiperOrigin-RevId: 712926140
This commit is contained in:
George Necula 2025-01-07 08:42:23 -08:00 committed by jax authors
parent a7f384cc6e
commit fdb6af82d2
4 changed files with 24 additions and 19 deletions

View File

@ -190,7 +190,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
closed_jaxpr,
name="tmp_xla_computation",
platforms=module_context.platforms,
backend_or_name=module_context.backend_or_name,
backend=module_context.backend,
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
)
result_sharding = _pack_result_sharding(result_shape, result_shardings)

View File

@ -896,7 +896,7 @@ def _wrap_main_func(
with ir.InsertionPoint(entry_block):
# Make a context just for lowering the dimension value computations
module_context = mlir.ModuleContext(
backend_or_name="cpu", platforms=["cpu"],
backend=None, platforms=["cpu"],
axis_context=sharding_impls.ShardingContext(0),
keepalives=[], channel_iterator=itertools.count(1),
host_callbacks=[], module=wrapped_module, context=context,

View File

@ -698,10 +698,11 @@ class ModuleContext:
module: ir.Module
ip: ir.InsertionPoint
symbol_table: ir.SymbolTable
backend_or_name: str | xb.XlaBackend | None
# The lowering platforms for the module. Can be more than one only when
# exporting.
platforms: Sequence[str]
# See ModuleContext.get_backend() for backend and platforms usage.
backend: xb.XlaBackend | None
axis_context: AxisContext
keepalives: list[Any]
channel_iterator: Iterator[int]
@ -725,8 +726,8 @@ class ModuleContext:
def __init__(
self,
*,
backend_or_name: str | xb.XlaBackend | None,
platforms: Sequence[str],
backend: xb.XlaBackend | None,
axis_context: AxisContext,
keepalives: list[Any],
channel_iterator: Iterator[int],
@ -745,7 +746,7 @@ class ModuleContext:
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
self.ip = ip or ir.InsertionPoint(self.module.body)
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
self.backend_or_name = backend_or_name
self.backend = backend
self.platforms = platforms
self.axis_context = axis_context
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
@ -760,17 +761,20 @@ class ModuleContext:
self.all_default_mem_kind = all_default_mem_kind
self.lowering_parameters = lowering_parameters
@property
def backend(self) -> xb.XlaBackend:
# TODO(necula): clean the use of backend and backend_or_name vs. platforms
def get_backend(self) -> xb.XlaBackend:
if len(self.platforms) > 1:
raise NotImplementedError(
"accessing .backend in multi-lowering setting. This can occur when "
"lowering a primitive that has not been adapted to multi-platform "
"lowering")
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
return xb.get_backend(self.backend_or_name)
return self.backend_or_name
if self.backend is not None:
if xb.canonicalize_platform(self.backend.platform) != self.platforms[0]:
raise ValueError(
"the platform for the specified backend "
f"{xb.canonicalize_platform(self.backend.platform)} is different "
f"from the lowering platform {self.platforms[0]}")
return self.backend
return xb.get_backend(self.platforms[0])
def new_channel(self) -> int:
channel = next(self.channel_iterator)
@ -1072,14 +1076,14 @@ def _get_unconstrained_dimensions(s, aval):
return (us, all_unconstrained(s, aval),
({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
*,
ordered_effects: list[core.Effect],
backend_or_name: str | xb.XlaBackend | None,
# See ModuleContext.get_backend() for backend and platforms usage.
platforms: Sequence[str],
backend: xb.XlaBackend | None,
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
@ -1170,7 +1174,7 @@ def lower_jaxpr_to_module(
else:
dim_vars = ()
ctx = ModuleContext(backend_or_name=backend_or_name,
ctx = ModuleContext(backend=backend,
platforms=platforms, axis_context=axis_context,
keepalives=keepalives,
channel_iterator=channel_iter,
@ -2892,7 +2896,7 @@ def emit_python_callback(
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
raise ValueError(
f"`EmitPythonCallback` not supported on {platform} backend.")
backend = ctx.module_context.backend
backend = ctx.module_context.get_backend()
result_shapes = util.flatten(
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
operand_shapes = util.flatten(
@ -3012,13 +3016,14 @@ def emit_python_callback(
def build_mlir_module_helper(
closed_jaxpr: core.ClosedJaxpr, *, name: str,
platforms: Sequence[str],
backend_or_name: str, axis_context: AxisContext) -> ir.Module:
backend: xb.XlaBackend | None,
axis_context: AxisContext) -> ir.Module:
"""Helper to generate pmap-style XLA computations for custom partitioners."""
unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects)
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}')
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
backend_or_name=backend_or_name, ordered_effects=[],
backend=backend, ordered_effects=[],
name_stack=source_info_util.NameStack(),
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
axis_context=axis_context, platforms=platforms,

View File

@ -871,7 +871,7 @@ def lower_parallel_callable(
module_name,
closed_jaxpr,
ordered_effects=ordered_effects,
backend_or_name=backend,
backend=backend,
platforms=platforms,
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
name_stack=name_stack,
@ -1954,7 +1954,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
module_name,
closed_jaxpr,
ordered_effects=ordered_effects,
backend_or_name=backend,
backend=backend,
platforms=platforms,
axis_context=axis_ctx,
name_stack=name_stack,