From aee62e4874ec2e9079adfccc5281c85f57f25c54 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 6 Jun 2024 13:38:16 -0700 Subject: [PATCH] Implement `lower` in terms of `specialize` PiperOrigin-RevId: 641005643 --- jax/_src/api.py | 79 ++++++++++------------------------- jax/_src/interpreters/pxla.py | 24 +++++++---- jax/_src/pjit.py | 24 +++-------- jax/_src/stages.py | 7 +++- 4 files changed, 48 insertions(+), 86 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index a3d99eda3..aa4579475 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1811,6 +1811,10 @@ def _cpp_pmap( pmap_f = wraps(fun)(cpp_mapped_f) + @api_boundary + def lower(*args, **kwargs): + return specialize(*args, **kwargs).lower() + @api_boundary def specialize(*args, **kwargs): lowering_parameters = kwargs.pop( @@ -1819,18 +1823,7 @@ def _cpp_pmap( fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, devices, backend, axis_size, args, kwargs) abstract_args = list(map(shaped_abstractify, p.flat_args)) - lower_callable = partial( - pxla.lower_parallel_callable, p.flat_fun, backend, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - out_axes_thunk=p.out_axes_thunk, - donated_invars=p.donated_invars, - is_explicit_global_axis_size=p.is_explicit_global_axis_size, - avals=abstract_args, - lowering_parameters=lowering_parameters) - jaxpr, _, _, _, _ = pxla.get_pmap_jaxpr( + closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( p.flat_fun, backend, axis_name, axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, devices=p.devices, @@ -1838,13 +1831,26 @@ def _cpp_pmap( in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk, avals=abstract_args) + lower_callable = partial( + pxla.lower_parallel_callable, p.flat_fun, axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + donated_invars=p.donated_invars, + is_explicit_global_axis_size=p.is_explicit_global_axis_size, + avals=abstract_args, + lowering_parameters=lowering_parameters, + closed_jaxpr=closed_jaxpr, + backend=xc_backend, + replicas=replicas, + shards=shards, + pci=pci) args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) - return stages.Specialized(jaxpr, args_info, p.flat_fun.__name__, + return stages.Specialized(closed_jaxpr, args_info, p.flat_fun.__name__, p.out_tree(), lower_callable) - pmap_f.lower = _pmap_lower( - fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, - backend, axis_size, donate_tuple) + pmap_f.lower = lower pmap_f.specialize = specialize return pmap_f @@ -1852,47 +1858,6 @@ def _cpp_pmap( _pmap_cache_clears = weakref.WeakSet() # type: ignore -def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, - devices, backend, axis_size, donate_tuple): # noqa: F811 - """Make a ``lower`` method for pmapped functions.""" - # If the function we returned from ``pmap`` were a class instance, - # this might naturally be a method, with ``fun`` as a ``self`` and - # all the other arguments stored as attributes. - @api_boundary - def lower(*args, **kwargs) -> stages.Lowered: - """Lower a parallel-mapped form of this function for the given arguments. - - A parallel-mapped and lowered function is staged out of Python and - translated to a compiler's input language, possibly in a - backend-dependent manner. It is ready for compilation but is not yet - compiled. It represents a function intended for SPMD execution on - multiple devices. - - Returns: - A ``Lowered`` instance representing the post-map lowering. - """ - lowering_parameters = kwargs.pop( - '_experimental_lowering_parameters', mlir.LoweringParameters()) - p = _prepare_pmap( - fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, - devices, backend, axis_size, args, kwargs) - abstract_args = list(map(shaped_abstractify, p.flat_args)) - computation = pxla.lower_parallel_callable( - p.flat_fun, backend, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - out_axes_thunk=p.out_axes_thunk, - donated_invars=p.donated_invars, - is_explicit_global_axis_size=p.is_explicit_global_axis_size, - avals=abstract_args, - lowering_parameters=lowering_parameters) - return stages.Lowered.from_flat_info( - computation, p.in_tree, abstract_args, donate_tuple, p.out_tree()) - - return lower - def jvp( fun: Callable, primals, tangents, has_aux: bool = False ) -> tuple[Any, ...]: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 686ee0ced..7ddac0e6f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -557,11 +557,17 @@ def parallel_callable(fun: lu.WrappedFun, donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, *avals): + closed_jaxpr, xc_backend, replicas, shards, pci = get_pmap_jaxpr( + fun, backend_name, axis_name, + axis_size=axis_size, global_axis_size=global_axis_size, + devices=devices, name=fun.__name__, in_axes=in_axes, + out_axes_thunk=out_axes_thunk, avals=avals) pmap_computation = lower_parallel_callable( - fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, - in_axes, out_axes_thunk, donated_invars, + fun, axis_name, axis_size, global_axis_size, devices, name, + in_axes, donated_invars, is_explicit_global_axis_size, avals, - lowering_parameters=mlir.LoweringParameters()) + lowering_parameters=mlir.LoweringParameters(), closed_jaxpr=closed_jaxpr, + backend=xc_backend, replicas=replicas, shards=shards, pci=pci) pmap_executable = pmap_computation.compile() return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) @@ -693,19 +699,22 @@ def get_pmap_jaxpr( @profiler.annotate_function def lower_parallel_callable( fun: lu.WrappedFun, - backend_name: str | None, axis_name: core.AxisName, axis_size: int, global_axis_size: int, devices: Sequence[xc.Device] | None, name: str, in_axes: Iterable[int | None], - out_axes_thunk: Callable[[], Sequence[int | None]], donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, - lowering_parameters: mlir.LoweringParameters) -> PmapComputation: + lowering_parameters: mlir.LoweringParameters, + closed_jaxpr: core.ClosedJaxpr, + backend: xc.Client, + replicas: ReplicaInfo, + shards: ShardInfo, + pci: ParallelCallableInfo) -> PmapComputation: # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: @@ -716,9 +725,6 @@ def lower_parallel_callable( f"Specified axis_size {global_axis_size} doesn't match received " f"axis_size {axis_size}.") - closed_jaxpr, backend, replicas, shards, pci = get_pmap_jaxpr( - fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, - in_axes, out_axes_thunk, avals) jaxpr = closed_jaxpr.jaxpr no_nested_sharding = False diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 4a3d8400a..a139227ad 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -477,29 +477,18 @@ def _make_jit_wrapper(jit_info: PjitInfo): @api_boundary def lower(*args, **kwargs): - lowering_parameters = kwargs.pop( - '_experimental_lowering_parameters', mlir.LoweringParameters()) - - (args_flat, params, in_tree, out_tree, - donated_invars, arg_names, _) = _infer_params(jit_info, args, kwargs) + specialized = specialize(*args, **kwargs) try: - lowering = _resolve_and_lower( - args_flat, **params, lowering_parameters=lowering_parameters) + return specialized.lower() except pxla.DeviceAssignmentMismatchError as e: fails, = e.args - api_name = 'jit' if params['resource_env'] is None else 'pjit' fun = jit_info.fun fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) msg = _device_assignment_mismatch_error( - fun_name, fails, args_flat, api_name, arg_names) + fun_name, fails, specialized._args_flat, 'jit', specialized._arg_names) raise ValueError(msg) from None - donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d) - jaxpr = params["jaxpr"] - return stages.Lowered.from_flat_info( - lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree) - @api_boundary def eval_shape(*args, **kwargs): _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs) @@ -514,8 +503,8 @@ def _make_jit_wrapper(jit_info: PjitInfo): lowering_parameters = kwargs.pop( '_experimental_lowering_parameters', mlir.LoweringParameters()) - args_flat, params, in_tree, out_tree, donated_invars, _, _ = _infer_params( - jit_info, args, kwargs) + (args_flat, params, in_tree, out_tree, donated_invars, + arg_names, _) = _infer_params(jit_info, args, kwargs) donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d) jaxpr = params['jaxpr'] @@ -523,7 +512,7 @@ def _make_jit_wrapper(jit_info: PjitInfo): lower_callable = partial(_resolve_and_lower, args_flat, **params, lowering_parameters=lowering_parameters) return stages.Specialized(jaxpr, args_info, params["name"], out_tree, - lower_callable) + lower_callable, args_flat, arg_names) wrapped = _cpp_pjit(jit_info) wrapped.lower = lower @@ -654,7 +643,6 @@ def _infer_params(jit_info, args, kwargs): args_flat = [*implicit_args, *explicit_args] num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked) - num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in attrs_tracked) num_extra_args = len(implicit_args) + num_states_in + len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat diff --git a/jax/_src/stages.py b/jax/_src/stages.py index ba71ca655..90ad765a9 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -426,15 +426,18 @@ class CompiledCallParams(NamedTuple): class Specialized(Stage): - __slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable"] + __slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable", + "_args_flat", "_arg_names"] def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable): + lower_callable, args_flat=None, arg_names=None): self.jaxpr = jaxpr self.args_info = args_info self.fun_name = fun_name self._out_tree = out_tree self._lower_callable = lower_callable + self._args_flat = args_flat + self._arg_names = arg_names @property def out_info(self):