Implement lower in terms of specialize

PiperOrigin-RevId: 641005643
This commit is contained in:
Yash Katariya 2024-06-06 13:38:16 -07:00 committed by jax authors
parent 90c83bb1e2
commit aee62e4874
4 changed files with 48 additions and 86 deletions

View File

@ -1811,6 +1811,10 @@ def _cpp_pmap(
pmap_f = wraps(fun)(cpp_mapped_f)
@api_boundary
def lower(*args, **kwargs):
return specialize(*args, **kwargs).lower()
@api_boundary
def specialize(*args, **kwargs):
lowering_parameters = kwargs.pop(
@ -1819,18 +1823,7 @@ def _cpp_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args))
lower_callable = partial(
pxla.lower_parallel_callable, p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
name=p.flat_fun.__name__,
in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk,
donated_invars=p.donated_invars,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_parameters=lowering_parameters)
jaxpr, _, _, _, _ = pxla.get_pmap_jaxpr(
closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr(
p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
@ -1838,13 +1831,26 @@ def _cpp_pmap(
in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk,
avals=abstract_args)
lower_callable = partial(
pxla.lower_parallel_callable, p.flat_fun, axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
name=p.flat_fun.__name__,
in_axes=p.in_axes_flat,
donated_invars=p.donated_invars,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_parameters=lowering_parameters,
closed_jaxpr=closed_jaxpr,
backend=xc_backend,
replicas=replicas,
shards=shards,
pci=pci)
args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple)
return stages.Specialized(jaxpr, args_info, p.flat_fun.__name__,
return stages.Specialized(closed_jaxpr, args_info, p.flat_fun.__name__,
p.out_tree(), lower_callable)
pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, donate_tuple)
pmap_f.lower = lower
pmap_f.specialize = specialize
return pmap_f
@ -1852,47 +1858,6 @@ def _cpp_pmap(
_pmap_cache_clears = weakref.WeakSet() # type: ignore
def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
devices, backend, axis_size, donate_tuple): # noqa: F811
"""Make a ``lower`` method for pmapped functions."""
# If the function we returned from ``pmap`` were a class instance,
# this might naturally be a method, with ``fun`` as a ``self`` and
# all the other arguments stored as attributes.
@api_boundary
def lower(*args, **kwargs) -> stages.Lowered:
"""Lower a parallel-mapped form of this function for the given arguments.
A parallel-mapped and lowered function is staged out of Python and
translated to a compiler's input language, possibly in a
backend-dependent manner. It is ready for compilation but is not yet
compiled. It represents a function intended for SPMD execution on
multiple devices.
Returns:
A ``Lowered`` instance representing the post-map lowering.
"""
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args))
computation = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
name=p.flat_fun.__name__,
in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk,
donated_invars=p.donated_invars,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_parameters=lowering_parameters)
return stages.Lowered.from_flat_info(
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
return lower
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> tuple[Any, ...]:

View File

@ -557,11 +557,17 @@ def parallel_callable(fun: lu.WrappedFun,
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
*avals):
closed_jaxpr, xc_backend, replicas, shards, pci = get_pmap_jaxpr(
fun, backend_name, axis_name,
axis_size=axis_size, global_axis_size=global_axis_size,
devices=devices, name=fun.__name__, in_axes=in_axes,
out_axes_thunk=out_axes_thunk, avals=avals)
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
fun, axis_name, axis_size, global_axis_size, devices, name,
in_axes, donated_invars,
is_explicit_global_axis_size, avals,
lowering_parameters=mlir.LoweringParameters())
lowering_parameters=mlir.LoweringParameters(), closed_jaxpr=closed_jaxpr,
backend=xc_backend, replicas=replicas, shards=shards, pci=pci)
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
@ -693,19 +699,22 @@ def get_pmap_jaxpr(
@profiler.annotate_function
def lower_parallel_callable(
fun: lu.WrappedFun,
backend_name: str | None,
axis_name: core.AxisName,
axis_size: int,
global_axis_size: int,
devices: Sequence[xc.Device] | None,
name: str,
in_axes: Iterable[int | None],
out_axes_thunk: Callable[[], Sequence[int | None]],
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
*,
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
lowering_parameters: mlir.LoweringParameters,
closed_jaxpr: core.ClosedJaxpr,
backend: xc.Client,
replicas: ReplicaInfo,
shards: ShardInfo,
pci: ParallelCallableInfo) -> PmapComputation:
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
@ -716,9 +725,6 @@ def lower_parallel_callable(
f"Specified axis_size {global_axis_size} doesn't match received "
f"axis_size {axis_size}.")
closed_jaxpr, backend, replicas, shards, pci = get_pmap_jaxpr(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, avals)
jaxpr = closed_jaxpr.jaxpr
no_nested_sharding = False

View File

@ -477,29 +477,18 @@ def _make_jit_wrapper(jit_info: PjitInfo):
@api_boundary
def lower(*args, **kwargs):
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
(args_flat, params, in_tree, out_tree,
donated_invars, arg_names, _) = _infer_params(jit_info, args, kwargs)
specialized = specialize(*args, **kwargs)
try:
lowering = _resolve_and_lower(
args_flat, **params, lowering_parameters=lowering_parameters)
return specialized.lower()
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
fun = jit_info.fun
fun_name = getattr(fun, '__qualname__',
getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
fun_name, fails, specialized._args_flat, 'jit', specialized._arg_names)
raise ValueError(msg) from None
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
jaxpr = params["jaxpr"]
return stages.Lowered.from_flat_info(
lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree)
@api_boundary
def eval_shape(*args, **kwargs):
_, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
@ -514,8 +503,8 @@ def _make_jit_wrapper(jit_info: PjitInfo):
lowering_parameters = kwargs.pop(
'_experimental_lowering_parameters', mlir.LoweringParameters())
args_flat, params, in_tree, out_tree, donated_invars, _, _ = _infer_params(
jit_info, args, kwargs)
(args_flat, params, in_tree, out_tree, donated_invars,
arg_names, _) = _infer_params(jit_info, args, kwargs)
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
jaxpr = params['jaxpr']
@ -523,7 +512,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
lower_callable = partial(_resolve_and_lower, args_flat, **params,
lowering_parameters=lowering_parameters)
return stages.Specialized(jaxpr, args_info, params["name"], out_tree,
lower_callable)
lower_callable, args_flat, arg_names)
wrapped = _cpp_pjit(jit_info)
wrapped.lower = lower
@ -654,7 +643,6 @@ def _infer_params(jit_info, args, kwargs):
args_flat = [*implicit_args, *explicit_args]
num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked)
num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in attrs_tracked)
num_extra_args = len(implicit_args) + num_states_in + len(consts)
in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat

View File

@ -426,15 +426,18 @@ class CompiledCallParams(NamedTuple):
class Specialized(Stage):
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable"]
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
"_args_flat", "_arg_names"]
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
lower_callable):
lower_callable, args_flat=None, arg_names=None):
self.jaxpr = jaxpr
self.args_info = args_info
self.fun_name = fun_name
self._out_tree = out_tree
self._lower_callable = lower_callable
self._args_flat = args_flat
self._arg_names = arg_names
@property
def out_info(self):