diff --git a/jax/_src/api.py b/jax/_src/api.py index 3b1fdd6b3..ef18f3eba 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1819,8 +1819,6 @@ def _cpp_pmap( @api_boundary def trace(*args, **kwargs): - lowering_parameters = kwargs.pop( - '_experimental_lowering_parameters', mlir.LoweringParameters()) p = _prepare_pmap( fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, devices, backend, axis_size, args, kwargs) @@ -1842,7 +1840,6 @@ def _cpp_pmap( donated_invars=p.donated_invars, is_explicit_global_axis_size=p.is_explicit_global_axis_size, avals=abstract_args, - lowering_parameters=lowering_parameters, closed_jaxpr=closed_jaxpr, backend=xc_backend, replicas=replicas, diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index d5bddac94..0689f98b4 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -424,7 +424,7 @@ def export_back_compat( """ def do_export(*args_specs, **kwargs_specs) -> Exported: - if hasattr(fun_jax, "lower"): + if hasattr(fun_jax, "trace"): # If we have a pjit or pmap already we do not wrap with another, and we # allow shardings. wrapped_fun_jax = fun_jax @@ -434,8 +434,6 @@ def export_back_compat( # an error if the lowered function contains non-replicated sharding annotations. wrapped_fun_jax = jax.jit(fun_jax) - has_trace = hasattr(wrapped_fun_jax, "trace") - if lowering_platforms is not None: actual_lowering_platforms = tuple(lowering_platforms) else: @@ -457,25 +455,12 @@ def export_back_compat( self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", other_descr=shape_poly.args_kwargs_path_to_str(k_path)) - if has_trace: - traced = wrapped_fun_jax.trace( # type: ignore - *args_specs, **kwargs_specs, - _experimental_lowering_parameters=mlir.LoweringParameters( - platforms=actual_lowering_platforms, - for_export=True, - )) - jaxpr, fun_name = traced.jaxpr, traced.fun_name - lowered = traced.lower() - else: - lowered = wrapped_fun_jax.lower( - *args_specs, **kwargs_specs, - _experimental_lowering_parameters=mlir.LoweringParameters( - platforms=actual_lowering_platforms, - for_export=True, - )) - jaxpr, fun_name = None, util.fun_name(wrapped_fun_jax) + traced = wrapped_fun_jax.trace(*args_specs, **kwargs_specs) + lowered = traced.lower( + lowering_platforms=actual_lowering_platforms, + _private_parameters=mlir.LoweringParameters(for_export=True)) return _export_lowered( - lowered, jaxpr, fun_name, + lowered, traced.jaxpr, traced.fun_name, disabled_checks=disabled_checks, _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) return do_export @@ -553,16 +538,12 @@ def export( self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", other_descr=shape_poly.args_kwargs_path_to_str(k_path)) - traced = fun_jit.trace( - *args_specs, **kwargs_specs, - _experimental_lowering_parameters=mlir.LoweringParameters( - platforms=actual_lowering_platforms, - for_export=True, - )) - jaxpr, fun_name = traced.jaxpr, traced.fun_name - lowered = traced.lower() + traced = fun_jit.trace(*args_specs, **kwargs_specs) + lowered = traced.lower( + lowering_platforms=actual_lowering_platforms, + _private_parameters=mlir.LoweringParameters(for_export=True)) return _export_lowered( - lowered, jaxpr, fun_name, + lowered, traced.jaxpr, traced.fun_name, disabled_checks=disabled_checks) return do_export diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ba429dfe4..a3e42bce3 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -546,13 +546,6 @@ class LoweringParameters: # existing Jax rules. override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None - # The current lowering platforms, a non-empty tuple containing some of - # 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are - # doing multi-platform lowering, otherwise it can specify cross-platform - # lowering. The value None specifies the default lowering platform. - # This is used only in export and jax2tf. - platforms: tuple[str, ...] | None = None - # Signals that the entire computation being lowered operates on global # constants. This will result in adding jax.global_constant attributes # to the arguments of all functions that are created, e.g., floor_divide. @@ -621,8 +614,7 @@ class ModuleContext: module: ir.Module | None = None, ip: ir.InsertionPoint | None = None, symbol_table: ir.SymbolTable | None = None, - cached_primitive_lowerings: None | (dict[Any, - func_dialect.FuncOp]) = None, + cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, traceback_caches: None | TracebackCaches = None, shape_poly_state = None): @@ -948,8 +940,7 @@ def lower_jaxpr_to_module( channel_iterator=channel_iter, host_callbacks=host_callbacks, lowering_parameters=lowering_parameters, - shape_poly_state=ShapePolyLoweringState( - dim_vars, lowering_parameters.platforms)) + shape_poly_state=ShapePolyLoweringState(dim_vars, platforms)) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d795abcde..fe50230fc 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -592,8 +592,9 @@ def parallel_callable(fun: lu.WrappedFun, fun, axis_name, axis_size, global_axis_size, devices, name, in_axes, donated_invars, is_explicit_global_axis_size, avals, - lowering_parameters=mlir.LoweringParameters(), closed_jaxpr=closed_jaxpr, - backend=xc_backend, replicas=replicas, shards=shards, pci=pci) + lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(), + closed_jaxpr=closed_jaxpr, backend=xc_backend, replicas=replicas, + shards=shards, pci=pci) pmap_executable = pmap_computation.compile() return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) @@ -735,6 +736,7 @@ def lower_parallel_callable( is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, + lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, closed_jaxpr: core.ClosedJaxpr, backend: xc.Client, @@ -813,7 +815,7 @@ def lower_parallel_callable( tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals), backend.platform) module_name = f"pmap_{fun.__name__}" - platforms = lowering_parameters.platforms or (backend.platform,) + platforms = lowering_platforms or (backend.platform,) with maybe_extend_axis_env(axis_name, global_axis_size, None): ordered_effects = list( effects.ordered_effects.filter_in(closed_jaxpr.effects)) @@ -1956,6 +1958,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, donated_invars, name_stack, all_default_mem_kind, inout_aliases: None | tuple[None | int, ...], propagated_out_mem_kinds: tuple[None | str, ...], + platforms: tuple[str, ...], lowering_parameters: mlir.LoweringParameters): jaxpr = closed_jaxpr.jaxpr in_shardings = semantic_in_shardings._gspmd_shardings @@ -2016,8 +2019,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - # Optionally, override the lowering platform - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, @@ -2166,9 +2168,10 @@ def lower_sharding_computation( *, keep_unused: bool, inline: bool, - devices_from_context: Sequence[xc.Device] | None = None, + devices_from_context: Sequence[xc.Device] | None, + lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, - pgle_profiler: profiler.PGLEProfiler | None = None, + pgle_profiler: profiler.PGLEProfiler | None, ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -2212,7 +2215,7 @@ def lower_sharding_computation( for js, source_info in util.stable_unique(jaxpr_sharding))), devices_from_context) - platforms = lowering_parameters.platforms or (backend.platform,) + platforms = lowering_platforms or (backend.platform,) # TODO(yashkatariya): Enable this when offload APIs are stable. # transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) @@ -2252,7 +2255,8 @@ def lower_sharding_computation( semantic_out_shardings, in_layouts, out_layouts, len(da_object), tuple(da_object) if prim_requires_devices else None, donated_invars, name_stack, all_default_mem_kind, inout_aliases, - propagated_out_mem_kinds, lowering_parameters=lowering_parameters) + propagated_out_mem_kinds, platforms, + lowering_parameters=lowering_parameters) # backend and device_assignment is passed through to MeshExecutable because # if keep_unused=False and all in_shardings are pruned, then there is no way @@ -2316,10 +2320,11 @@ def lower_mesh_computation( spmd_lowering: bool, global_in_avals: Sequence[core.ShapedArray], tiling_method: TilingMethod | None, + lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters) -> MeshComputation: assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) - platforms = lowering_parameters.platforms or (backend.platform,) + platforms = lowering_platforms or (backend.platform,) name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) global_axis_sizes = mesh.shape diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 77896d7f9..10713b624 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -707,7 +707,7 @@ def make_xmap_callable(fun: lu.WrappedFun, f, 'xmap', name, mesh, in_shardings, out_shardings, donated_invars, use_spmd_lowering, in_avals, - tiling_method=tiling_method, + tiling_method=tiling_method, lowering_platforms=None, lowering_parameters=lowering_parameters) else: jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals) @@ -716,7 +716,8 @@ def make_xmap_callable(fun: lu.WrappedFun, (UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals), (None,) * len(in_avals), (None,) * len(out_avals), donated_invars, keep_unused=True, inline=False, - devices_from_context=None, lowering_parameters=lowering_parameters) + devices_from_context=None, lowering_platforms=None, + lowering_parameters=lowering_parameters, pgle_profiler=None) class EvaluationPlan(NamedTuple): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 9e5b54ce7..2169e2f04 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -500,16 +500,13 @@ def _make_jit_wrapper(jit_info: PjitInfo): @api_boundary def trace(*args, **kwargs) -> stages.Traced: - lowering_parameters = kwargs.pop( - '_experimental_lowering_parameters', mlir.LoweringParameters()) - (args_flat, params, in_avals, in_tree, out_tree, donated_invars, arg_names, num_consts, _) = _infer_params(jit_info, args, kwargs) donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d) args_info = stages.make_args_info(in_tree, in_avals, donate_argnums) lower_callable = partial(_resolve_and_lower, args_flat, **params, - lowering_parameters=lowering_parameters) + pgle_profiler=None) return stages.Traced(params['jaxpr'], args_info, params["name"], out_tree, lower_callable, args_flat, arg_names, num_consts) @@ -1497,7 +1494,7 @@ def _resolve_in_shardings( def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_parameters, pgle_profiler=None): + lowering_platforms, lowering_parameters, pgle_profiler): in_shardings = _resolve_in_shardings( args, in_shardings, out_shardings, resource_env.physical_mesh if resource_env is not None else None) @@ -1506,6 +1503,7 @@ def _resolve_and_lower( lowered = _pjit_lower( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, + lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) return lowered @@ -1540,7 +1538,8 @@ def _pjit_call_impl_python( out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline, lowering_parameters=mlir.LoweringParameters(), + inline=inline, lowering_platforms=None, + lowering_parameters=mlir.LoweringParameters(), pgle_profiler=pgle_profiler ).compile(compile_options) @@ -1659,6 +1658,7 @@ def _pjit_lower_cached( keep_unused: bool, inline: bool, *, + lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): if resource_env is not None: @@ -1679,6 +1679,7 @@ def _pjit_lower_cached( jaxpr, api_name, name, mesh, in_shardings, out_shardings, donated_invars, True, jaxpr.in_avals, tiling_method=None, + lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters) else: return pxla.lower_sharding_computation( @@ -1687,6 +1688,7 @@ def _pjit_lower_cached( keep_unused=keep_unused, inline=inline, devices_from_context=( None if mesh is None or mesh.empty else list(mesh.devices.flat)), + lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 73373912e..93e30be45 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,6 +30,7 @@ executable protocols described above. """ from __future__ import annotations +import functools from collections.abc import Sequence from dataclasses import dataclass from typing import Any, NamedTuple, Protocol, Union, runtime_checkable @@ -446,9 +447,14 @@ class Traced(Stage): return self._out_tree.unflatten( [OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals]) - def lower(self): - lowering = self._lower_callable() - return Lowered(lowering, self.args_info, self._out_tree) + def lower(self, lowering_platforms: tuple[str, ...] | None = None, + _private_parameters: mlir.LoweringParameters | None = None): + if _private_parameters is None: + _private_parameters = mlir.LoweringParameters() + new_callable = functools.partial( + self._lower_callable, lowering_platforms=lowering_platforms, + lowering_parameters=_private_parameters) + return Lowered(new_callable(), self.args_info, self._out_tree) class Compiled(Stage): diff --git a/tests/api_test.py b/tests/api_test.py index 3167bc93c..0bd58fe23 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4728,6 +4728,13 @@ class APITest(jtu.JaxTestCase): out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash self.assertEqual(out, 3) + def test_lowering_platform_aot(self): + @jax.jit + def f(x): + return x * 2 + + f.trace(jnp.arange(8)).lower(lowering_platforms=('tpu',)) # doesn't crash + class RematTest(jtu.JaxTestCase): @@ -10731,9 +10738,11 @@ class OverrideLoweringTest(jtu.JaxTestCase): rules = ((jax.lax.sharding_constraint_p, wsc_as_noop),) lowered_ir = ( jax.jit(f) - .lower(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16), - _experimental_lowering_parameters=mlir.LoweringParameters( - override_lowering_rules=rules)).as_text()) + .trace(jax.ShapeDtypeStruct((2, 4), dtype=jnp.bfloat16)) + .lower(_private_parameters=mlir.LoweringParameters( + override_lowering_rules=rules)) + .as_text() + ) self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)