mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Implement lower
in terms of specialize
PiperOrigin-RevId: 641005643
This commit is contained in:
parent
90c83bb1e2
commit
aee62e4874
@ -1811,6 +1811,10 @@ def _cpp_pmap(
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
return specialize(*args, **kwargs).lower()
|
||||
|
||||
@api_boundary
|
||||
def specialize(*args, **kwargs):
|
||||
lowering_parameters = kwargs.pop(
|
||||
@ -1819,18 +1823,7 @@ def _cpp_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
abstract_args = list(map(shaped_abstractify, p.flat_args))
|
||||
lower_callable = partial(
|
||||
pxla.lower_parallel_callable, p.flat_fun, backend, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
name=p.flat_fun.__name__,
|
||||
in_axes=p.in_axes_flat,
|
||||
out_axes_thunk=p.out_axes_thunk,
|
||||
donated_invars=p.donated_invars,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
avals=abstract_args,
|
||||
lowering_parameters=lowering_parameters)
|
||||
jaxpr, _, _, _, _ = pxla.get_pmap_jaxpr(
|
||||
closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr(
|
||||
p.flat_fun, backend, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
@ -1838,13 +1831,26 @@ def _cpp_pmap(
|
||||
in_axes=p.in_axes_flat,
|
||||
out_axes_thunk=p.out_axes_thunk,
|
||||
avals=abstract_args)
|
||||
lower_callable = partial(
|
||||
pxla.lower_parallel_callable, p.flat_fun, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
name=p.flat_fun.__name__,
|
||||
in_axes=p.in_axes_flat,
|
||||
donated_invars=p.donated_invars,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
avals=abstract_args,
|
||||
lowering_parameters=lowering_parameters,
|
||||
closed_jaxpr=closed_jaxpr,
|
||||
backend=xc_backend,
|
||||
replicas=replicas,
|
||||
shards=shards,
|
||||
pci=pci)
|
||||
args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple)
|
||||
return stages.Specialized(jaxpr, args_info, p.flat_fun.__name__,
|
||||
return stages.Specialized(closed_jaxpr, args_info, p.flat_fun.__name__,
|
||||
p.out_tree(), lower_callable)
|
||||
|
||||
pmap_f.lower = _pmap_lower(
|
||||
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
|
||||
backend, axis_size, donate_tuple)
|
||||
pmap_f.lower = lower
|
||||
pmap_f.specialize = specialize
|
||||
|
||||
return pmap_f
|
||||
@ -1852,47 +1858,6 @@ def _cpp_pmap(
|
||||
_pmap_cache_clears = weakref.WeakSet() # type: ignore
|
||||
|
||||
|
||||
def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
devices, backend, axis_size, donate_tuple): # noqa: F811
|
||||
"""Make a ``lower`` method for pmapped functions."""
|
||||
# If the function we returned from ``pmap`` were a class instance,
|
||||
# this might naturally be a method, with ``fun`` as a ``self`` and
|
||||
# all the other arguments stored as attributes.
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs) -> stages.Lowered:
|
||||
"""Lower a parallel-mapped form of this function for the given arguments.
|
||||
|
||||
A parallel-mapped and lowered function is staged out of Python and
|
||||
translated to a compiler's input language, possibly in a
|
||||
backend-dependent manner. It is ready for compilation but is not yet
|
||||
compiled. It represents a function intended for SPMD execution on
|
||||
multiple devices.
|
||||
|
||||
Returns:
|
||||
A ``Lowered`` instance representing the post-map lowering.
|
||||
"""
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
abstract_args = list(map(shaped_abstractify, p.flat_args))
|
||||
computation = pxla.lower_parallel_callable(
|
||||
p.flat_fun, backend, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
name=p.flat_fun.__name__,
|
||||
in_axes=p.in_axes_flat,
|
||||
out_axes_thunk=p.out_axes_thunk,
|
||||
donated_invars=p.donated_invars,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
avals=abstract_args,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
|
||||
|
||||
return lower
|
||||
|
||||
def jvp(
|
||||
fun: Callable, primals, tangents, has_aux: bool = False
|
||||
) -> tuple[Any, ...]:
|
||||
|
@ -557,11 +557,17 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
donated_invars: Sequence[bool],
|
||||
is_explicit_global_axis_size: bool,
|
||||
*avals):
|
||||
closed_jaxpr, xc_backend, replicas, shards, pci = get_pmap_jaxpr(
|
||||
fun, backend_name, axis_name,
|
||||
axis_size=axis_size, global_axis_size=global_axis_size,
|
||||
devices=devices, name=fun.__name__, in_axes=in_axes,
|
||||
out_axes_thunk=out_axes_thunk, avals=avals)
|
||||
pmap_computation = lower_parallel_callable(
|
||||
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
fun, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, donated_invars,
|
||||
is_explicit_global_axis_size, avals,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
lowering_parameters=mlir.LoweringParameters(), closed_jaxpr=closed_jaxpr,
|
||||
backend=xc_backend, replicas=replicas, shards=shards, pci=pci)
|
||||
pmap_executable = pmap_computation.compile()
|
||||
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
|
||||
|
||||
@ -693,19 +699,22 @@ def get_pmap_jaxpr(
|
||||
@profiler.annotate_function
|
||||
def lower_parallel_callable(
|
||||
fun: lu.WrappedFun,
|
||||
backend_name: str | None,
|
||||
axis_name: core.AxisName,
|
||||
axis_size: int,
|
||||
global_axis_size: int,
|
||||
devices: Sequence[xc.Device] | None,
|
||||
name: str,
|
||||
in_axes: Iterable[int | None],
|
||||
out_axes_thunk: Callable[[], Sequence[int | None]],
|
||||
donated_invars: Sequence[bool],
|
||||
is_explicit_global_axis_size: bool,
|
||||
avals: Sequence[core.AbstractValue],
|
||||
*,
|
||||
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
closed_jaxpr: core.ClosedJaxpr,
|
||||
backend: xc.Client,
|
||||
replicas: ReplicaInfo,
|
||||
shards: ShardInfo,
|
||||
pci: ParallelCallableInfo) -> PmapComputation:
|
||||
# Determine global_axis_size for use in AxisEnv.
|
||||
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
||||
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
||||
@ -716,9 +725,6 @@ def lower_parallel_callable(
|
||||
f"Specified axis_size {global_axis_size} doesn't match received "
|
||||
f"axis_size {axis_size}.")
|
||||
|
||||
closed_jaxpr, backend, replicas, shards, pci = get_pmap_jaxpr(
|
||||
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, avals)
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
|
||||
no_nested_sharding = False
|
||||
|
@ -477,29 +477,18 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
|
||||
(args_flat, params, in_tree, out_tree,
|
||||
donated_invars, arg_names, _) = _infer_params(jit_info, args, kwargs)
|
||||
specialized = specialize(*args, **kwargs)
|
||||
try:
|
||||
lowering = _resolve_and_lower(
|
||||
args_flat, **params, lowering_parameters=lowering_parameters)
|
||||
return specialized.lower()
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
fun = jit_info.fun
|
||||
fun_name = getattr(fun, '__qualname__',
|
||||
getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, arg_names)
|
||||
fun_name, fails, specialized._args_flat, 'jit', specialized._arg_names)
|
||||
raise ValueError(msg) from None
|
||||
|
||||
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
jaxpr = params["jaxpr"]
|
||||
return stages.Lowered.from_flat_info(
|
||||
lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree)
|
||||
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
_, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
|
||||
@ -514,8 +503,8 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
lowering_parameters = kwargs.pop(
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
|
||||
args_flat, params, in_tree, out_tree, donated_invars, _, _ = _infer_params(
|
||||
jit_info, args, kwargs)
|
||||
(args_flat, params, in_tree, out_tree, donated_invars,
|
||||
arg_names, _) = _infer_params(jit_info, args, kwargs)
|
||||
|
||||
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
jaxpr = params['jaxpr']
|
||||
@ -523,7 +512,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
lower_callable = partial(_resolve_and_lower, args_flat, **params,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return stages.Specialized(jaxpr, args_info, params["name"], out_tree,
|
||||
lower_callable)
|
||||
lower_callable, args_flat, arg_names)
|
||||
|
||||
wrapped = _cpp_pjit(jit_info)
|
||||
wrapped.lower = lower
|
||||
@ -654,7 +643,6 @@ def _infer_params(jit_info, args, kwargs):
|
||||
args_flat = [*implicit_args, *explicit_args]
|
||||
|
||||
num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked)
|
||||
num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in attrs_tracked)
|
||||
num_extra_args = len(implicit_args) + num_states_in + len(consts)
|
||||
in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat
|
||||
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
|
||||
|
@ -426,15 +426,18 @@ class CompiledCallParams(NamedTuple):
|
||||
|
||||
|
||||
class Specialized(Stage):
|
||||
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable"]
|
||||
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
|
||||
"_args_flat", "_arg_names"]
|
||||
|
||||
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
|
||||
lower_callable):
|
||||
lower_callable, args_flat=None, arg_names=None):
|
||||
self.jaxpr = jaxpr
|
||||
self.args_info = args_info
|
||||
self.fun_name = fun_name
|
||||
self._out_tree = out_tree
|
||||
self._lower_callable = lower_callable
|
||||
self._args_flat = args_flat
|
||||
self._arg_names = arg_names
|
||||
|
||||
@property
|
||||
def out_info(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user