[jax2tf] Clean up the support for cross-lowering.

In a previous CL we introduced cross-lowering support without any
changes in JAX core, but at the expense of some overly complex code
in jax2tf, along with overriding a JAX core function. Plus, those
changes were not enough to handle some xmap and pmap cases.

Here we introduce a `_experimental_lowering_platform: Optional[str]` parameter
to the `.lower()` methods and then we thread the `lowering_platform`
all the way to the calls to `mlir.lower_jaxpr_to_module2`. That's it.

Note that this parameter to `.lower()` is experimental and not supposed
to be used outside jax2tf. It may also gobble user kwargs.
This commit is contained in:
George Necula 2023-02-28 11:30:23 +01:00
parent 713bc2687d
commit 9a424aabbd
7 changed files with 158 additions and 235 deletions

View File

@ -789,7 +789,8 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
return aval, device
@api_boundary
def lower(*args, **kwargs) -> stages.Lowered:
def lower(*args, _experimental_lowering_platform: Optional[str] = None,
**kwargs) -> stages.Lowered:
"""Lower this function for the given arguments.
A lowered function is staged out of Python and translated to a
@ -823,13 +824,15 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
if jax.config.jax_array:
computation = dispatch.sharded_lowering(
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
keep_unused, *arg_specs_and_devices)
keep_unused, lowering_platform=_experimental_lowering_platform,
*arg_specs_and_devices)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree())
else:
computation = dispatch.lower_xla_callable(
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
keep_unused, *arg_specs_and_devices)
keep_unused, lowering_platform=_experimental_lowering_platform,
*arg_specs_and_devices)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree())
@ -2471,7 +2474,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
# this might naturally be a method, with ``fun`` as a ``self`` and
# all the other arguments stored as attributes.
@api_boundary
def lower(*args, **kwargs) -> stages.Lowered:
def lower(*args, _experimental_lowering_platform: Optional[str] = None,
**kwargs) -> stages.Lowered:
"""Lower a parallel-mapped form of this function for the given arguments.
A parallel-mapped and lowered function is staged out of Python and
@ -2497,7 +2501,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args)
avals=abstract_args,
lowering_platform=_experimental_lowering_platform)
return stages.Lowered.from_flat_info(
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())

View File

