From 9a424aabbd1b0dc5e28115e4d833226b3703053f Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 28 Feb 2023 11:30:23 +0100 Subject: [PATCH] [jax2tf] Clean up the support for cross-lowering. In a previous CL we introduced cross-lowering support without any changes in JAX core, but at the expense of some overly complex code in jax2tf, along with overriding a JAX core function. Plus, those changes were not enough to handle some xmap and pmap cases. Here we introduce a `_experimental_lowering_platform: Optional[str]` parameter to the `.lower()` methods and then we thread the `lowering_platform` all the way to the calls to `mlir.lower_jaxpr_to_module2`. That's it. Note that this parameter to `.lower()` is experimental and not supposed to be used outside jax2tf. It may also gobble user kwargs. --- jax/_src/api.py | 15 +- jax/_src/dispatch.py | 20 ++- jax/_src/interpreters/pxla.py | 20 ++- jax/_src/maps.py | 20 ++- jax/_src/pjit.py | 21 ++- jax/experimental/jax2tf/jax2tf.py | 170 +++---------------- jax/experimental/jax2tf/tests/jax2tf_test.py | 127 ++++++++------ 7 files changed, 158 insertions(+), 235 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 236eec28f..c10dfa36a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -789,7 +789,8 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend, return aval, device @api_boundary - def lower(*args, **kwargs) -> stages.Lowered: + def lower(*args, _experimental_lowering_platform: Optional[str] = None, + **kwargs) -> stages.Lowered: """Lower this function for the given arguments. A lowered function is staged out of Python and translated to a @@ -823,13 +824,15 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend, if jax.config.jax_array: computation = dispatch.sharded_lowering( flat_fun, device, backend, flat_fun.__name__, donated_invars, True, - keep_unused, *arg_specs_and_devices) + keep_unused, lowering_platform=_experimental_lowering_platform, + *arg_specs_and_devices) return stages.Lowered.from_flat_info( computation, in_tree, in_avals, donate_argnums, out_tree()) else: computation = dispatch.lower_xla_callable( flat_fun, device, backend, flat_fun.__name__, donated_invars, True, - keep_unused, *arg_specs_and_devices) + keep_unused, lowering_platform=_experimental_lowering_platform, + *arg_specs_and_devices) return stages.Lowered.from_flat_info( computation, in_tree, in_avals, donate_argnums, out_tree()) @@ -2471,7 +2474,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, # this might naturally be a method, with ``fun`` as a ``self`` and # all the other arguments stored as attributes. @api_boundary - def lower(*args, **kwargs) -> stages.Lowered: + def lower(*args, _experimental_lowering_platform: Optional[str] = None, + **kwargs) -> stages.Lowered: """Lower a parallel-mapped form of this function for the given arguments. A parallel-mapped and lowered function is staged out of Python and @@ -2497,7 +2501,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, donated_invars=p.donated_invars, global_arg_shapes=p.global_arg_shapes_flat, is_explicit_global_axis_size=p.is_explicit_global_axis_size, - avals=abstract_args) + avals=abstract_args, + lowering_platform=_experimental_lowering_platform) return stages.Lowered.from_flat_info( computation, p.in_tree, abstract_args, donate_tuple, p.out_tree()) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index c840273ec..fb6021f40 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -326,7 +326,8 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins): def sharded_lowering(fun, device, backend, name, donated_invars, always_lower, - keep_unused, *arg_specs): + keep_unused, *arg_specs, + lowering_platform: Optional[str]): in_avals, in_shardings = util.unzip2(arg_specs) da = None @@ -334,7 +335,7 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower, da, in_shardings = not_none_device_or_backend_on_jit( backend, device, len(in_shardings)) - in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings] + in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore # Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know # the number of output avals at this stage. lower_sharding_computation will @@ -342,19 +343,22 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower, return pxla.lower_sharding_computation( fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars, in_avals, in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused, - always_lower=always_lower, devices_from_context=da) + always_lower=always_lower, devices_from_context=da, + lowering_platform=lowering_platform) def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name, donated_invars, keep_unused, *arg_specs): if config.jax_array: computation = sharded_lowering(fun, device, backend, name, donated_invars, - False, keep_unused, *arg_specs) + False, keep_unused, *arg_specs, + lowering_platform=None) allow_prop = [True] * len(computation.compile_args['global_out_avals']) return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call else: return lower_xla_callable(fun, device, backend, name, donated_invars, False, - keep_unused, *arg_specs).compile().unsafe_call + keep_unused, *arg_specs, + lowering_platform=None).compile().unsafe_call xla_callable = lu.cache(_xla_callable_uncached) @@ -414,7 +418,8 @@ def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr): @profiler.annotate_function def lower_xla_callable( fun: lu.WrappedFun, device, backend, name, donated_invars, - always_lower: bool, keep_unused: bool, *arg_specs): + always_lower: bool, keep_unused: bool, *arg_specs, + lowering_platform: Optional[str]): """Lower into XLA. Args: @@ -512,7 +517,8 @@ def lower_xla_callable( effects.ordered_effects.filter_in(closed_jaxpr.effects)) lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, unordered_effects, - ordered_effects, backend, backend.platform, + ordered_effects, backend, + lowering_platform or backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) module, keepalive, host_callbacks = ( lowering_result.module, lowering_result.keepalive, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f776580a3..13e39ea15 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1273,7 +1273,7 @@ def parallel_callable(fun: lu.WrappedFun, pmap_computation = lower_parallel_callable( fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, - is_explicit_global_axis_size, avals) + is_explicit_global_axis_size, avals, lowering_platform=None) pmap_executable = pmap_computation.compile() return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) @@ -1397,7 +1397,9 @@ def lower_parallel_callable( donated_invars: Sequence[bool], global_arg_shapes: Sequence[Optional[Tuple[int, ...]]], is_explicit_global_axis_size: bool, - avals: Sequence[core.AbstractValue]): + avals: Sequence[core.AbstractValue], + *, + lowering_platform: Optional[str]): # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: @@ -1502,7 +1504,7 @@ def lower_parallel_callable( unordered_effects, ordered_effects, backend, - backend.platform, + lowering_platform or backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars, @@ -2876,10 +2878,12 @@ def lower_sharding_computation( out_shardings: Union[Sequence[Union[sharding_internal.XLACompatibleSharding, UnspecifiedValue]], UnspecifiedValue], donated_invars: Sequence[bool], global_in_avals: Sequence[core.ShapedArray], + *, in_is_global: Sequence[bool], keep_unused: bool, always_lower: bool, - devices_from_context: Optional[Sequence[xc.Device]] = None + devices_from_context: Optional[Sequence[xc.Device]] = None, + lowering_platform: Optional[str], ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -3047,7 +3051,8 @@ def lower_sharding_computation( unordered_effects, ordered_effects, backend, - backend.platform, + # Optionally, override the lowering platform + lowering_platform or backend.platform, axis_ctx, name_stack, donated_invars, @@ -3102,7 +3107,8 @@ def lower_mesh_computation( spmd_lowering: bool, global_in_avals: Sequence[core.ShapedArray], tiling_method: Optional[TilingMethod], - in_is_global: Sequence[bool]) -> MeshComputation: + in_is_global: Sequence[bool], + lowering_platform: Optional[str]) -> MeshComputation: assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) @@ -3221,7 +3227,7 @@ def lower_mesh_computation( unordered_effects, ordered_effects, backend, - backend.platform, + lowering_platform or backend.platform, axis_ctx, name_stack, donated_invars, diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 436f9417c..9675febdd 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -603,7 +603,7 @@ def xmap(fun: Callable, return verify_outputs(out_flat, out_tree, params) @decorate_serial - def lower(*args): + def lower(*args, _experimental_lowering_platform: Optional[str] = None): fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args) avals_flat = [shaped_abstractify(arg) for arg in args_flat] computation = make_xmap_callable( @@ -611,12 +611,13 @@ def xmap(fun: Callable, params['donated_invars'], params['global_axis_sizes'], params['axis_resources'], params['resource_env'], params['backend'], params['spmd_in_axes'], params['spmd_out_axes_thunk'], params['in_positional_semantics'], - params['out_positional_semantics'], *avals_flat) + params['out_positional_semantics'], + _experimental_lowering_platform, *avals_flat) in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) in_avals = in_tree.unflatten(avals_flat) return stages.Lowered.from_flat_info( - computation, in_tree, in_avals, donate_argnums, out_tree(), + computation, in_tree, in_avals, donate_argnums, out_tree(), # type: ignore no_kwargs=True) fun_mapped.lower = lower @@ -631,7 +632,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_ fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, out_positional_semantics, - *in_avals).compile().unsafe_call + None, *in_avals).compile().unsafe_call distributed_debug_log(("Running xmapped function", name), ("python function", fun.f), ("mesh", resource_env.physical_mesh), @@ -644,7 +645,9 @@ def make_xmap_callable(fun: lu.WrappedFun, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, in_positional_semantics, - out_positional_semantics, *in_avals): + out_positional_semantics, + lowering_platform: Optional[str], + *in_avals): plan = EvaluationPlan.from_axis_resources( axis_resources, resource_env, global_axis_sizes, in_positional_semantics) @@ -702,16 +705,17 @@ def make_xmap_callable(fun: lu.WrappedFun, f, 'xmap', name, mesh, in_shardings, out_shardings, donated_invars, use_spmd_lowering, global_in_avals, - tiling_method=tiling_method, in_is_global=in_is_global) + tiling_method=tiling_method, in_is_global=in_is_global, + lowering_platform=lowering_platform) else: if config.jax_array: return dispatch.sharded_lowering( f, None, backend, name, donated_invars, False, True, - *[(a, None) for a in in_avals]) + *[(a, None) for a in in_avals], lowering_platform=lowering_platform) else: return dispatch.lower_xla_callable( f, None, backend, name, donated_invars, False, True, - *[(a, None) for a in in_avals]) + *[(a, None) for a in in_avals], lowering_platform=lowering_platform) class EvaluationPlan(NamedTuple): """Encapsulates preprocessing common to top-level xmap invocations and its translation rule.""" diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 8d481975a..c17831fd5 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -364,7 +364,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, wrapped = _python_pjit(fun, infer_params_fn) @api_boundary - def lower(*args, **kwargs): + def lower(*args, _experimental_lowering_platform: Optional[str] = None, + **kwargs): (args_flat, flat_local_in_avals, params, in_tree, out_tree, donate_argnums) = infer_params_fn(*args, **kwargs) if jax.config.jax_array: @@ -379,7 +380,8 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, lowering = _pjit_lower( params['jaxpr'], in_shardings, params['out_shardings'], params['resource_env'], params['donated_invars'], params['name'], - in_is_global, params['keep_unused'], always_lower=True) + in_is_global, params['keep_unused'], always_lower=True, + lowering_platform=_experimental_lowering_platform) if kwargs: args_kwargs_in_tree = in_tree @@ -1289,7 +1291,7 @@ def _pjit_call_impl(*args, jaxpr, compiled = _pjit_lower( jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_is_global, keep_unused, - always_lower=False).compile( + always_lower=False, lowering_platform=None).compile( _allow_propagation_to_outputs=_allow_propagation_to_outputs) _most_recent_pjit_call_executable.value = compiled # This check is expensive so only do it if enable_checks is on. @@ -1385,7 +1387,9 @@ def _pjit_lower_cached( name: str, in_is_global: Sequence[bool], keep_unused: bool, - always_lower: bool): + always_lower: bool, + *, + lowering_platform: Optional[str]): in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast( Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings @@ -1431,7 +1435,8 @@ def _pjit_lower_cached( return pxla.lower_mesh_computation( fun, api_name, name, mesh, in_shardings, out_shardings, donated_invars, - True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global) + True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global, + lowering_platform=lowering_platform) else: # Pass `in_is_global` here because this path is taken by both host local # avals and global avals. @@ -1442,7 +1447,8 @@ def _pjit_lower_cached( jaxpr.in_avals, in_is_global=in_is_global, keep_unused=keep_unused, always_lower=always_lower, devices_from_context=( - None if mesh is None or mesh.empty else list(mesh.devices.flat))) + None if mesh is None or mesh.empty else list(mesh.devices.flat)), + lowering_platform=lowering_platform) def pjit_staging_rule(trace, *args, **params): @@ -1657,7 +1663,8 @@ def _pjit_partial_eval(trace, *in_tracers, known_params["jaxpr"], known_params["in_shardings"], known_params["out_shardings"], known_params["resource_env"], known_params["donated_invars"], known_params["name"], - in_is_global, known_params['keep_unused'], always_lower=False).compile( + in_is_global, known_params['keep_unused'], always_lower=False, + lowering_platform=None).compile( _allow_propagation_to_outputs=[True] * len(known_params['out_shardings']), _allow_compile_replicated=False) da = compiled._device_assignment diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index cfe572cb5..c2a2759fc 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -497,9 +497,12 @@ def flatten_fun_jax(fun_jax: Callable, args_tf: Sequence[TfVal], # preserve the lowering function. This will be used in the _lower_native_and_run. # We rely on the fact that the lowering is the same for the function # taking pytrees, and the one taking flat args. - def fun_flat_jax_lower(*args_flat_jax): + def fun_flat_jax_lower(*args_flat_jax, _experimental_lowering_platform): tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax) - lowered = fun_jax.lower(*tree_args, **tree_kwargs) + lowered = fun_jax.lower( + *tree_args, + _experimental_lowering_platform=_experimental_lowering_platform, + **tree_kwargs) out_tree = lowered.out_tree nonlocal out_tree_ref assert out_tree_ref is None or out_tree_ref == out_tree @@ -678,23 +681,24 @@ def _lower_native_and_run(fun_jax: Callable, ] if lowering_params.experimental_native_lowering_platforms: - lowered = cross_platform_lowering( - fun_jax, arg_specs_jax, # type: ignore[arg-type] - platforms=lowering_params.experimental_native_lowering_platforms - )._lowering # type: ignore + lowering_platform = lowering_params.experimental_native_lowering_platforms[0] else: - if not hasattr(fun_jax, "lower") or abstracted_axes: - # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also - # convert(f_jax), in which case a "jit" is implied. We also add a jit when - # we need to pass the abstracted axes. - # TODO(necula): Will clean this when we clean the native lowering jax2tf API - fun_jax_lower = jax.jit(fun_jax, - abstracted_axes=abstracted_axes).lower - else: - # If we have a pjit or pmap already we do not wrap with another - fun_jax_lower = fun_jax.lower + lowering_platform = None - lowered = fun_jax_lower(*arg_specs_jax)._lowering # type: ignore + if not hasattr(fun_jax, "lower") or abstracted_axes: + # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also + # convert(f_jax), in which case a "jit" is implied. We also add a jit when + # we need to pass the abstracted axes. + # TODO(necula): Will clean this when we clean the native lowering jax2tf API + fun_jax_lower = jax.jit(fun_jax, + abstracted_axes=abstracted_axes).lower + else: + # If we have a pjit or pmap already we do not wrap with another + fun_jax_lower = fun_jax.lower + + lowered = fun_jax_lower( + *arg_specs_jax, + _experimental_lowering_platform=lowering_platform)._lowering # type: ignore if config.jax2tf_use_stablehlo: mlir_module = lowered.stablehlo() @@ -708,6 +712,8 @@ def _lower_native_and_run(fun_jax: Callable, if "global_out_avals" in lowered.compile_args: # This is currently the case for pjit out_avals = lowered.compile_args["global_out_avals"] + elif "shards" in lowered.compile_args: # for PmapComputation + out_avals = lowered.compile_args["shards"].out_sharded_avals else: out_avals = lowered.compile_args["out_avals"] if lowered.compile_args["host_callbacks"]: @@ -834,136 +840,6 @@ def _lower_native_and_run(fun_jax: Callable, for res_val, out_aval in zip(res, out_avals)) return res, out_avals -def cross_platform_lowering(fun_jax, arg_specs: Sequence[jax.Array], - *, - platforms: Sequence[str] = ()): - - context_mesh = pxla.thread_resources.env.physical_mesh - if not context_mesh.empty: - # What devices we need - if context_mesh.is_multi_process: - raise NotImplementedError("cross_platform lowering is not supported for multi-host lowering") - devices = np.array(context_mesh.devices).reshape((-1,)) - devices_shape = np.shape(context_mesh.devices) - axis_names = context_mesh.axis_names - else: - devices = [config.jax_default_device or jax.local_devices()[0]] # type: ignore - devices_shape = (1,) - axis_names = ("_no_axis",) - - lowering_client = LoweringOnlyClient(platforms[0], - 1 + max(d.id for d in devices)) - lowering_devices = [lowering_client.devices[d.id] for d in devices] - lowering_mesh = sharding.Mesh( - np.array(lowering_devices).reshape(devices_shape), # type: ignore - axis_names) - - try: - orig_jax_default_device = config.jax_default_device - config.update("jax_default_device", lowering_devices[0]) # For nullary functions - prev_get_and_check_device_assignment = pxla._get_and_check_device_assignment - pxla._get_and_check_device_assignment = partial(_get_and_check_device_assignment, - lowering_client) - with lowering_mesh: - if not hasattr(fun_jax, "lower"): - # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also - # convert(f_jax), in which case a "jit" is implied. We also add a jit when - # we need to pass the abstracted axes or shardings. - # TODO(necula): Will clean this when we clean the native lowering jax2tf API - fun_jax_lower = jax.jit(fun_jax).lower - else: - fun_jax_lower = fun_jax.lower - lowered = fun_jax_lower(*arg_specs) - return lowered - finally: - config.update("jax_default_device", orig_jax_default_device) - pxla._get_and_check_device_assignment = prev_get_and_check_device_assignment - -class LoweringOnlyClient: - """A Client that overrides the platform, for cross-platform lowering only.""" - def __init__(self, platform: str, nr_devices: int): - self.platform = platform - self._process_index = 0 - self.devices = [LoweringOnlyDevice(self, i) for i in range(nr_devices)] - self.lowering_only_client = True - - def __str__(self): - return f"LoweringOnlyClient({self.platform})" - - def process_index(self): - return self._process_index - - def device_count(self): - return len(self.devices) - -class LoweringOnlyDevice: - """A Device that overrides the platform, for cross-platform lowering only.""" - def __init__(self, client: LoweringOnlyClient, id: int): - self.client = client - self.process_index = client.process_index() - self.id = id - - def __str__(self): - return f"LoweringOnlyDevice({self.platform}, id={self.id})" - - -# This is a copy of pxla._get_and_check_device_assignment, because we need -# to change its behavior for cross-platform lowering. -# The changes are marked below with "CHANGED:". -# This function reconciles the device assignment from shardings and from -# the mesh context. Some JAX primitives (xmap, shard_map) carry their own -# mesh of devices, instead of relying on the mesh context manager, which would -# conflict with the lowering-only devices. We must now only avoid raising -# errors in the case, but we must also pick the lowering devices. -def _get_and_check_device_assignment( - lowering_client: LoweringOnlyClient, # CHANGED: we pass the overriding client - shardings: Iterable[pxla.ShardingInfo], - devices: Optional[Sequence[xla_client.Device]], -) -> Tuple[xla_client.Client, Sequence[xla_client.Device]]: - from jax._src.api import local_devices - - first_sharding_info = None - if devices is None: - devices = [] - else: - devices = list(devices) - - for i, s_type, source_info in shardings: - if pxla.is_auto(i) or pxla._is_unspecified(i): - continue - # Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been - # skipped. - if first_sharding_info is None: - first_sharding_info = (list(i._device_assignment), s_type, source_info) # type: ignore - arr_device_assignment = list(i._device_assignment) # type: ignore - if not devices: - if first_sharding_info[0] != arr_device_assignment: - # CHANGED: do not error if the only difference is in lowering_only_client - if not all((d1.id == d2.id and - (hasattr(d1.client, "lowering_only_client") or hasattr(d2.client, "lowering_only_client"))) - for d1, d2 in zip(first_sharding_info[0], arr_device_assignment)): - raise pxla.DeviceAssignmentMismatchError([ - pxla.DeviceAssignmentMismatch(*first_sharding_info), - pxla.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) - else: - if devices != arr_device_assignment: - # CHANGED: do not error if the only difference is in lowering_only_client - if not all((d1.id == d2.id and - (hasattr(d1.client, "lowering_only_client") or hasattr(d2.client, "lowering_only_client"))) - for d1, d2 in zip(devices, arr_device_assignment)): - raise pxla.DeviceAssignmentMismatchError([ - pxla.DeviceAssignmentMismatch(devices, pxla.MismatchType.CONTEXT_DEVICES, None), - pxla.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) - if first_sharding_info is None and devices: - final_device_assignment = devices - elif first_sharding_info is None: - final_device_assignment = [config.jax_default_device or local_devices()[0]] - else: - final_device_assignment = first_sharding_info[0] # type: ignore - - # CHANGED: override the device assignment - final_device_assignment = tuple(lowering_client.devices[d.id] for d in final_device_assignment) # type: ignore - return xb.get_device_backend(final_device_assignment[0]), final_device_assignment def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun, in_vals: Sequence[TfVal], diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 82cf7f7bc..ba17ae95d 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1375,32 +1375,40 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): f"See op.name = : {op.name}") @parameterized.named_parameters( - dict(testcase_name=f"{'with_mesh_' if with_mesh else ''}{'nullary_' if nullary else ''}{transform}_pjit_sharding={pjit_sharding}", - with_mesh=with_mesh, transform=transform, nullary=nullary, pjit_sharding=pjit_sharding) - # The inner transformation to apply to the lowered function - for transform in ["base", - "jit", - "pjit", "pjit_in_shardings_None", "pjit_in_shardings_P", "pjit_in_shardings_Sharding", - "shard_map", "xmap", "pmap"] - # The sharding to be used for the outer pjit - for pjit_sharding in ( - ["unspecified"] if transform == "pmap" else - ["unspecified", "none", "P", "Sharding"]) + dict(testcase_name=( + f"{'with_mesh_' if with_mesh else ''}" + f"2={transform2 if transform2 != 'none' else ''}" + f"_1={transform1 if transform1 != 'none' else ''}" + f"{'_nullary' if nullary else ''}"), + with_mesh=with_mesh, transform1=transform1, + transform2=transform2, nullary=nullary) + # Test transform2(transform1(func) + for transform1 in [ + "none", + "jit", + "pjit", "pjit_in_shardings_None", "pjit_in_shardings_P", + "pjit_in_shardings_Sharding", + "shard_map", "xmap", "pmap"] + for transform2 in ( + ["none", "pjit_in_shardings_None", "pjit_in_shardings_P", + "pjit_in_shardings_Sharding"] + ) # Whether the function can be nullary for nullary in ( - [False] if (pjit_sharding != "unspecified") else - [True, False] - ) + # To reduce the number of tests + [True, False] if transform2 == "none" else + [False]) # Whether we use a "with mesh" for with_mesh in ( - [True] if (transform not in ["base", "jit", "pjit"] or - pjit_sharding != "unspecified") else + [True] if (transform1 not in ["base", "jit", "pjit"] or + transform2 != "none") else [False, True]) ) - def test_cross_platform(self, with_mesh=False, transform="jit", nullary=False, pjit_sharding="unspecified"): + def test_cross_platform(self, with_mesh=True, transform1="xmap", + transform2="none", nullary=True): # Tests cross-lowering for # with mesh: - # pjit(transform(func), in_sharding=pjit_sharding) + # transform2(transform1(func)) if not config.jax_array: raise unittest.SkipTest("cross_platform test work only with jax.Array") if not config.jax_jit_pjit_api_merge: @@ -1409,49 +1417,60 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): mesh = sharding.Mesh(jax.devices()[:1], ("a",)) # cummax has distinctive lowering for TPU, using a reduce-window op func = lambda x: lax.cummax(x, axis=0, reverse=False) - # For shard_map we cannot use cummax :-( because it does not have a replication rule - # But we use lax.all_gather which on TPU is lowered with a all-gather op + # For shard_map we cannot use cummax :-( because it does not have a + # replication rule. But we use lax.all_gather which on TPU is lowered with + # an all-gather op func_shard_map = lambda x: lax.all_gather(x, 'a', axis=1, tiled=True) - transformed_func = dict( - base=func, - jit=jax.jit(func), - jit_in_shardings_None=jax.jit(func, in_shardings=None), - jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)), - jit_in_shardings_Sharding=jax.jit(func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), - pjit=pjit.pjit(func), - pjit_in_shardings_None=pjit.pjit(func, in_shardings=None), - pjit_in_shardings_P=pjit.pjit(func, in_shardings=(P("a"),)), - pjit_in_shardings_Sharding=pjit.pjit(func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), - shard_map=( - shard_map(func_shard_map, mesh, in_specs=(P("a", None),), out_specs=P("a", None))), - xmap=xmap(func, in_axes=({0: 'axis'},), out_axes={0: 'axis'}, axis_resources={'axis': 'a'}), - pmap=jax.pmap(func, in_axes=0, out_axes=0), - )[transform] - pjit_transformed_func = dict( - unspecified=pjit.pjit(transformed_func), - none=pjit.pjit(transformed_func, in_shardings=None), - P=pjit.pjit(transformed_func, in_shardings=(P("a"),)), - Sharding=pjit.pjit(transformed_func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), - )[pjit_sharding] - if pjit_sharding == "unspecified": - if transform == "xmap": - raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes") + def apply_transform(func, transform: str): + transformed_func = dict( + none=func, + jit=jax.jit(func), + jit_in_shardings_None=jax.jit(func, in_shardings=None), # type: ignore + jit_in_shardings_P=jax.jit(func, in_shardings=(P("a"),)), # type: ignore + jit_in_shardings_Sharding=jax.jit( + func, in_shardings=(sharding.NamedSharding(mesh, P("a")),)), # type: ignore + pjit=pjit.pjit(func), + pjit_in_shardings_None=pjit.pjit(func, in_shardings=None, + out_shardings=None), + pjit_in_shardings_P=pjit.pjit(func, in_shardings=(P("a"),), + out_shardings=P("a")), + pjit_in_shardings_Sharding=pjit.pjit( + func, + in_shardings=(sharding.NamedSharding(mesh, P("a")),), + out_shardings=sharding.NamedSharding(mesh, P("a"))), + shard_map=( + shard_map(func, mesh, in_specs=(P("a", None),), + out_specs=P("a", None))), + xmap=xmap(func, in_axes=({0: 'axis'},), + out_axes={0: 'axis'}, axis_resources={'axis': 'a'}), + pmap=jax.pmap(func, in_axes=0, out_axes=0), + )[transform] + return transformed_func + transformed1_func = apply_transform( + (func_shard_map if transform1 == "shard_map" else func), + transform1) + assert transform2 not in ["xmap", "shard_map"] + transformed2_func = apply_transform(transformed1_func, transform2) + + if transform1 == "xmap" and transform2 in ["pjit", "none"]: + raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes") + + if transform1 == "pmap": + x = x.reshape((1, -1)) # Since we use 1 device if not nullary: - func_to_convert = pjit_transformed_func + func_to_convert = transformed2_func args = [x] else: - func_to_convert = lambda: pjit_transformed_func(jnp.ones(x.shape, dtype=x.dtype)) + func_to_convert = lambda: transformed2_func(jnp.ones(x.shape, + dtype=x.dtype)) args = [] - if transform == "pmap": + if transform1 == "pmap": if nullary: raise unittest.SkipTest("Cannot lower nested pmap: jit-of-pmap warning") - raise unittest.SkipTest("TODO: pmap picks the devices from jax.devices() and will lower for CPU") - - if transform == "xmap": - raise unittest.SkipTest("TODO: xmap does not pick up the overriden mesh and will lower for CPU") + raise unittest.SkipTest("TODO: figure out how to invoke pmap from TF") f_tf = jax2tf.convert(func_to_convert, experimental_native_lowering=True, @@ -1464,10 +1483,10 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): _ = func_to_convert(*args) tf_hlo = f_tf.experimental_get_compiler_ir(*args)(stage="hlo") - if transform == "shard_map": - self.assertIn("all-gather(f32[4,6]", tf_hlo) + if transform1 == "shard_map": + self.assertIn(" all-gather(f32[4,6]", tf_hlo) else: - self.assertIn("reduce-window(f32[4,6]", tf_hlo) + self.assertIn(" reduce-window(", tf_hlo) def get_serialized_computation(