Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!

The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
This commit is contained in:
Yash Katariya 2024-06-17 11:58:18 -07:00 committed by jax authors
parent be1f4ba380
commit 6ba16e0348
8 changed files with 60 additions and 68 deletions

View File

@ -1819,8 +1819,6 @@ def _cpp_pmap(
@api_boundary
def trace(*args, **kwargs):
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
@ -1842,7 +1840,6 @@ def _cpp_pmap(
donated_invars=p.donated_invars,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_parameters=lowering_parameters,
closed_jaxpr=closed_jaxpr,
backend=xc_backend,
replicas=replicas,

View File

@ -424,7 +424,7 @@ def export_back_compat(
"""
def do_export(*args_specs, **kwargs_specs) -> Exported:
if hasattr(fun_jax, "lower"):
if hasattr(fun_jax, "trace"):
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
wrapped_fun_jax = fun_jax
@ -434,8 +434,6 @@ def export_back_compat(
# an error if the lowered function contains non-replicated sharding annotations.
wrapped_fun_jax = jax.jit(fun_jax)
has_trace = hasattr(wrapped_fun_jax, "trace")
if lowering_platforms is not None:
actual_lowering_platforms = tuple(lowering_platforms)
else:
@ -457,25 +455,12 @@ def export_back_compat(
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
if has_trace:
traced = wrapped_fun_jax.trace( # type: ignore
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
for_export=True,
))
jaxpr, fun_name = traced.jaxpr, traced.fun_name
lowered = traced.lower()
else:
lowered = wrapped_fun_jax.lower(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
for_export=True,
))
jaxpr, fun_name = None, util.fun_name(wrapped_fun_jax)
traced = wrapped_fun_jax.trace(*args_specs, **kwargs_specs)
lowered = traced.lower(
lowering_platforms=actual_lowering_platforms,
_private_parameters=mlir.LoweringParameters(for_export=True))
return _export_lowered(
lowered, jaxpr, fun_name,
lowered, traced.jaxpr, traced.fun_name,
disabled_checks=disabled_checks,
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
return do_export
@ -553,16 +538,12 @@ def export(
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
traced = fun_jit.trace(
*args_specs, **kwargs_specs,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
for_export=True,
))
jaxpr, fun_name = traced.jaxpr, traced.fun_name
lowered = traced.lower()
traced = fun_jit.trace(*args_specs, **kwargs_specs)
lowered = traced.lower(
lowering_platforms=actual_lowering_platforms,
_private_parameters=mlir.LoweringParameters(for_export=True))
return _export_lowered(
lowered, jaxpr, fun_name,
lowered, traced.jaxpr, traced.fun_name,
disabled_checks=disabled_checks)
return do_export

View File

@ -546,13 +546,6 @@ class LoweringParameters:
# existing Jax rules.
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None
# The current lowering platforms, a non-empty tuple containing some of
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
# doing multi-platform lowering, otherwise it can specify cross-platform
# lowering. The value None specifies the default lowering platform.
# This is used only in export and jax2tf.
platforms: tuple[str, ...] | None = None
# Signals that the entire computation being lowered operates on global
# constants. This will result in adding jax.global_constant attributes
# to the arguments of all functions that are created, e.g., floor_divide.
@ -621,8 +614,7 @@ class ModuleContext:
module: ir.Module | None = None,
ip: ir.InsertionPoint | None = None,
symbol_table: ir.SymbolTable | None = None,
cached_primitive_lowerings: None | (dict[Any,
func_dialect.FuncOp]) = None,
cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None,
traceback_caches: None | TracebackCaches = None,
shape_poly_state = None):
@ -948,8 +940,7 @@ def lower_jaxpr_to_module(
channel_iterator=channel_iter,
host_callbacks=host_callbacks,
lowering_parameters=lowering_parameters,
shape_poly_state=ShapePolyLoweringState(
dim_vars, lowering_parameters.platforms))
shape_poly_state=ShapePolyLoweringState(dim_vars, platforms))
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.

View File

@ -592,8 +592,9 @@ def parallel_callable(fun: lu.WrappedFun,
fun, axis_name, axis_size, global_axis_size, devices, name,
in_axes, donated_invars,
is_explicit_global_axis_size, avals,
lowering_parameters=mlir.LoweringParameters(), closed_jaxpr=closed_jaxpr,
backend=xc_backend, replicas=replicas, shards=shards, pci=pci)
lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(),
closed_jaxpr=closed_jaxpr, backend=xc_backend, replicas=replicas,
shards=shards, pci=pci)
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
@ -735,6 +736,7 @@ def lower_parallel_callable(
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
*,
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
closed_jaxpr: core.ClosedJaxpr,
backend: xc.Client,
@ -813,7 +815,7 @@ def lower_parallel_callable(
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
backend.platform)
module_name = f"pmap_{fun.__name__}"
platforms = lowering_parameters.platforms or (backend.platform,)
platforms = lowering_platforms or (backend.platform,)
with maybe_extend_axis_env(axis_name, global_axis_size, None):
ordered_effects = list(
effects.ordered_effects.filter_in(closed_jaxpr.effects))
@ -1956,6 +1958,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
donated_invars, name_stack, all_default_mem_kind,
inout_aliases: None | tuple[None | int, ...],
propagated_out_mem_kinds: tuple[None | str, ...],
platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings._gspmd_shardings
@ -2016,8 +2019,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
closed_jaxpr,
ordered_effects=ordered_effects,
backend_or_name=backend,
# Optionally, override the lowering platform
platforms=lowering_parameters.platforms or (backend.platform,),
platforms=platforms,
axis_context=axis_ctx,
name_stack=name_stack,
donated_args=donated_invars,
@ -2166,9 +2168,10 @@ def lower_sharding_computation(
*,
keep_unused: bool,
inline: bool,
devices_from_context: Sequence[xc.Device] | None = None,
devices_from_context: Sequence[xc.Device] | None,
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None = None,
pgle_profiler: profiler.PGLEProfiler | None,
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -2212,7 +2215,7 @@ def lower_sharding_computation(
for js, source_info in util.stable_unique(jaxpr_sharding))),
devices_from_context)
platforms = lowering_parameters.platforms or (backend.platform,)
platforms = lowering_platforms or (backend.platform,)
# TODO(yashkatariya): Enable this when offload APIs are stable.
# transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
@ -2252,7 +2255,8 @@ def lower_sharding_computation(
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
tuple(da_object) if prim_requires_devices else None, donated_invars,
name_stack, all_default_mem_kind, inout_aliases,
propagated_out_mem_kinds, lowering_parameters=lowering_parameters)
propagated_out_mem_kinds, platforms,
lowering_parameters=lowering_parameters)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
@ -2316,10 +2320,11 @@ def lower_mesh_computation(
spmd_lowering: bool,
global_in_avals: Sequence[core.ShapedArray],
tiling_method: TilingMethod | None,
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters) -> MeshComputation:
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
platforms = lowering_parameters.platforms or (backend.platform,)
platforms = lowering_platforms or (backend.platform,)
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
global_axis_sizes = mesh.shape

View File

@ -707,7 +707,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
f, 'xmap', name, mesh,
in_shardings, out_shardings, donated_invars,
use_spmd_lowering, in_avals,
tiling_method=tiling_method,
tiling_method=tiling_method, lowering_platforms=None,
lowering_parameters=lowering_parameters)
else:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
@ -716,7 +716,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
(None,) * len(in_avals), (None,) * len(out_avals),
donated_invars, keep_unused=True, inline=False,
devices_from_context=None, lowering_parameters=lowering_parameters)
devices_from_context=None, lowering_platforms=None,
lowering_parameters=lowering_parameters, pgle_profiler=None)
class EvaluationPlan(NamedTuple):

View File

@ -500,16 +500,13 @@ def _make_jit_wrapper(jit_info: PjitInfo):
@api_boundary
def trace(*args, **kwargs) -> stages.Traced:
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
(args_flat, params, in_avals, in_tree, out_tree, donated_invars,
arg_names, num_consts, _) = _infer_params(jit_info, args, kwargs)
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
args_info = stages.make_args_info(in_tree, in_avals, donate_argnums)
lower_callable = partial(_resolve_and_lower, args_flat, **params,
lowering_parameters=lowering_parameters)
pgle_profiler=None)
return stages.Traced(params['jaxpr'], args_info, params["name"], out_tree,
lower_callable, args_flat, arg_names, num_consts)
@ -1497,7 +1494,7 @@ def _resolve_in_shardings(
def _resolve_and_lower(
args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
lowering_parameters, pgle_profiler=None):
lowering_platforms, lowering_parameters, pgle_profiler):
in_shardings = _resolve_in_shardings(
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
@ -1506,6 +1503,7 @@ def _resolve_and_lower(
lowered = _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
return lowered
@ -1540,7 +1538,8 @@ def _pjit_call_impl_python(
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline, lowering_parameters=mlir.LoweringParameters(),
inline=inline, lowering_platforms=None,
lowering_parameters=mlir.LoweringParameters(),
pgle_profiler=pgle_profiler
).compile(compile_options)
@ -1659,6 +1658,7 @@ def _pjit_lower_cached(
keep_unused: bool,
inline: bool,
*,
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
if resource_env is not None:
@ -1679,6 +1679,7 @@ def _pjit_lower_cached(
jaxpr, api_name, name, mesh,
in_shardings, out_shardings, donated_invars,
True, jaxpr.in_avals, tiling_method=None,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters)
else:
return pxla.lower_sharding_computation(
@ -1687,6 +1688,7 @@ def _pjit_lower_cached(
keep_unused=keep_unused, inline=inline,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)

View File

@ -30,6 +30,7 @@ executable protocols described above.
"""
from __future__ import annotations
import functools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
@ -446,9 +447,14 @@ class Traced(Stage):
return self._out_tree.unflatten(
[OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals])
def lower(self):
lowering = self._lower_callable()
return Lowered(lowering, self.args_info, self._out_tree)
def lower(self, lowering_platforms: tuple[str, ...] | None = None,
_private_parameters: mlir.LoweringParameters | None = None):
if _private_parameters is None:
_private_parameters = mlir.LoweringParameters()
new_callable = functools.partial(
self._lower_callable, lowering_platforms=lowering_platforms,
lowering_parameters=_private_parameters)
return Lowered(new_callable(), self.args_info, self._out_tree)
class Compiled(Stage):

View File

@ -4728,6 +4728,13 @@ class APITest(jtu.JaxTestCase):
out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash
self.assertEqual(out, 3)
def test_lowering_platform_aot(self):
@jax.jit
def f(x):
return x * 2
f.trace(jnp.arange(8)).lower(lowering_platforms=('tpu',)) # doesn't crash
class RematTest(jtu.JaxTestCase):
@ -10731,9 +10738,11 @@ class OverrideLoweringTest(jtu.JaxTestCase):
rules = ((jax.lax.sharding_constraint_p, wsc_as_noop),)
lowered_ir = (
jax.jit(f)
.lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16),
_experimental_lowering_parameters=mlir.LoweringParameters(
override_lowering_rules=rules)).as_text())
.trace(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16))
.lower(_private_parameters=mlir.LoweringParameters(
override_lowering_rules=rules))
.as_text()
)
self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)