mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add lowering_platforms
to traced.lower()
to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold: 1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering. 2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`. Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write. Designed with @froystig! PiperOrigin-RevId: 644087787
This commit is contained in:
parent
be1f4ba380
commit
6ba16e0348
@ -1819,8 +1819,6 @@ def _cpp_pmap(
|
||||
|
||||
@api_boundary
|
||||
def trace(*args, **kwargs):
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
@ -1842,7 +1840,6 @@ def _cpp_pmap(
|
||||
donated_invars=p.donated_invars,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
avals=abstract_args,
|
||||
lowering_parameters=lowering_parameters,
|
||||
closed_jaxpr=closed_jaxpr,
|
||||
backend=xc_backend,
|
||||
replicas=replicas,
|
||||
|
@ -424,7 +424,7 @@ def export_back_compat(
|
||||
"""
|
||||
|
||||
def do_export(*args_specs, **kwargs_specs) -> Exported:
|
||||
if hasattr(fun_jax, "lower"):
|
||||
if hasattr(fun_jax, "trace"):
|
||||
# If we have a pjit or pmap already we do not wrap with another, and we
|
||||
# allow shardings.
|
||||
wrapped_fun_jax = fun_jax
|
||||
@ -434,8 +434,6 @@ def export_back_compat(
|
||||
# an error if the lowered function contains non-replicated sharding annotations.
|
||||
wrapped_fun_jax = jax.jit(fun_jax)
|
||||
|
||||
has_trace = hasattr(wrapped_fun_jax, "trace")
|
||||
|
||||
if lowering_platforms is not None:
|
||||
actual_lowering_platforms = tuple(lowering_platforms)
|
||||
else:
|
||||
@ -457,25 +455,12 @@ def export_back_compat(
|
||||
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
|
||||
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
|
||||
|
||||
if has_trace:
|
||||
traced = wrapped_fun_jax.trace( # type: ignore
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=actual_lowering_platforms,
|
||||
for_export=True,
|
||||
))
|
||||
jaxpr, fun_name = traced.jaxpr, traced.fun_name
|
||||
lowered = traced.lower()
|
||||
else:
|
||||
lowered = wrapped_fun_jax.lower(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=actual_lowering_platforms,
|
||||
for_export=True,
|
||||
))
|
||||
jaxpr, fun_name = None, util.fun_name(wrapped_fun_jax)
|
||||
traced = wrapped_fun_jax.trace(*args_specs, **kwargs_specs)
|
||||
lowered = traced.lower(
|
||||
lowering_platforms=actual_lowering_platforms,
|
||||
_private_parameters=mlir.LoweringParameters(for_export=True))
|
||||
return _export_lowered(
|
||||
lowered, jaxpr, fun_name,
|
||||
lowered, traced.jaxpr, traced.fun_name,
|
||||
disabled_checks=disabled_checks,
|
||||
_device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only)
|
||||
return do_export
|
||||
@ -553,16 +538,12 @@ def export(
|
||||
self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
|
||||
other_descr=shape_poly.args_kwargs_path_to_str(k_path))
|
||||
|
||||
traced = fun_jit.trace(
|
||||
*args_specs, **kwargs_specs,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=actual_lowering_platforms,
|
||||
for_export=True,
|
||||
))
|
||||
jaxpr, fun_name = traced.jaxpr, traced.fun_name
|
||||
lowered = traced.lower()
|
||||
traced = fun_jit.trace(*args_specs, **kwargs_specs)
|
||||
lowered = traced.lower(
|
||||
lowering_platforms=actual_lowering_platforms,
|
||||
_private_parameters=mlir.LoweringParameters(for_export=True))
|
||||
return _export_lowered(
|
||||
lowered, jaxpr, fun_name,
|
||||
lowered, traced.jaxpr, traced.fun_name,
|
||||
disabled_checks=disabled_checks)
|
||||
return do_export
|
||||
|
||||
|
@ -546,13 +546,6 @@ class LoweringParameters:
|
||||
# existing Jax rules.
|
||||
override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None
|
||||
|
||||
# The current lowering platforms, a non-empty tuple containing some of
|
||||
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
|
||||
# doing multi-platform lowering, otherwise it can specify cross-platform
|
||||
# lowering. The value None specifies the default lowering platform.
|
||||
# This is used only in export and jax2tf.
|
||||
platforms: tuple[str, ...] | None = None
|
||||
|
||||
# Signals that the entire computation being lowered operates on global
|
||||
# constants. This will result in adding jax.global_constant attributes
|
||||
# to the arguments of all functions that are created, e.g., floor_divide.
|
||||
@ -621,8 +614,7 @@ class ModuleContext:
|
||||
module: ir.Module | None = None,
|
||||
ip: ir.InsertionPoint | None = None,
|
||||
symbol_table: ir.SymbolTable | None = None,
|
||||
cached_primitive_lowerings: None | (dict[Any,
|
||||
func_dialect.FuncOp]) = None,
|
||||
cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None,
|
||||
traceback_caches: None | TracebackCaches = None,
|
||||
shape_poly_state = None):
|
||||
|
||||
@ -948,8 +940,7 @@ def lower_jaxpr_to_module(
|
||||
channel_iterator=channel_iter,
|
||||
host_callbacks=host_callbacks,
|
||||
lowering_parameters=lowering_parameters,
|
||||
shape_poly_state=ShapePolyLoweringState(
|
||||
dim_vars, lowering_parameters.platforms))
|
||||
shape_poly_state=ShapePolyLoweringState(dim_vars, platforms))
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
# XLA computation preserves the module name.
|
||||
|
@ -592,8 +592,9 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
fun, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, donated_invars,
|
||||
is_explicit_global_axis_size, avals,
|
||||
lowering_parameters=mlir.LoweringParameters(), closed_jaxpr=closed_jaxpr,
|
||||
backend=xc_backend, replicas=replicas, shards=shards, pci=pci)
|
||||
lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(),
|
||||
closed_jaxpr=closed_jaxpr, backend=xc_backend, replicas=replicas,
|
||||
shards=shards, pci=pci)
|
||||
pmap_executable = pmap_computation.compile()
|
||||
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
|
||||
|
||||
@ -735,6 +736,7 @@ def lower_parallel_callable(
|
||||
is_explicit_global_axis_size: bool,
|
||||
avals: Sequence[core.AbstractValue],
|
||||
*,
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
closed_jaxpr: core.ClosedJaxpr,
|
||||
backend: xc.Client,
|
||||
@ -813,7 +815,7 @@ def lower_parallel_callable(
|
||||
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
|
||||
backend.platform)
|
||||
module_name = f"pmap_{fun.__name__}"
|
||||
platforms = lowering_parameters.platforms or (backend.platform,)
|
||||
platforms = lowering_platforms or (backend.platform,)
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None):
|
||||
ordered_effects = list(
|
||||
effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
@ -1956,6 +1958,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
donated_invars, name_stack, all_default_mem_kind,
|
||||
inout_aliases: None | tuple[None | int, ...],
|
||||
propagated_out_mem_kinds: tuple[None | str, ...],
|
||||
platforms: tuple[str, ...],
|
||||
lowering_parameters: mlir.LoweringParameters):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings._gspmd_shardings
|
||||
@ -2016,8 +2019,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
closed_jaxpr,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
# Optionally, override the lowering platform
|
||||
platforms=lowering_parameters.platforms or (backend.platform,),
|
||||
platforms=platforms,
|
||||
axis_context=axis_ctx,
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
@ -2166,9 +2168,10 @@ def lower_sharding_computation(
|
||||
*,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
devices_from_context: Sequence[xc.Device] | None = None,
|
||||
devices_from_context: Sequence[xc.Device] | None,
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None = None,
|
||||
pgle_profiler: profiler.PGLEProfiler | None,
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
@ -2212,7 +2215,7 @@ def lower_sharding_computation(
|
||||
for js, source_info in util.stable_unique(jaxpr_sharding))),
|
||||
devices_from_context)
|
||||
|
||||
platforms = lowering_parameters.platforms or (backend.platform,)
|
||||
platforms = lowering_platforms or (backend.platform,)
|
||||
# TODO(yashkatariya): Enable this when offload APIs are stable.
|
||||
# transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
||||
|
||||
@ -2252,7 +2255,8 @@ def lower_sharding_computation(
|
||||
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
|
||||
tuple(da_object) if prim_requires_devices else None, donated_invars,
|
||||
name_stack, all_default_mem_kind, inout_aliases,
|
||||
propagated_out_mem_kinds, lowering_parameters=lowering_parameters)
|
||||
propagated_out_mem_kinds, platforms,
|
||||
lowering_parameters=lowering_parameters)
|
||||
|
||||
# backend and device_assignment is passed through to MeshExecutable because
|
||||
# if keep_unused=False and all in_shardings are pruned, then there is no way
|
||||
@ -2316,10 +2320,11 @@ def lower_mesh_computation(
|
||||
spmd_lowering: bool,
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
tiling_method: TilingMethod | None,
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters) -> MeshComputation:
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
platforms = lowering_parameters.platforms or (backend.platform,)
|
||||
platforms = lowering_platforms or (backend.platform,)
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
global_axis_sizes = mesh.shape
|
||||
|
@ -707,7 +707,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
f, 'xmap', name, mesh,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
use_spmd_lowering, in_avals,
|
||||
tiling_method=tiling_method,
|
||||
tiling_method=tiling_method, lowering_platforms=None,
|
||||
lowering_parameters=lowering_parameters)
|
||||
else:
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
|
||||
@ -716,7 +716,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
|
||||
(None,) * len(in_avals), (None,) * len(out_avals),
|
||||
donated_invars, keep_unused=True, inline=False,
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters)
|
||||
devices_from_context=None, lowering_platforms=None,
|
||||
lowering_parameters=lowering_parameters, pgle_profiler=None)
|
||||
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
|
@ -500,16 +500,13 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def trace(*args, **kwargs) -> stages.Traced:
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
|
||||
(args_flat, params, in_avals, in_tree, out_tree, donated_invars,
|
||||
arg_names, num_consts, _) = _infer_params(jit_info, args, kwargs)
|
||||
|
||||
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
args_info = stages.make_args_info(in_tree, in_avals, donate_argnums)
|
||||
lower_callable = partial(_resolve_and_lower, args_flat, **params,
|
||||
lowering_parameters=lowering_parameters)
|
||||
pgle_profiler=None)
|
||||
return stages.Traced(params['jaxpr'], args_info, params["name"], out_tree,
|
||||
lower_callable, args_flat, arg_names, num_consts)
|
||||
|
||||
@ -1497,7 +1494,7 @@ def _resolve_in_shardings(
|
||||
def _resolve_and_lower(
|
||||
args, jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
|
||||
lowering_parameters, pgle_profiler=None):
|
||||
lowering_platforms, lowering_parameters, pgle_profiler):
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh if resource_env is not None else None)
|
||||
@ -1506,6 +1503,7 @@ def _resolve_and_lower(
|
||||
lowered = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, name, keep_unused, inline,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
pgle_profiler=pgle_profiler)
|
||||
return lowered
|
||||
@ -1540,7 +1538,8 @@ def _pjit_call_impl_python(
|
||||
out_shardings=out_shardings, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline, lowering_parameters=mlir.LoweringParameters(),
|
||||
inline=inline, lowering_platforms=None,
|
||||
lowering_parameters=mlir.LoweringParameters(),
|
||||
pgle_profiler=pgle_profiler
|
||||
).compile(compile_options)
|
||||
|
||||
@ -1659,6 +1658,7 @@ def _pjit_lower_cached(
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
*,
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
if resource_env is not None:
|
||||
@ -1679,6 +1679,7 @@ def _pjit_lower_cached(
|
||||
jaxpr, api_name, name, mesh,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
True, jaxpr.in_avals, tiling_method=None,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters)
|
||||
else:
|
||||
return pxla.lower_sharding_computation(
|
||||
@ -1687,6 +1688,7 @@ def _pjit_lower_cached(
|
||||
keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
pgle_profiler=pgle_profiler)
|
||||
|
||||
|
@ -30,6 +30,7 @@ executable protocols described above.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
|
||||
@ -446,9 +447,14 @@ class Traced(Stage):
|
||||
return self._out_tree.unflatten(
|
||||
[OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals])
|
||||
|
||||
def lower(self):
|
||||
lowering = self._lower_callable()
|
||||
return Lowered(lowering, self.args_info, self._out_tree)
|
||||
def lower(self, lowering_platforms: tuple[str, ...] | None = None,
|
||||
_private_parameters: mlir.LoweringParameters | None = None):
|
||||
if _private_parameters is None:
|
||||
_private_parameters = mlir.LoweringParameters()
|
||||
new_callable = functools.partial(
|
||||
self._lower_callable, lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=_private_parameters)
|
||||
return Lowered(new_callable(), self.args_info, self._out_tree)
|
||||
|
||||
|
||||
class Compiled(Stage):
|
||||
|
@ -4728,6 +4728,13 @@ class APITest(jtu.JaxTestCase):
|
||||
out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash
|
||||
self.assertEqual(out, 3)
|
||||
|
||||
def test_lowering_platform_aot(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
f.trace(jnp.arange(8)).lower(lowering_platforms=('tpu',)) # doesn't crash
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
@ -10731,9 +10738,11 @@ class OverrideLoweringTest(jtu.JaxTestCase):
|
||||
rules = ((jax.lax.sharding_constraint_p, wsc_as_noop),)
|
||||
lowered_ir = (
|
||||
jax.jit(f)
|
||||
.lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16),
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
override_lowering_rules=rules)).as_text())
|
||||
.trace(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16))
|
||||
.lower(_private_parameters=mlir.LoweringParameters(
|
||||
override_lowering_rules=rules))
|
||||
.as_text()
|
||||
)
|
||||
self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user