@ -326,7 +326,8 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins):
def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
keep_unused, *arg_specs):
keep_unused, *arg_specs,
lowering_platform: Optional[str]):
in_avals, in_shardings = util.unzip2(arg_specs)
da = None
@ -334,7 +335,7 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
da, in_shardings = not_none_device_or_backend_on_jit(
backend, device, len(in_shardings))
in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings]
in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
@ -342,19 +343,22 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars,
in_avals, in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
always_lower=always_lower, devices_from_context=da)
always_lower=always_lower, devices_from_context=da,
lowering_platform=lowering_platform)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, keep_unused, *arg_specs):
if config.jax_array:
computation = sharded_lowering(fun, device, backend, name, donated_invars,
False, keep_unused, *arg_specs)
False, keep_unused, *arg_specs,
lowering_platform=None)
allow_prop = [True] * len(computation.compile_args['global_out_avals'])
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
else:
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
keep_unused, *arg_specs).compile().unsafe_call
keep_unused, *arg_specs,
lowering_platform=None).compile().unsafe_call
xla_callable = lu.cache(_xla_callable_uncached)
@ -414,7 +418,8 @@ def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr):
@profiler.annotate_function
def lower_xla_callable(
fun: lu.WrappedFun, device, backend, name, donated_invars,
always_lower: bool, keep_unused: bool, *arg_specs):
always_lower: bool, keep_unused: bool, *arg_specs,
lowering_platform: Optional[str]):
"""Lower into XLA.
Args:
@ -512,7 +517,8 @@ def lower_xla_callable(
effects.ordered_effects.filter_in(closed_jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, unordered_effects,
ordered_effects, backend, backend.platform,
ordered_effects, backend,
lowering_platform or backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,

View File

@ -1273,7 +1273,7 @@ def parallel_callable(fun: lu.WrappedFun,
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
is_explicit_global_axis_size, avals)
is_explicit_global_axis_size, avals, lowering_platform=None)
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
@ -1397,7 +1397,9 @@ def lower_parallel_callable(
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue]):
avals: Sequence[core.AbstractValue],
*,
lowering_platform: Optional[str]):
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
@ -1502,7 +1504,7 @@ def lower_parallel_callable(
unordered_effects,
ordered_effects,
backend,
backend.platform,
lowering_platform or backend.platform,
mlir.ReplicaAxisContext(axis_env),
name_stack,
donated_invars,
@ -2876,10 +2878,12 @@ def lower_sharding_computation(
out_shardings: Union[Sequence[Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue]], UnspecifiedValue],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
in_is_global: Sequence[bool],
keep_unused: bool,
always_lower: bool,
devices_from_context: Optional[Sequence[xc.Device]] = None
devices_from_context: Optional[Sequence[xc.Device]] = None,
lowering_platform: Optional[str],
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -3047,7 +3051,8 @@ def lower_sharding_computation(
unordered_effects,
ordered_effects,
backend,
backend.platform,
# Optionally, override the lowering platform
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,
@ -3102,7 +3107,8 @@ def lower_mesh_computation(
spmd_lowering: bool,
global_in_avals: Sequence[core.ShapedArray],
tiling_method: Optional[TilingMethod],
in_is_global: Sequence[bool]) -> MeshComputation:
in_is_global: Sequence[bool],
lowering_platform: Optional[str]) -> MeshComputation:
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
@ -3221,7 +3227,7 @@ def lower_mesh_computation(
unordered_effects,
ordered_effects,
backend,
backend.platform,
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,

View File

@ -603,7 +603,7 @@ def xmap(fun: Callable,
return verify_outputs(out_flat, out_tree, params)
@decorate_serial
def lower(*args):
def lower(*args, _experimental_lowering_platform: Optional[str] = None):
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
computation = make_xmap_callable(
@ -611,12 +611,13 @@ def xmap(fun: Callable,
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
params['resource_env'], params['backend'], params['spmd_in_axes'],
params['spmd_out_axes_thunk'], params['in_positional_semantics'],
params['out_positional_semantics'], *avals_flat)
params['out_positional_semantics'],
_experimental_lowering_platform, *avals_flat)
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
in_avals = in_tree.unflatten(avals_flat)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree(),
computation, in_tree, in_avals, donate_argnums, out_tree(), # type: ignore
no_kwargs=True)
fun_mapped.lower = lower
@ -631,7 +632,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, out_positional_semantics,
*in_avals).compile().unsafe_call
None, *in_avals).compile().unsafe_call
distributed_debug_log(("Running xmapped function", name),
("python function", fun.f),
("mesh", resource_env.physical_mesh),
@ -644,7 +645,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
in_axes, out_axes_thunk, donated_invars,
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics,
out_positional_semantics, *in_avals):
out_positional_semantics,
lowering_platform: Optional[str],
*in_avals):
plan = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes, in_positional_semantics)
@ -702,16 +705,17 @@ def make_xmap_callable(fun: lu.WrappedFun,
f, 'xmap', name, mesh,
in_shardings, out_shardings, donated_invars,
use_spmd_lowering, global_in_avals,
tiling_method=tiling_method, in_is_global=in_is_global)
tiling_method=tiling_method, in_is_global=in_is_global,
lowering_platform=lowering_platform)
else:
if config.jax_array:
return dispatch.sharded_lowering(
f, None, backend, name, donated_invars, False, True,
*[(a, None) for a in in_avals])
*[(a, None) for a in in_avals], lowering_platform=lowering_platform)
else:
return dispatch.lower_xla_callable(
f, None, backend, name, donated_invars, False, True,
*[(a, None) for a in in_avals])
*[(a, None) for a in in_avals], lowering_platform=lowering_platform)
class EvaluationPlan(NamedTuple):
"""Encapsulates preprocessing common to top-level xmap invocations and its translation rule."""

View File

@ -364,7 +364,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
wrapped = _python_pjit(fun, infer_params_fn)
@api_boundary
def lower(*args, **kwargs):
def lower(*args, _experimental_lowering_platform: Optional[str] = None,
**kwargs):
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params_fn(*args, **kwargs)
if jax.config.jax_array:
@ -379,7 +380,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
lowering = _pjit_lower(
params['jaxpr'], in_shardings, params['out_shardings'],
params['resource_env'], params['donated_invars'], params['name'],
in_is_global, params['keep_unused'], always_lower=True)
in_is_global, params['keep_unused'], always_lower=True,
lowering_platform=_experimental_lowering_platform)
if kwargs:
args_kwargs_in_tree = in_tree
@ -1289,7 +1291,7 @@ def _pjit_call_impl(*args, jaxpr,
compiled = _pjit_lower(
jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, in_is_global, keep_unused,
always_lower=False).compile(
always_lower=False, lowering_platform=None).compile(
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
_most_recent_pjit_call_executable.value = compiled
# This check is expensive so only do it if enable_checks is on.
@ -1385,7 +1387,9 @@ def _pjit_lower_cached(
name: str,
in_is_global: Sequence[bool],
keep_unused: bool,
always_lower: bool):
always_lower: bool,
*,
lowering_platform: Optional[str]):
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings
@ -1431,7 +1435,8 @@ def _pjit_lower_cached(
return pxla.lower_mesh_computation(
fun, api_name, name, mesh,
in_shardings, out_shardings, donated_invars,
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global)
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global,
lowering_platform=lowering_platform)
else:
# Pass `in_is_global` here because this path is taken by both host local
# avals and global avals.
@ -1442,7 +1447,8 @@ def _pjit_lower_cached(
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused,
always_lower=always_lower,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)))
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
lowering_platform=lowering_platform)
def pjit_staging_rule(trace, *args, **params):
@ -1657,7 +1663,8 @@ def _pjit_partial_eval(trace, *in_tracers,
known_params["jaxpr"], known_params["in_shardings"],
known_params["out_shardings"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global, known_params['keep_unused'], always_lower=False).compile(
in_is_global, known_params['keep_unused'], always_lower=False,
lowering_platform=None).compile(
_allow_propagation_to_outputs=[True] * len(known_params['out_shardings']),
_allow_compile_replicated=False)
da = compiled._device_assignment

View File

@ -497,9 +497,12 @@ def flatten_fun_jax(fun_jax: Callable, args_tf: Sequence[TfVal],
# preserve the lowering function. This will be used in the _lower_native_and_run.
# We rely on the fact that the lowering is the same for the function
# taking pytrees, and the one taking flat args.
def fun_flat_jax_lower(*args_flat_jax):
def fun_flat_jax_lower(*args_flat_jax, _experimental_lowering_platform):
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
lowered = fun_jax.lower(*tree_args, **tree_kwargs)
lowered = fun_jax.lower(
*tree_args,
_experimental_lowering_platform=_experimental_lowering_platform,
**tree_kwargs)
out_tree = lowered.out_tree
nonlocal out_tree_ref
assert out_tree_ref is None or out_tree_ref == out_tree
@ -678,23 +681,24 @@ def _lower_native_and_run(fun_jax: Callable,
]
if lowering_params.experimental_native_lowering_platforms:
lowered = cross_platform_lowering(
fun_jax, arg_specs_jax, # type: ignore[arg-type]
platforms=lowering_params.experimental_native_lowering_platforms
)._lowering # type: ignore
lowering_platform = lowering_params.experimental_native_lowering_platforms[0]
else:
if not hasattr(fun_jax, "lower") or abstracted_axes:
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
# we need to pass the abstracted axes.
# TODO(necula): Will clean this when we clean the native lowering jax2tf API
fun_jax_lower = jax.jit(fun_jax,
abstracted_axes=abstracted_axes).lower
else:
# If we have a pjit or pmap already we do not wrap with another
fun_jax_lower = fun_jax.lower
lowering_platform = None
lowered = fun_jax_lower(*arg_specs_jax)._lowering # type: ignore
if not hasattr(fun_jax, "lower") or abstracted_axes:
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
# we need to pass the abstracted axes.
# TODO(necula): Will clean this when we clean the native lowering jax2tf API
fun_jax_lower = jax.jit(fun_jax,
abstracted_axes=abstracted_axes).lower
else:
# If we have a pjit or pmap already we do not wrap with another
fun_jax_lower = fun_jax.lower
lowered = fun_jax_lower(
*arg_specs_jax,
_experimental_lowering_platform=lowering_platform)._lowering # type: ignore
if config.jax2tf_use_stablehlo:
mlir_module = lowered.stablehlo()
@ -708,6 +712,8 @@ def _lower_native_and_run(fun_jax: Callable,
if "global_out_avals" in lowered.compile_args:
# This is currently the case for pjit
out_avals = lowered.compile_args["global_out_avals"]
elif "shards" in lowered.compile_args: # for PmapComputation
out_avals = lowered.compile_args["shards"].out_sharded_avals
else:
out_avals = lowered.compile_args["out_avals"]
if lowered.compile_args["host_callbacks"]:
@ -834,136 +840,6 @@ def _lower_native_and_run(fun_jax: Callable,
for res_val, out_aval in zip(res, out_avals))
return res, out_avals
def cross_platform_lowering(fun_jax, arg_specs: Sequence[jax.Array],
*,
platforms: Sequence[str] = ()):
context_mesh = pxla.thread_resources.env.physical_mesh
if not context_mesh.empty:
# What devices we need
if context_mesh.is_multi_process:
raise NotImplementedError("cross_platform lowering is not supported for multi-host lowering")
devices = np.array(context_mesh.devices).reshape((-1,))
devices_shape = np.shape(context_mesh.devices)
axis_names = context_mesh.axis_names
else:
devices = [config.jax_default_device or jax.local_devices()[0]] # type: ignore
devices_shape = (1,)
axis_names = ("_no_axis",)
lowering_client = LoweringOnlyClient(platforms[0],
1 + max(d.id for d in devices))
lowering_devices = [lowering_client.devices[d.id] for d in devices]
lowering_mesh = sharding.Mesh(
np.array(lowering_devices).reshape(devices_shape), # type: ignore
axis_names)
try:
orig_jax_default_device = config.jax_default_device
config.update("jax_default_device", lowering_devices[0]) # For nullary functions
prev_get_and_check_device_assignment = pxla._get_and_check_device_assignment
pxla._get_and_check_device_assignment = partial(_get_and_check_device_assignment,
lowering_client)
with lowering_mesh:
if not hasattr(fun_jax, "lower"):
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
# we need to pass the abstracted axes or shardings.
# TODO(necula): Will clean this when we clean the native lowering jax2tf API
fun_jax_lower = jax.jit(fun_jax).lower
else:
fun_jax_lower = fun_jax.lower
lowered = fun_jax_lower(*arg_specs)
return lowered
finally:
config.update("jax_default_device", orig_jax_default_device)
pxla._get_and_check_device_assignment = prev_get_and_check_device_assignment
class LoweringOnlyClient:
"""A Client that overrides the platform, for cross-platform lowering only."""
def __init__(self, platform: str, nr_devices: int):
self.platform = platform
self._process_index = 0
self.devices = [LoweringOnlyDevice(self, i) for i in range(nr_devices)]
self.lowering_only_client = True
def __str__(self):
return f"LoweringOnlyClient({self.platform})"
def process_index(self):
return self._process_index
def device_count(self):
return len(self.devices)
class LoweringOnlyDevice:
"""A Device that overrides the platform, for cross-platform lowering only."""
def __init__(self, client: LoweringOnlyClient, id: int):
self.client = client
self.process_index = client.process_index()
self.id = id
def __str__(self):
return f"LoweringOnlyDevice({self.platform}, id={self.id})"
# This is a copy of pxla._get_and_check_device_assignment, because we need
# to change its behavior for cross-platform lowering.
# The changes are marked below with "CHANGED:".
# This function reconciles the device assignment from shardings and from
# the mesh context. Some JAX primitives (xmap, shard_map) carry their own
# mesh of devices, instead of relying on the mesh context manager, which would
# conflict with the lowering-only devices. We must now only avoid raising
# errors in the case, but we must also pick the lowering devices.
def _get_and_check_device_assignment(
lowering_client: LoweringOnlyClient, # CHANGED: we pass the overriding client
shardings: Iterable[pxla.ShardingInfo],
devices: Optional[Sequence[xla_client.Device]],
) -> Tuple[xla_client.Client, Sequence[xla_client.Device]]:
from jax._src.api import local_devices
first_sharding_info = None
if devices is None:
devices = []
else:
devices = list(devices)
for i, s_type, source_info in shardings:
if pxla.is_auto(i) or pxla._is_unspecified(i):
continue
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
# skipped.
if first_sharding_info is None:
first_sharding_info = (list(i._device_assignment), s_type, source_info) # type: ignore
arr_device_assignment = list(i._device_assignment) # type: ignore
if not devices:
if first_sharding_info[0] != arr_device_assignment:
# CHANGED: do not error if the only difference is in lowering_only_client
if not all((d1.id == d2.id and
(hasattr(d1.client, "lowering_only_client") or hasattr(d2.client, "lowering_only_client")))
for d1, d2 in zip(first_sharding_info[0], arr_device_assignment)):
raise pxla.DeviceAssignmentMismatchError([
pxla.DeviceAssignmentMismatch(*first_sharding_info),
pxla.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
else:
if devices != arr_device_assignment:
# CHANGED: do not error if the only difference is in lowering_only_client
if not all((d1.id == d2.id and
(hasattr(d1.client, "lowering_only_client") or hasattr(d2.client, "lowering_only_client")))
for d1, d2 in zip(devices, arr_device_assignment)):
raise pxla.DeviceAssignmentMismatchError([
pxla.DeviceAssignmentMismatch(devices, pxla.MismatchType.CONTEXT_DEVICES, None),
pxla.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
if first_sharding_info is None and devices:
final_device_assignment = devices
elif first_sharding_info is None:
final_device_assignment = [config.jax_default_device or local_devices()[0]]
else:
final_device_assignment = first_sharding_info[0] # type: ignore
# CHANGED: override the device assignment
final_device_assignment = tuple(lowering_client.devices[d.id] for d in final_device_assignment) # type: ignore
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
in_vals: Sequence[TfVal],

View File

@ -1375,32 +1375,40 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f"See op.name = : {op.name}")
@parameterized.named_parameters(
dict(testcase_name=f"{'with_mesh_' if with_mesh else ''}{'nullary_' if nullary else ''}{transform}_pjit_sharding={pjit_sharding}",
with_mesh=with_mesh, transform=transform, nullary=nullary, pjit_sharding=pjit_sharding)
# The inner transformation to apply to the lowered function
for transform in ["base",
"jit",
"pjit", "pjit_in_shardings_None", "pjit_in_shardings_P", "pjit_in_shardings_Sharding",
"shard_map", "xmap", "pmap"]
# The sharding to be used for the outer pjit
for pjit_sharding in (
["unspecified"] if transform == "pmap" else
["unspecified", "none", "P", "Sharding"])
dict(testcase_name=(
f"{'with_mesh_' if with_mesh else ''}"
f"2={transform2 if transform2 != 'none' else ''}"
f"_1={transform1 if transform1 != 'none' else ''}"
f"{'_nullary' if nullary else ''}"),
with_mesh=with_mesh, transform1=transform1,
transform2=transform2, nullary=nullary)
# Test transform2(transform1(func)
for transform1 in [
"none",
"jit",
"pjit", "pjit_in_shardings_None", "pjit_in_shardings_P",
"pjit_in_shardings_Sharding",
"shard_map", "xmap", "pmap"]
for transform2 in (
["none", "pjit_in_shardings_None", "pjit_in_shardings_P",
"pjit_in_shardings_Sharding"]
)
# Whether the function can be nullary
for nullary in (
[False] if (pjit_sharding != "unspecified") else
[True, False]
)
# To reduce the number of tests
[True, False] if transform2 == "none" else
[False])
# Whether we use a "with mesh"
for with_mesh in (
[True] if (transform not in ["base", "jit", "pjit"] or
pjit_sharding != "unspecified") else
[True] if (transform1 not in ["base", "jit", "pjit"] or
transform2 != "none") else
[False, True])
)
def test_cross_platform(self, with_mesh=False, transform="jit", nullary=False, pjit_sharding="unspecified"):
def test_cross_platform(self, with_mesh=True, transform1="xmap",
transform2="none", nullary=True):
# Tests cross-lowering for
# with mesh:
# pjit(transform(func), in_sharding=pjit_sharding)
# transform2(transform1(func))
if not config.jax_array:
raise unittest.SkipTest("cross_platform test work only with jax.Array")
if not config.jax_jit_pjit_api_merge:
@ -1409,49 +1417,60 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
mesh = sharding.Mesh(jax.devices()[:1], ("a",))
# cummax has distinctive lowering for TPU, using a reduce-window op
func = lambda x: lax.cummax(x, axis=0, reverse=False)
# For shard_map we cannot use cummax :-( because it does not have a replication rule
# But we use lax.all_gather which on TPU is lowered with a all-gather op
# For shard_map we cannot use cummax :-( because it does not have a
# replication rule. But we use lax.all_gather which on TPU is lowered with
# an all-gather op
func_shard_map = lambda x: lax.all_gather(x, 'a', axis=1, tiled=True)
transformed_func = dict(
base=func,
jit=jax.jit(func),
jit_in_shardings_None=jax.jit(func, in_shardings=None),
jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)),
jit_in_shardings_Sharding=jax.jit(func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)),
pjit=pjit.pjit(func),
pjit_in_shardings_None=pjit.pjit(func, in_shardings=None),
pjit_in_shardings_P=pjit.pjit(func, in_shardings=(P("a"),)),
pjit_in_shardings_Sharding=pjit.pjit(func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)),
shard_map=(
shard_map(func_shard_map, mesh, in_specs=(P("a", None),), out_specs=P("a", None))),
xmap=xmap(func, in_axes=({0: 'axis'},), out_axes={0: 'axis'}, axis_resources={'axis': 'a'}),
pmap=jax.pmap(func, in_axes=0, out_axes=0),
)[transform]
pjit_transformed_func = dict(
unspecified=pjit.pjit(transformed_func),
none=pjit.pjit(transformed_func, in_shardings=None),
P=pjit.pjit(transformed_func, in_shardings=(P("a"),)),
Sharding=pjit.pjit(transformed_func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)),
)[pjit_sharding]
if pjit_sharding == "unspecified":
if transform == "xmap":
raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes")
def apply_transform(func, transform: str):
transformed_func = dict(
none=func,
jit=jax.jit(func),
jit_in_shardings_None=jax.jit(func, in_shardings=None), # type: ignore
jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)), # type: ignore
jit_in_shardings_Sharding=jax.jit(
func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), # type: ignore
pjit=pjit.pjit(func),
pjit_in_shardings_None=pjit.pjit(func, in_shardings=None,
out_shardings=None),
pjit_in_shardings_P=pjit.pjit(func, in_shardings=(P("a"),),
out_shardings=P("a")),
pjit_in_shardings_Sharding=pjit.pjit(
func,
in_shardings=(sharding.NamedSharding(mesh, P("a")),),
out_shardings=sharding.NamedSharding(mesh, P("a"))),
shard_map=(
shard_map(func, mesh, in_specs=(P("a", None),),
out_specs=P("a", None))),
xmap=xmap(func, in_axes=({0: 'axis'},),
out_axes={0: 'axis'}, axis_resources={'axis': 'a'}),
pmap=jax.pmap(func, in_axes=0, out_axes=0),
)[transform]
return transformed_func
transformed1_func = apply_transform(
(func_shard_map if transform1 == "shard_map" else func),
transform1)
assert transform2 not in ["xmap", "shard_map"]
transformed2_func = apply_transform(transformed1_func, transform2)
if transform1 == "xmap" and transform2 in ["pjit", "none"]:
raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes")
if transform1 == "pmap":
x = x.reshape((1, -1)) # Since we use 1 device
if not nullary:
func_to_convert = pjit_transformed_func
func_to_convert = transformed2_func
args = [x]
else:
func_to_convert = lambda: pjit_transformed_func(jnp.ones(x.shape, dtype=x.dtype))
func_to_convert = lambda: transformed2_func(jnp.ones(x.shape,
dtype=x.dtype))
args = []
if transform == "pmap":
if transform1 == "pmap":
if nullary:
raise unittest.SkipTest("Cannot lower nested pmap: jit-of-pmap warning")
raise unittest.SkipTest("TODO: pmap picks the devices from jax.devices() and will lower for CPU")
if transform == "xmap":
raise unittest.SkipTest("TODO: xmap does not pick up the overriden mesh and will lower for CPU")
raise unittest.SkipTest("TODO: figure out how to invoke pmap from TF")
f_tf = jax2tf.convert(func_to_convert,
experimental_native_lowering=True,
@ -1464,10 +1483,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
_ = func_to_convert(*args)
tf_hlo = f_tf.experimental_get_compiler_ir(*args)(stage="hlo")
if transform == "shard_map":
self.assertIn("all-gather(f32[4,6]", tf_hlo)
if transform1 == "shard_map":
self.assertIn(" all-gather(f32[4,6]", tf_hlo)
else:
self.assertIn("reduce-window(f32[4,6]", tf_hlo)
self.assertIn(" reduce-window(", tf_hlo)
def get_serialized_computation(