mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Clean up the support for cross-lowering.
In a previous CL we introduced cross-lowering support without any changes in JAX core, but at the expense of some overly complex code in jax2tf, along with overriding a JAX core function. Plus, those changes were not enough to handle some xmap and pmap cases. Here we introduce a `_experimental_lowering_platform: Optional[str]` parameter to the `.lower()` methods and then we thread the `lowering_platform` all the way to the calls to `mlir.lower_jaxpr_to_module2`. That's it. Note that this parameter to `.lower()` is experimental and not supposed to be used outside jax2tf. It may also gobble user kwargs.
This commit is contained in:
parent
713bc2687d
commit
9a424aabbd
@ -789,7 +789,8 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
return aval, device
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs) -> stages.Lowered:
|
||||
def lower(*args, _experimental_lowering_platform: Optional[str] = None,
|
||||
**kwargs) -> stages.Lowered:
|
||||
"""Lower this function for the given arguments.
|
||||
|
||||
A lowered function is staged out of Python and translated to a
|
||||
@ -823,13 +824,15 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
if jax.config.jax_array:
|
||||
computation = dispatch.sharded_lowering(
|
||||
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
|
||||
keep_unused, *arg_specs_and_devices)
|
||||
keep_unused, lowering_platform=_experimental_lowering_platform,
|
||||
*arg_specs_and_devices)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, in_tree, in_avals, donate_argnums, out_tree())
|
||||
else:
|
||||
computation = dispatch.lower_xla_callable(
|
||||
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
|
||||
keep_unused, *arg_specs_and_devices)
|
||||
keep_unused, lowering_platform=_experimental_lowering_platform,
|
||||
*arg_specs_and_devices)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, in_tree, in_avals, donate_argnums, out_tree())
|
||||
|
||||
@ -2471,7 +2474,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
# this might naturally be a method, with ``fun`` as a ``self`` and
|
||||
# all the other arguments stored as attributes.
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs) -> stages.Lowered:
|
||||
def lower(*args, _experimental_lowering_platform: Optional[str] = None,
|
||||
**kwargs) -> stages.Lowered:
|
||||
"""Lower a parallel-mapped form of this function for the given arguments.
|
||||
|
||||
A parallel-mapped and lowered function is staged out of Python and
|
||||
@ -2497,7 +2501,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donated_invars=p.donated_invars,
|
||||
global_arg_shapes=p.global_arg_shapes_flat,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
avals=abstract_args)
|
||||
avals=abstract_args,
|
||||
lowering_platform=_experimental_lowering_platform)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
|
||||
|
||||
|
@ -326,7 +326,8 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins):
|
||||
|
||||
|
||||
def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
|
||||
keep_unused, *arg_specs):
|
||||
keep_unused, *arg_specs,
|
||||
lowering_platform: Optional[str]):
|
||||
in_avals, in_shardings = util.unzip2(arg_specs)
|
||||
|
||||
da = None
|
||||
@ -334,7 +335,7 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
|
||||
da, in_shardings = not_none_device_or_backend_on_jit(
|
||||
backend, device, len(in_shardings))
|
||||
|
||||
in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings]
|
||||
in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore
|
||||
|
||||
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
|
||||
# the number of output avals at this stage. lower_sharding_computation will
|
||||
@ -342,19 +343,22 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars,
|
||||
in_avals, in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
|
||||
always_lower=always_lower, devices_from_context=da)
|
||||
always_lower=always_lower, devices_from_context=da,
|
||||
lowering_platform=lowering_platform)
|
||||
|
||||
|
||||
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
|
||||
donated_invars, keep_unused, *arg_specs):
|
||||
if config.jax_array:
|
||||
computation = sharded_lowering(fun, device, backend, name, donated_invars,
|
||||
False, keep_unused, *arg_specs)
|
||||
False, keep_unused, *arg_specs,
|
||||
lowering_platform=None)
|
||||
allow_prop = [True] * len(computation.compile_args['global_out_avals'])
|
||||
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
|
||||
else:
|
||||
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
|
||||
keep_unused, *arg_specs).compile().unsafe_call
|
||||
keep_unused, *arg_specs,
|
||||
lowering_platform=None).compile().unsafe_call
|
||||
|
||||
xla_callable = lu.cache(_xla_callable_uncached)
|
||||
|
||||
@ -414,7 +418,8 @@ def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr):
|
||||
@profiler.annotate_function
|
||||
def lower_xla_callable(
|
||||
fun: lu.WrappedFun, device, backend, name, donated_invars,
|
||||
always_lower: bool, keep_unused: bool, *arg_specs):
|
||||
always_lower: bool, keep_unused: bool, *arg_specs,
|
||||
lowering_platform: Optional[str]):
|
||||
"""Lower into XLA.
|
||||
|
||||
Args:
|
||||
@ -512,7 +517,8 @@ def lower_xla_callable(
|
||||
effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, unordered_effects,
|
||||
ordered_effects, backend, backend.platform,
|
||||
ordered_effects, backend,
|
||||
lowering_platform or backend.platform,
|
||||
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
|
@ -1273,7 +1273,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
pmap_computation = lower_parallel_callable(
|
||||
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
|
||||
is_explicit_global_axis_size, avals)
|
||||
is_explicit_global_axis_size, avals, lowering_platform=None)
|
||||
pmap_executable = pmap_computation.compile()
|
||||
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
|
||||
|
||||
@ -1397,7 +1397,9 @@ def lower_parallel_callable(
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
|
||||
is_explicit_global_axis_size: bool,
|
||||
avals: Sequence[core.AbstractValue]):
|
||||
avals: Sequence[core.AbstractValue],
|
||||
*,
|
||||
lowering_platform: Optional[str]):
|
||||
# Determine global_axis_size for use in AxisEnv.
|
||||
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
||||
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
||||
@ -1502,7 +1504,7 @@ def lower_parallel_callable(
|
||||
unordered_effects,
|
||||
ordered_effects,
|
||||
backend,
|
||||
backend.platform,
|
||||
lowering_platform or backend.platform,
|
||||
mlir.ReplicaAxisContext(axis_env),
|
||||
name_stack,
|
||||
donated_invars,
|
||||
@ -2876,10 +2878,12 @@ def lower_sharding_computation(
|
||||
out_shardings: Union[Sequence[Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue]], UnspecifiedValue],
|
||||
donated_invars: Sequence[bool],
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
in_is_global: Sequence[bool],
|
||||
keep_unused: bool,
|
||||
always_lower: bool,
|
||||
devices_from_context: Optional[Sequence[xc.Device]] = None
|
||||
devices_from_context: Optional[Sequence[xc.Device]] = None,
|
||||
lowering_platform: Optional[str],
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
@ -3047,7 +3051,8 @@ def lower_sharding_computation(
|
||||
unordered_effects,
|
||||
ordered_effects,
|
||||
backend,
|
||||
backend.platform,
|
||||
# Optionally, override the lowering platform
|
||||
lowering_platform or backend.platform,
|
||||
axis_ctx,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
@ -3102,7 +3107,8 @@ def lower_mesh_computation(
|
||||
spmd_lowering: bool,
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
tiling_method: Optional[TilingMethod],
|
||||
in_is_global: Sequence[bool]) -> MeshComputation:
|
||||
in_is_global: Sequence[bool],
|
||||
lowering_platform: Optional[str]) -> MeshComputation:
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
@ -3221,7 +3227,7 @@ def lower_mesh_computation(
|
||||
unordered_effects,
|
||||
ordered_effects,
|
||||
backend,
|
||||
backend.platform,
|
||||
lowering_platform or backend.platform,
|
||||
axis_ctx,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
|
@ -603,7 +603,7 @@ def xmap(fun: Callable,
|
||||
return verify_outputs(out_flat, out_tree, params)
|
||||
|
||||
@decorate_serial
|
||||
def lower(*args):
|
||||
def lower(*args, _experimental_lowering_platform: Optional[str] = None):
|
||||
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
|
||||
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
|
||||
computation = make_xmap_callable(
|
||||
@ -611,12 +611,13 @@ def xmap(fun: Callable,
|
||||
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
|
||||
params['resource_env'], params['backend'], params['spmd_in_axes'],
|
||||
params['spmd_out_axes_thunk'], params['in_positional_semantics'],
|
||||
params['out_positional_semantics'], *avals_flat)
|
||||
params['out_positional_semantics'],
|
||||
_experimental_lowering_platform, *avals_flat)
|
||||
|
||||
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
in_avals = in_tree.unflatten(avals_flat)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, in_tree, in_avals, donate_argnums, out_tree(),
|
||||
computation, in_tree, in_avals, donate_argnums, out_tree(), # type: ignore
|
||||
no_kwargs=True)
|
||||
|
||||
fun_mapped.lower = lower
|
||||
@ -631,7 +632,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
|
||||
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
|
||||
axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, out_positional_semantics,
|
||||
*in_avals).compile().unsafe_call
|
||||
None, *in_avals).compile().unsafe_call
|
||||
distributed_debug_log(("Running xmapped function", name),
|
||||
("python function", fun.f),
|
||||
("mesh", resource_env.physical_mesh),
|
||||
@ -644,7 +645,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
global_axis_sizes, axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics,
|
||||
out_positional_semantics, *in_avals):
|
||||
out_positional_semantics,
|
||||
lowering_platform: Optional[str],
|
||||
*in_avals):
|
||||
plan = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
|
||||
|
||||
@ -702,16 +705,17 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
f, 'xmap', name, mesh,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
use_spmd_lowering, global_in_avals,
|
||||
tiling_method=tiling_method, in_is_global=in_is_global)
|
||||
tiling_method=tiling_method, in_is_global=in_is_global,
|
||||
lowering_platform=lowering_platform)
|
||||
else:
|
||||
if config.jax_array:
|
||||
return dispatch.sharded_lowering(
|
||||
f, None, backend, name, donated_invars, False, True,
|
||||
*[(a, None) for a in in_avals])
|
||||
*[(a, None) for a in in_avals], lowering_platform=lowering_platform)
|
||||
else:
|
||||
return dispatch.lower_xla_callable(
|
||||
f, None, backend, name, donated_invars, False, True,
|
||||
*[(a, None) for a in in_avals])
|
||||
*[(a, None) for a in in_avals], lowering_platform=lowering_platform)
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""
|
||||
|
@ -364,7 +364,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
wrapped = _python_pjit(fun, infer_params_fn)
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
def lower(*args, _experimental_lowering_platform: Optional[str] = None,
|
||||
**kwargs):
|
||||
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
|
||||
donate_argnums) = infer_params_fn(*args, **kwargs)
|
||||
if jax.config.jax_array:
|
||||
@ -379,7 +380,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], in_shardings, params['out_shardings'],
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
in_is_global, params['keep_unused'], always_lower=True)
|
||||
in_is_global, params['keep_unused'], always_lower=True,
|
||||
lowering_platform=_experimental_lowering_platform)
|
||||
|
||||
if kwargs:
|
||||
args_kwargs_in_tree = in_tree
|
||||
@ -1289,7 +1291,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name, in_is_global, keep_unused,
|
||||
always_lower=False).compile(
|
||||
always_lower=False, lowering_platform=None).compile(
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
|
||||
_most_recent_pjit_call_executable.value = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
@ -1385,7 +1387,9 @@ def _pjit_lower_cached(
|
||||
name: str,
|
||||
in_is_global: Sequence[bool],
|
||||
keep_unused: bool,
|
||||
always_lower: bool):
|
||||
always_lower: bool,
|
||||
*,
|
||||
lowering_platform: Optional[str]):
|
||||
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
|
||||
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
|
||||
out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings
|
||||
@ -1431,7 +1435,8 @@ def _pjit_lower_cached(
|
||||
return pxla.lower_mesh_computation(
|
||||
fun, api_name, name, mesh,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global)
|
||||
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global,
|
||||
lowering_platform=lowering_platform)
|
||||
else:
|
||||
# Pass `in_is_global` here because this path is taken by both host local
|
||||
# avals and global avals.
|
||||
@ -1442,7 +1447,8 @@ def _pjit_lower_cached(
|
||||
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
|
||||
always_lower=always_lower,
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)))
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
|
||||
lowering_platform=lowering_platform)
|
||||
|
||||
|
||||
def pjit_staging_rule(trace, *args, **params):
|
||||
@ -1657,7 +1663,8 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
known_params["jaxpr"], known_params["in_shardings"],
|
||||
known_params["out_shardings"], known_params["resource_env"],
|
||||
known_params["donated_invars"], known_params["name"],
|
||||
in_is_global, known_params['keep_unused'], always_lower=False).compile(
|
||||
in_is_global, known_params['keep_unused'], always_lower=False,
|
||||
lowering_platform=None).compile(
|
||||
_allow_propagation_to_outputs=[True] * len(known_params['out_shardings']),
|
||||
_allow_compile_replicated=False)
|
||||
da = compiled._device_assignment
|
||||
|
@ -497,9 +497,12 @@ def flatten_fun_jax(fun_jax: Callable, args_tf: Sequence[TfVal],
|
||||
# preserve the lowering function. This will be used in the _lower_native_and_run.
|
||||
# We rely on the fact that the lowering is the same for the function
|
||||
# taking pytrees, and the one taking flat args.
|
||||
def fun_flat_jax_lower(*args_flat_jax):
|
||||
def fun_flat_jax_lower(*args_flat_jax, _experimental_lowering_platform):
|
||||
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
|
||||
lowered = fun_jax.lower(*tree_args, **tree_kwargs)
|
||||
lowered = fun_jax.lower(
|
||||
*tree_args,
|
||||
_experimental_lowering_platform=_experimental_lowering_platform,
|
||||
**tree_kwargs)
|
||||
out_tree = lowered.out_tree
|
||||
nonlocal out_tree_ref
|
||||
assert out_tree_ref is None or out_tree_ref == out_tree
|
||||
@ -678,23 +681,24 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
]
|
||||
|
||||
if lowering_params.experimental_native_lowering_platforms:
|
||||
lowered = cross_platform_lowering(
|
||||
fun_jax, arg_specs_jax, # type: ignore[arg-type]
|
||||
platforms=lowering_params.experimental_native_lowering_platforms
|
||||
)._lowering # type: ignore
|
||||
lowering_platform = lowering_params.experimental_native_lowering_platforms[0]
|
||||
else:
|
||||
if not hasattr(fun_jax, "lower") or abstracted_axes:
|
||||
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
|
||||
# we need to pass the abstracted axes.
|
||||
# TODO(necula): Will clean this when we clean the native lowering jax2tf API
|
||||
fun_jax_lower = jax.jit(fun_jax,
|
||||
abstracted_axes=abstracted_axes).lower
|
||||
else:
|
||||
# If we have a pjit or pmap already we do not wrap with another
|
||||
fun_jax_lower = fun_jax.lower
|
||||
lowering_platform = None
|
||||
|
||||
lowered = fun_jax_lower(*arg_specs_jax)._lowering # type: ignore
|
||||
if not hasattr(fun_jax, "lower") or abstracted_axes:
|
||||
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
|
||||
# we need to pass the abstracted axes.
|
||||
# TODO(necula): Will clean this when we clean the native lowering jax2tf API
|
||||
fun_jax_lower = jax.jit(fun_jax,
|
||||
abstracted_axes=abstracted_axes).lower
|
||||
else:
|
||||
# If we have a pjit or pmap already we do not wrap with another
|
||||
fun_jax_lower = fun_jax.lower
|
||||
|
||||
lowered = fun_jax_lower(
|
||||
*arg_specs_jax,
|
||||
_experimental_lowering_platform=lowering_platform)._lowering # type: ignore
|
||||
|
||||
if config.jax2tf_use_stablehlo:
|
||||
mlir_module = lowered.stablehlo()
|
||||
@ -708,6 +712,8 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
if "global_out_avals" in lowered.compile_args:
|
||||
# This is currently the case for pjit
|
||||
out_avals = lowered.compile_args["global_out_avals"]
|
||||
elif "shards" in lowered.compile_args: # for PmapComputation
|
||||
out_avals = lowered.compile_args["shards"].out_sharded_avals
|
||||
else:
|
||||
out_avals = lowered.compile_args["out_avals"]
|
||||
if lowered.compile_args["host_callbacks"]:
|
||||
@ -834,136 +840,6 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
for res_val, out_aval in zip(res, out_avals))
|
||||
return res, out_avals
|
||||
|
||||
def cross_platform_lowering(fun_jax, arg_specs: Sequence[jax.Array],
|
||||
*,
|
||||
platforms: Sequence[str] = ()):
|
||||
|
||||
context_mesh = pxla.thread_resources.env.physical_mesh
|
||||
if not context_mesh.empty:
|
||||
# What devices we need
|
||||
if context_mesh.is_multi_process:
|
||||
raise NotImplementedError("cross_platform lowering is not supported for multi-host lowering")
|
||||
devices = np.array(context_mesh.devices).reshape((-1,))
|
||||
devices_shape = np.shape(context_mesh.devices)
|
||||
axis_names = context_mesh.axis_names
|
||||
else:
|
||||
devices = [config.jax_default_device or jax.local_devices()[0]] # type: ignore
|
||||
devices_shape = (1,)
|
||||
axis_names = ("_no_axis",)
|
||||
|
||||
lowering_client = LoweringOnlyClient(platforms[0],
|
||||
1 + max(d.id for d in devices))
|
||||
lowering_devices = [lowering_client.devices[d.id] for d in devices]
|
||||
lowering_mesh = sharding.Mesh(
|
||||
np.array(lowering_devices).reshape(devices_shape), # type: ignore
|
||||
axis_names)
|
||||
|
||||
try:
|
||||
orig_jax_default_device = config.jax_default_device
|
||||
config.update("jax_default_device", lowering_devices[0]) # For nullary functions
|
||||
prev_get_and_check_device_assignment = pxla._get_and_check_device_assignment
|
||||
pxla._get_and_check_device_assignment = partial(_get_and_check_device_assignment,
|
||||
lowering_client)
|
||||
with lowering_mesh:
|
||||
if not hasattr(fun_jax, "lower"):
|
||||
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
|
||||
# we need to pass the abstracted axes or shardings.
|
||||
# TODO(necula): Will clean this when we clean the native lowering jax2tf API
|
||||
fun_jax_lower = jax.jit(fun_jax).lower
|
||||
else:
|
||||
fun_jax_lower = fun_jax.lower
|
||||
lowered = fun_jax_lower(*arg_specs)
|
||||
return lowered
|
||||
finally:
|
||||
config.update("jax_default_device", orig_jax_default_device)
|
||||
pxla._get_and_check_device_assignment = prev_get_and_check_device_assignment
|
||||
|
||||
class LoweringOnlyClient:
|
||||
"""A Client that overrides the platform, for cross-platform lowering only."""
|
||||
def __init__(self, platform: str, nr_devices: int):
|
||||
self.platform = platform
|
||||
self._process_index = 0
|
||||
self.devices = [LoweringOnlyDevice(self, i) for i in range(nr_devices)]
|
||||
self.lowering_only_client = True
|
||||
|
||||
def __str__(self):
|
||||
return f"LoweringOnlyClient({self.platform})"
|
||||
|
||||
def process_index(self):
|
||||
return self._process_index
|
||||
|
||||
def device_count(self):
|
||||
return len(self.devices)
|
||||
|
||||
class LoweringOnlyDevice:
|
||||
"""A Device that overrides the platform, for cross-platform lowering only."""
|
||||
def __init__(self, client: LoweringOnlyClient, id: int):
|
||||
self.client = client
|
||||
self.process_index = client.process_index()
|
||||
self.id = id
|
||||
|
||||
def __str__(self):
|
||||
return f"LoweringOnlyDevice({self.platform}, id={self.id})"
|
||||
|
||||
|
||||
# This is a copy of pxla._get_and_check_device_assignment, because we need
|
||||
# to change its behavior for cross-platform lowering.
|
||||
# The changes are marked below with "CHANGED:".
|
||||
# This function reconciles the device assignment from shardings and from
|
||||
# the mesh context. Some JAX primitives (xmap, shard_map) carry their own
|
||||
# mesh of devices, instead of relying on the mesh context manager, which would
|
||||
# conflict with the lowering-only devices. We must now only avoid raising
|
||||
# errors in the case, but we must also pick the lowering devices.
|
||||
def _get_and_check_device_assignment(
|
||||
lowering_client: LoweringOnlyClient, # CHANGED: we pass the overriding client
|
||||
shardings: Iterable[pxla.ShardingInfo],
|
||||
devices: Optional[Sequence[xla_client.Device]],
|
||||
) -> Tuple[xla_client.Client, Sequence[xla_client.Device]]:
|
||||
from jax._src.api import local_devices
|
||||
|
||||
first_sharding_info = None
|
||||
if devices is None:
|
||||
devices = []
|
||||
else:
|
||||
devices = list(devices)
|
||||
|
||||
for i, s_type, source_info in shardings:
|
||||
if pxla.is_auto(i) or pxla._is_unspecified(i):
|
||||
continue
|
||||
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
|
||||
# skipped.
|
||||
if first_sharding_info is None:
|
||||
first_sharding_info = (list(i._device_assignment), s_type, source_info) # type: ignore
|
||||
arr_device_assignment = list(i._device_assignment) # type: ignore
|
||||
if not devices:
|
||||
if first_sharding_info[0] != arr_device_assignment:
|
||||
# CHANGED: do not error if the only difference is in lowering_only_client
|
||||
if not all((d1.id == d2.id and
|
||||
(hasattr(d1.client, "lowering_only_client") or hasattr(d2.client, "lowering_only_client")))
|
||||
for d1, d2 in zip(first_sharding_info[0], arr_device_assignment)):
|
||||
raise pxla.DeviceAssignmentMismatchError([
|
||||
pxla.DeviceAssignmentMismatch(*first_sharding_info),
|
||||
pxla.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
||||
else:
|
||||
if devices != arr_device_assignment:
|
||||
# CHANGED: do not error if the only difference is in lowering_only_client
|
||||
if not all((d1.id == d2.id and
|
||||
(hasattr(d1.client, "lowering_only_client") or hasattr(d2.client, "lowering_only_client")))
|
||||
for d1, d2 in zip(devices, arr_device_assignment)):
|
||||
raise pxla.DeviceAssignmentMismatchError([
|
||||
pxla.DeviceAssignmentMismatch(devices, pxla.MismatchType.CONTEXT_DEVICES, None),
|
||||
pxla.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
||||
if first_sharding_info is None and devices:
|
||||
final_device_assignment = devices
|
||||
elif first_sharding_info is None:
|
||||
final_device_assignment = [config.jax_default_device or local_devices()[0]]
|
||||
else:
|
||||
final_device_assignment = first_sharding_info[0] # type: ignore
|
||||
|
||||
# CHANGED: override the device assignment
|
||||
final_device_assignment = tuple(lowering_client.devices[d.id] for d in final_device_assignment) # type: ignore
|
||||
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
|
||||
|
||||
def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
|
||||
in_vals: Sequence[TfVal],
|
||||
|
@ -1375,32 +1375,40 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f"See op.name = : {op.name}")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"{'with_mesh_' if with_mesh else ''}{'nullary_' if nullary else ''}{transform}_pjit_sharding={pjit_sharding}",
|
||||
with_mesh=with_mesh, transform=transform, nullary=nullary, pjit_sharding=pjit_sharding)
|
||||
# The inner transformation to apply to the lowered function
|
||||
for transform in ["base",
|
||||
"jit",
|
||||
"pjit", "pjit_in_shardings_None", "pjit_in_shardings_P", "pjit_in_shardings_Sharding",
|
||||
"shard_map", "xmap", "pmap"]
|
||||
# The sharding to be used for the outer pjit
|
||||
for pjit_sharding in (
|
||||
["unspecified"] if transform == "pmap" else
|
||||
["unspecified", "none", "P", "Sharding"])
|
||||
dict(testcase_name=(
|
||||
f"{'with_mesh_' if with_mesh else ''}"
|
||||
f"2={transform2 if transform2 != 'none' else ''}"
|
||||
f"_1={transform1 if transform1 != 'none' else ''}"
|
||||
f"{'_nullary' if nullary else ''}"),
|
||||
with_mesh=with_mesh, transform1=transform1,
|
||||
transform2=transform2, nullary=nullary)
|
||||
# Test transform2(transform1(func)
|
||||
for transform1 in [
|
||||
"none",
|
||||
"jit",
|
||||
"pjit", "pjit_in_shardings_None", "pjit_in_shardings_P",
|
||||
"pjit_in_shardings_Sharding",
|
||||
"shard_map", "xmap", "pmap"]
|
||||
for transform2 in (
|
||||
["none", "pjit_in_shardings_None", "pjit_in_shardings_P",
|
||||
"pjit_in_shardings_Sharding"]
|
||||
)
|
||||
# Whether the function can be nullary
|
||||
for nullary in (
|
||||
[False] if (pjit_sharding != "unspecified") else
|
||||
[True, False]
|
||||
)
|
||||
# To reduce the number of tests
|
||||
[True, False] if transform2 == "none" else
|
||||
[False])
|
||||
# Whether we use a "with mesh"
|
||||
for with_mesh in (
|
||||
[True] if (transform not in ["base", "jit", "pjit"] or
|
||||
pjit_sharding != "unspecified") else
|
||||
[True] if (transform1 not in ["base", "jit", "pjit"] or
|
||||
transform2 != "none") else
|
||||
[False, True])
|
||||
)
|
||||
def test_cross_platform(self, with_mesh=False, transform="jit", nullary=False, pjit_sharding="unspecified"):
|
||||
def test_cross_platform(self, with_mesh=True, transform1="xmap",
|
||||
transform2="none", nullary=True):
|
||||
# Tests cross-lowering for
|
||||
# with mesh:
|
||||
# pjit(transform(func), in_sharding=pjit_sharding)
|
||||
# transform2(transform1(func))
|
||||
if not config.jax_array:
|
||||
raise unittest.SkipTest("cross_platform test work only with jax.Array")
|
||||
if not config.jax_jit_pjit_api_merge:
|
||||
@ -1409,49 +1417,60 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
mesh = sharding.Mesh(jax.devices()[:1], ("a",))
|
||||
# cummax has distinctive lowering for TPU, using a reduce-window op
|
||||
func = lambda x: lax.cummax(x, axis=0, reverse=False)
|
||||
# For shard_map we cannot use cummax :-( because it does not have a replication rule
|
||||
# But we use lax.all_gather which on TPU is lowered with a all-gather op
|
||||
# For shard_map we cannot use cummax :-( because it does not have a
|
||||
# replication rule. But we use lax.all_gather which on TPU is lowered with
|
||||
# an all-gather op
|
||||
func_shard_map = lambda x: lax.all_gather(x, 'a', axis=1, tiled=True)
|
||||
|
||||
transformed_func = dict(
|
||||
base=func,
|
||||
jit=jax.jit(func),
|
||||
jit_in_shardings_None=jax.jit(func, in_shardings=None),
|
||||
jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)),
|
||||
jit_in_shardings_Sharding=jax.jit(func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)),
|
||||
pjit=pjit.pjit(func),
|
||||
pjit_in_shardings_None=pjit.pjit(func, in_shardings=None),
|
||||
pjit_in_shardings_P=pjit.pjit(func, in_shardings=(P("a"),)),
|
||||
pjit_in_shardings_Sharding=pjit.pjit(func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)),
|
||||
shard_map=(
|
||||
shard_map(func_shard_map, mesh, in_specs=(P("a", None),), out_specs=P("a", None))),
|
||||
xmap=xmap(func, in_axes=({0: 'axis'},), out_axes={0: 'axis'}, axis_resources={'axis': 'a'}),
|
||||
pmap=jax.pmap(func, in_axes=0, out_axes=0),
|
||||
)[transform]
|
||||
pjit_transformed_func = dict(
|
||||
unspecified=pjit.pjit(transformed_func),
|
||||
none=pjit.pjit(transformed_func, in_shardings=None),
|
||||
P=pjit.pjit(transformed_func, in_shardings=(P("a"),)),
|
||||
Sharding=pjit.pjit(transformed_func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)),
|
||||
)[pjit_sharding]
|
||||
if pjit_sharding == "unspecified":
|
||||
if transform == "xmap":
|
||||
raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes")
|
||||
def apply_transform(func, transform: str):
|
||||
transformed_func = dict(
|
||||
none=func,
|
||||
jit=jax.jit(func),
|
||||
jit_in_shardings_None=jax.jit(func, in_shardings=None), # type: ignore
|
||||
jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)), # type: ignore
|
||||
jit_in_shardings_Sharding=jax.jit(
|
||||
func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), # type: ignore
|
||||
pjit=pjit.pjit(func),
|
||||
pjit_in_shardings_None=pjit.pjit(func, in_shardings=None,
|
||||
out_shardings=None),
|
||||
pjit_in_shardings_P=pjit.pjit(func, in_shardings=(P("a"),),
|
||||
out_shardings=P("a")),
|
||||
pjit_in_shardings_Sharding=pjit.pjit(
|
||||
func,
|
||||
in_shardings=(sharding.NamedSharding(mesh, P("a")),),
|
||||
out_shardings=sharding.NamedSharding(mesh, P("a"))),
|
||||
shard_map=(
|
||||
shard_map(func, mesh, in_specs=(P("a", None),),
|
||||
out_specs=P("a", None))),
|
||||
xmap=xmap(func, in_axes=({0: 'axis'},),
|
||||
out_axes={0: 'axis'}, axis_resources={'axis': 'a'}),
|
||||
pmap=jax.pmap(func, in_axes=0, out_axes=0),
|
||||
)[transform]
|
||||
return transformed_func
|
||||
|
||||
transformed1_func = apply_transform(
|
||||
(func_shard_map if transform1 == "shard_map" else func),
|
||||
transform1)
|
||||
assert transform2 not in ["xmap", "shard_map"]
|
||||
transformed2_func = apply_transform(transformed1_func, transform2)
|
||||
|
||||
if transform1 == "xmap" and transform2 in ["pjit", "none"]:
|
||||
raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes")
|
||||
|
||||
if transform1 == "pmap":
|
||||
x = x.reshape((1, -1)) # Since we use 1 device
|
||||
if not nullary:
|
||||
func_to_convert = pjit_transformed_func
|
||||
func_to_convert = transformed2_func
|
||||
args = [x]
|
||||
else:
|
||||
func_to_convert = lambda: pjit_transformed_func(jnp.ones(x.shape, dtype=x.dtype))
|
||||
func_to_convert = lambda: transformed2_func(jnp.ones(x.shape,
|
||||
dtype=x.dtype))
|
||||
args = []
|
||||
|
||||
if transform == "pmap":
|
||||
if transform1 == "pmap":
|
||||
if nullary:
|
||||
raise unittest.SkipTest("Cannot lower nested pmap: jit-of-pmap warning")
|
||||
raise unittest.SkipTest("TODO: pmap picks the devices from jax.devices() and will lower for CPU")
|
||||
|
||||
if transform == "xmap":
|
||||
raise unittest.SkipTest("TODO: xmap does not pick up the overriden mesh and will lower for CPU")
|
||||
raise unittest.SkipTest("TODO: figure out how to invoke pmap from TF")
|
||||
|
||||
f_tf = jax2tf.convert(func_to_convert,
|
||||
experimental_native_lowering=True,
|
||||
@ -1464,10 +1483,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
_ = func_to_convert(*args)
|
||||
tf_hlo = f_tf.experimental_get_compiler_ir(*args)(stage="hlo")
|
||||
|
||||
if transform == "shard_map":
|
||||
self.assertIn("all-gather(f32[4,6]", tf_hlo)
|
||||
if transform1 == "shard_map":
|
||||
self.assertIn(" all-gather(f32[4,6]", tf_hlo)
|
||||
else:
|
||||
self.assertIn("reduce-window(f32[4,6]", tf_hlo)
|
||||
self.assertIn(" reduce-window(", tf_hlo)
|
||||
|
||||
|
||||
def get_serialized_computation(
|
||||
|
Loading…
x
Reference in New Issue
Block a user