mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Clean up backend_or_name
vs. platforms
in lowering code.
It turns out that the backend is rarely needed when lowering, e.g., for lowering callbacks. Whenever we need the backend for lowering, we must be in single-platform lowering mode (`len(platforms) == 1`) and we can look up the backend from `platforms[0]`. However, in some rare cases we can have a custom `XlaBackend` whose platform matches `platforms[0]`. We rename `backend_or_name` to just `backend` and we restrict its type to be an optional `XlaBackend` (not a platform string). PiperOrigin-RevId: 712926140
This commit is contained in:
parent
a7f384cc6e
commit
fdb6af82d2
@ -190,7 +190,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
closed_jaxpr,
|
||||
name="tmp_xla_computation",
|
||||
platforms=module_context.platforms,
|
||||
backend_or_name=module_context.backend_or_name,
|
||||
backend=module_context.backend,
|
||||
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
|
||||
)
|
||||
result_sharding = _pack_result_sharding(result_shape, result_shardings)
|
||||
|
@ -896,7 +896,7 @@ def _wrap_main_func(
|
||||
with ir.InsertionPoint(entry_block):
|
||||
# Make a context just for lowering the dimension value computations
|
||||
module_context = mlir.ModuleContext(
|
||||
backend_or_name="cpu", platforms=["cpu"],
|
||||
backend=None, platforms=["cpu"],
|
||||
axis_context=sharding_impls.ShardingContext(0),
|
||||
keepalives=[], channel_iterator=itertools.count(1),
|
||||
host_callbacks=[], module=wrapped_module, context=context,
|
||||
|
@ -698,10 +698,11 @@ class ModuleContext:
|
||||
module: ir.Module
|
||||
ip: ir.InsertionPoint
|
||||
symbol_table: ir.SymbolTable
|
||||
backend_or_name: str | xb.XlaBackend | None
|
||||
# The lowering platforms for the module. Can be more than one only when
|
||||
# exporting.
|
||||
platforms: Sequence[str]
|
||||
# See ModuleContext.get_backend() for backend and platforms usage.
|
||||
backend: xb.XlaBackend | None
|
||||
axis_context: AxisContext
|
||||
keepalives: list[Any]
|
||||
channel_iterator: Iterator[int]
|
||||
@ -725,8 +726,8 @@ class ModuleContext:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
backend_or_name: str | xb.XlaBackend | None,
|
||||
platforms: Sequence[str],
|
||||
backend: xb.XlaBackend | None,
|
||||
axis_context: AxisContext,
|
||||
keepalives: list[Any],
|
||||
channel_iterator: Iterator[int],
|
||||
@ -745,7 +746,7 @@ class ModuleContext:
|
||||
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
|
||||
self.ip = ip or ir.InsertionPoint(self.module.body)
|
||||
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
|
||||
self.backend_or_name = backend_or_name
|
||||
self.backend = backend
|
||||
self.platforms = platforms
|
||||
self.axis_context = axis_context
|
||||
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
|
||||
@ -760,17 +761,20 @@ class ModuleContext:
|
||||
self.all_default_mem_kind = all_default_mem_kind
|
||||
self.lowering_parameters = lowering_parameters
|
||||
|
||||
@property
|
||||
def backend(self) -> xb.XlaBackend:
|
||||
# TODO(necula): clean the use of backend and backend_or_name vs. platforms
|
||||
def get_backend(self) -> xb.XlaBackend:
|
||||
if len(self.platforms) > 1:
|
||||
raise NotImplementedError(
|
||||
"accessing .backend in multi-lowering setting. This can occur when "
|
||||
"lowering a primitive that has not been adapted to multi-platform "
|
||||
"lowering")
|
||||
if self.backend_or_name is None or isinstance(self.backend_or_name, str):
|
||||
return xb.get_backend(self.backend_or_name)
|
||||
return self.backend_or_name
|
||||
if self.backend is not None:
|
||||
if xb.canonicalize_platform(self.backend.platform) != self.platforms[0]:
|
||||
raise ValueError(
|
||||
"the platform for the specified backend "
|
||||
f"{xb.canonicalize_platform(self.backend.platform)} is different "
|
||||
f"from the lowering platform {self.platforms[0]}")
|
||||
return self.backend
|
||||
return xb.get_backend(self.platforms[0])
|
||||
|
||||
def new_channel(self) -> int:
|
||||
channel = next(self.channel_iterator)
|
||||
@ -1072,14 +1076,14 @@ def _get_unconstrained_dimensions(s, aval):
|
||||
return (us, all_unconstrained(s, aval),
|
||||
({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
*,
|
||||
ordered_effects: list[core.Effect],
|
||||
backend_or_name: str | xb.XlaBackend | None,
|
||||
# See ModuleContext.get_backend() for backend and platforms usage.
|
||||
platforms: Sequence[str],
|
||||
backend: xb.XlaBackend | None,
|
||||
axis_context: AxisContext,
|
||||
name_stack: source_info_util.NameStack,
|
||||
donated_args: Sequence[bool],
|
||||
@ -1170,7 +1174,7 @@ def lower_jaxpr_to_module(
|
||||
else:
|
||||
dim_vars = ()
|
||||
|
||||
ctx = ModuleContext(backend_or_name=backend_or_name,
|
||||
ctx = ModuleContext(backend=backend,
|
||||
platforms=platforms, axis_context=axis_context,
|
||||
keepalives=keepalives,
|
||||
channel_iterator=channel_iter,
|
||||
@ -2892,7 +2896,7 @@ def emit_python_callback(
|
||||
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
|
||||
raise ValueError(
|
||||
f"`EmitPythonCallback` not supported on {platform} backend.")
|
||||
backend = ctx.module_context.backend
|
||||
backend = ctx.module_context.get_backend()
|
||||
result_shapes = util.flatten(
|
||||
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
|
||||
operand_shapes = util.flatten(
|
||||
@ -3012,13 +3016,14 @@ def emit_python_callback(
|
||||
def build_mlir_module_helper(
|
||||
closed_jaxpr: core.ClosedJaxpr, *, name: str,
|
||||
platforms: Sequence[str],
|
||||
backend_or_name: str, axis_context: AxisContext) -> ir.Module:
|
||||
backend: xb.XlaBackend | None,
|
||||
axis_context: AxisContext) -> ir.Module:
|
||||
"""Helper to generate pmap-style XLA computations for custom partitioners."""
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects)
|
||||
if unlowerable_effects:
|
||||
raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}')
|
||||
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
|
||||
backend_or_name=backend_or_name, ordered_effects=[],
|
||||
backend=backend, ordered_effects=[],
|
||||
name_stack=source_info_util.NameStack(),
|
||||
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
|
||||
axis_context=axis_context, platforms=platforms,
|
||||
|
@ -871,7 +871,7 @@ def lower_parallel_callable(
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
backend=backend,
|
||||
platforms=platforms,
|
||||
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
|
||||
name_stack=name_stack,
|
||||
@ -1954,7 +1954,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
module_name,
|
||||
closed_jaxpr,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
backend=backend,
|
||||
platforms=platforms,
|
||||
axis_context=axis_ctx,
|
||||
name_stack=name_stack,
|
||||
|
Loading…
x
Reference in New Issue
Block a user