diff --git a/jax/_src/api.py b/jax/_src/api.py index 41c230c7f..a3d99eda3 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1811,9 +1811,41 @@ def _cpp_pmap( pmap_f = wraps(fun)(cpp_mapped_f) + @api_boundary + def specialize(*args, **kwargs): + lowering_parameters = kwargs.pop( + '_experimental_lowering_parameters', mlir.LoweringParameters()) + p = _prepare_pmap( + fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, + devices, backend, axis_size, args, kwargs) + abstract_args = list(map(shaped_abstractify, p.flat_args)) + lower_callable = partial( + pxla.lower_parallel_callable, p.flat_fun, backend, axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + out_axes_thunk=p.out_axes_thunk, + donated_invars=p.donated_invars, + is_explicit_global_axis_size=p.is_explicit_global_axis_size, + avals=abstract_args, + lowering_parameters=lowering_parameters) + jaxpr, _, _, _, _ = pxla.get_pmap_jaxpr( + p.flat_fun, backend, axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + out_axes_thunk=p.out_axes_thunk, + avals=abstract_args) + args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) + return stages.Specialized(jaxpr, args_info, p.flat_fun.__name__, + p.out_tree(), lower_callable) + pmap_f.lower = _pmap_lower( fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, backend, axis_size, donate_tuple) + pmap_f.specialize = specialize return pmap_f @@ -1845,7 +1877,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, devices, backend, axis_size, args, kwargs) abstract_args = list(map(shaped_abstractify, p.flat_args)) - computation, closed_jaxpr = pxla.lower_parallel_callable( + computation = pxla.lower_parallel_callable( p.flat_fun, backend, axis_name, axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, devices=p.devices, @@ -1857,8 +1889,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, avals=abstract_args, lowering_parameters=lowering_parameters) return stages.Lowered.from_flat_info( - computation, p.in_tree, abstract_args, donate_tuple, p.out_tree(), - fun_name=p.flat_fun.__name__, jaxpr=closed_jaxpr) + computation, p.in_tree, abstract_args, donate_tuple, p.out_tree()) return lower diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 04cdf8e5d..686ee0ced 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -557,7 +557,7 @@ def parallel_callable(fun: lu.WrappedFun, donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, *avals): - pmap_computation, _ = lower_parallel_callable( + pmap_computation = lower_parallel_callable( fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, is_explicit_global_axis_size, avals, @@ -665,6 +665,31 @@ def stage_parallel_callable( return jaxpr, consts, replicas, shards +def get_pmap_jaxpr( + fun: lu.WrappedFun, + backend_name: str | None, + axis_name: core.AxisName, + axis_size: int, + global_axis_size: int, + devices: Sequence[xc.Device] | None, + name: str, + in_axes: Iterable[int | None], + out_axes_thunk: Callable[[], Sequence[int | None]], + avals: Sequence[core.AbstractValue]): + if devices is not None and backend_name is None: + backend = xb.get_device_backend(devices[0]) + else: + backend = xb.get_backend(backend_name) + + pci = ParallelCallableInfo( + name, backend, axis_name, axis_size, global_axis_size, devices, + in_axes, out_axes_thunk, avals) + jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) + jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + return closed_jaxpr, backend, replicas, shards, pci + + @profiler.annotate_function def lower_parallel_callable( fun: lu.WrappedFun, @@ -680,7 +705,7 @@ def lower_parallel_callable( is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, - lowering_parameters: mlir.LoweringParameters) -> tuple[PmapComputation, core.ClosedJaxpr]: + lowering_parameters: mlir.LoweringParameters) -> PmapComputation: # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: @@ -691,10 +716,10 @@ def lower_parallel_callable( f"Specified axis_size {global_axis_size} doesn't match received " f"axis_size {axis_size}.") - if devices is not None and backend_name is None: - backend = xb.get_device_backend(devices[0]) - else: - backend = xb.get_backend(backend_name) + closed_jaxpr, backend, replicas, shards, pci = get_pmap_jaxpr( + fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, + in_axes, out_axes_thunk, avals) + jaxpr = closed_jaxpr.jaxpr no_nested_sharding = False must_run_on_all_devices = False @@ -711,10 +736,6 @@ def lower_parallel_callable( # devices). Nested sharding is ok in this case. must_run_on_all_devices = True - pci = ParallelCallableInfo( - name, backend, axis_name, axis_size, global_axis_size, devices, - in_axes, out_axes_thunk, avals) - jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun) if logger.isEnabledFor(logging.DEBUG): logger.debug("sharded_avals: %s", shards.sharded_avals) logger.debug("global_sharded_avals: %s", shards.global_sharded_avals) @@ -756,8 +777,6 @@ def lower_parallel_callable( axis_env = sharding_impls.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap')) - jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) replicated_args = [axis is None for axis in in_axes] tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals), backend.platform) @@ -798,7 +817,7 @@ def lower_parallel_callable( keepalive=lowering_result.keepalive, host_callbacks=lowering_result.host_callbacks, jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, - shape_poly_state=lowering_result.shape_poly_state), closed_jaxpr + shape_poly_state=lowering_result.shape_poly_state) def _pmap_unmap_shaped_array( diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 5487f5699..77896d7f9 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -617,7 +617,7 @@ def xmap(fun: Callable, '_experimental_lowering_platform', mlir.LoweringParameters()) fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args) avals_flat = [shaped_abstractify(arg) for arg in args_flat] - computation, jaxpr = make_xmap_callable( + computation = make_xmap_callable( fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'], params['donated_invars'], params['global_axis_sizes'], params['axis_resources'], params['resource_env'], params['backend'], params['spmd_in_axes'], @@ -627,8 +627,7 @@ def xmap(fun: Callable, in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) in_avals = in_tree.unflatten(avals_flat) return stages.Lowered.from_flat_info( - computation, in_tree, in_avals, donate_argnums, out_tree(), - no_kwargs=True, fun_name=params['name'], jaxpr=jaxpr) + computation, in_tree, in_avals, donate_argnums, out_tree()) fun_mapped.lower = lower return type_cast(stages.Wrapped, fun_mapped) @@ -637,7 +636,7 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_ global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk): in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args] - computation, _ = make_xmap_callable( + computation = make_xmap_callable( fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, @@ -709,7 +708,7 @@ def make_xmap_callable(fun: lu.WrappedFun, in_shardings, out_shardings, donated_invars, use_spmd_lowering, in_avals, tiling_method=tiling_method, - lowering_parameters=lowering_parameters), jaxpr + lowering_parameters=lowering_parameters) else: jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals) return pxla.lower_sharding_computation( @@ -717,7 +716,7 @@ def make_xmap_callable(fun: lu.WrappedFun, (UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals), (None,) * len(in_avals), (None,) * len(out_avals), donated_invars, keep_unused=True, inline=False, - devices_from_context=None, lowering_parameters=lowering_parameters), jaxpr + devices_from_context=None, lowering_parameters=lowering_parameters) class EvaluationPlan(NamedTuple): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 504048339..4a3d8400a 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -498,8 +498,7 @@ def _make_jit_wrapper(jit_info: PjitInfo): donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d) jaxpr = params["jaxpr"] return stages.Lowered.from_flat_info( - lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree, - fun_name=params["name"], jaxpr=jaxpr) + lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree) @api_boundary def eval_shape(*args, **kwargs): diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 81e2edfd7..ba71ca655 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -643,30 +643,23 @@ class Lowered(Stage): querying properties of lowered computations across JAX's various lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ - __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs", "_fun_name", "_jaxpr"] + __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"] _lowering: XlaLowering args_info: Any # PyTree of ArgInfo out_tree: tree_util.PyTreeDef _no_kwargs: bool - _fun_name: str - _jaxpr: core.ClosedJaxpr | None # Can be None when this class is constructed - # outside of JAX core. def __init__( self, lowering: XlaLowering, args_info, # PyTree of ArgInfo out_tree: tree_util.PyTreeDef, - no_kwargs: bool = False, - fun_name: str = "", - jaxpr: core.ClosedJaxpr | None = None): + no_kwargs: bool = False): self._lowering = lowering self.args_info = args_info self.out_tree = out_tree self._no_kwargs = no_kwargs - self._fun_name = fun_name - self._jaxpr = jaxpr @classmethod def from_flat_info(cls, @@ -675,9 +668,7 @@ class Lowered(Stage): in_avals, donate_argnums: tuple[int, ...], out_tree: tree_util.PyTreeDef, - no_kwargs: bool = False, - fun_name: str = "", - jaxpr: core.ClosedJaxpr | None = None): + no_kwargs: bool = False): """Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef. Args: @@ -686,15 +677,12 @@ class Lowered(Stage): no_kwargs: If ``True`` the transformation, and the ``Compiled`` returned from this object will not support keyword arguments (an error will be raised if some are provided). - fun_name: the name of the lowered function. - jaxpr: the Jaxpr of the lowered function. The value `None` is for - backwards compatibility, and is used only outside JAX core. """ return cls( lowering, make_args_info(in_tree, in_avals, donate_argnums), out_tree, - no_kwargs=no_kwargs, fun_name=fun_name, jaxpr=jaxpr) + no_kwargs=no_kwargs) @property def out_info(self): # PyTree of OutInfo diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index b87399059..774d5c830 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -433,15 +433,17 @@ def export(fun_jax: Callable, """ def do_export(*args_specs, **kwargs_specs) -> Exported: - if not hasattr(fun_jax, "lower"): + if hasattr(fun_jax, "lower"): + # If we have a pjit or pmap already we do not wrap with another, and we + # allow shardings. + wrapped_fun_jax = fun_jax + else: # We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also # convert(f_jax), in which case a "jit" is implied. In that case we raise # an error if the lowered function contains non-replicated sharding annotations. wrapped_fun_jax = jax.jit(fun_jax) - else: - # If we have a pjit or pmap already we do not wrap with another, and we - # allow shardings. - wrapped_fun_jax = fun_jax # type: ignore + + has_specialize = hasattr(wrapped_fun_jax, "specialize") if lowering_platforms is not None: actual_lowering_platforms = tuple(lowering_platforms) @@ -464,19 +466,32 @@ def export(fun_jax: Callable, self_descr=f"current (from {shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", other_descr=shape_poly.args_kwargs_path_to_str(k_path)) - lowered = wrapped_fun_jax.lower( - *args_specs, **kwargs_specs, - _experimental_lowering_parameters=mlir.LoweringParameters( + if has_specialize: + specialized = wrapped_fun_jax.specialize( + *args_specs, **kwargs_specs, + _experimental_lowering_parameters=mlir.LoweringParameters( platforms=actual_lowering_platforms, for_export=True, - )) + )) + jaxpr, fun_name = specialized.jaxpr, specialized.fun_name + lowered = specialized.lower() + else: + lowered = wrapped_fun_jax.lower( + *args_specs, **kwargs_specs, + _experimental_lowering_parameters=mlir.LoweringParameters( + platforms=actual_lowering_platforms, + for_export=True, + )) + jaxpr, fun_name = None, util.fun_name(wrapped_fun_jax) return _export_lowered( - lowered, disabled_checks=disabled_checks, + lowered, jaxpr, fun_name, + disabled_checks=disabled_checks, _device_assignment_for_internal_jax2tf_use_only=_device_assignment_for_internal_jax2tf_use_only) return do_export def _export_lowered( lowered: stages.Lowered, + jaxpr: core.ClosedJaxpr, fun_name: str, disabled_checks: Sequence[DisabledSafetyCheck] = (), _device_assignment_for_internal_jax2tf_use_only = None, ) -> Exported: @@ -563,8 +578,8 @@ def _export_lowered( def _get_exported_vjp(exp_primal: Exported) -> Exported: # Turn the primal jaxpr into a function, in preparation for exporting # the VJP. Note that jaxpr_as_fun produces a function with flat arguments - assert(lowered._jaxpr is not None) # None only when the lowered was created outside JAX - fun_jax = core.jaxpr_as_fun(lowered._jaxpr) + assert(jaxpr is not None) # None only when the lowered was created outside JAX + fun_jax = core.jaxpr_as_fun(jaxpr) fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax, in_tree=exp_primal.in_tree, @@ -580,7 +595,7 @@ def _export_lowered( disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) return Exported( - fun_name=lowered._fun_name, + fun_name=fun_name, in_tree=lowered.in_tree, out_tree=lowered.out_tree, in_avals=tuple(args_avals_flat), diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index fb36b95c9..99d1e1c84 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -34,7 +34,6 @@ import jax from jax._src import compiler from jax._src import config from jax._src import maps -from jax._src.maps import xmap from jax._src import test_util as jtu from jax._src import xla_bridge from jax import lax @@ -56,7 +55,6 @@ config.parse_flags_with_absl() from jax.experimental.jax2tf.tests import tf_test_util prev_xla_flags = None -prev_spmd_lowering_flag = None topology = None @@ -78,9 +76,6 @@ def setUpModule(): " --xla_force_host_platform_device_count=8") # Clear any cached backends so new CPU backend will pick up the env var. xla_bridge.get_backend.cache_clear() - global prev_spmd_lowering_flag - prev_spmd_lowering_flag = maps.SPMD_LOWERING.value - config.update('experimental_xmap_spmd_lowering', True) def tearDownModule(): @@ -89,7 +84,6 @@ def tearDownModule(): else: os.environ["XLA_FLAGS"] = prev_xla_flags xla_bridge.get_backend.cache_clear() - config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag) class ShardingTest(tf_test_util.JaxToTfTestCase): @@ -536,115 +530,6 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): "function with sharded arguments or results must be used under a `tf.function` context"): jax2tf.convert(f_jax)(a) - def test_xmap_basic(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - bshape = (2, 7) - b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) - - # f_jax: f32[16,8,5], f32[2,7] -> f32[16,8,10], f32[2,28] - # lambda ...: f32[5], f32[7] -> f32[10], f32[28] - f_jax = xmap(lambda a, b: (jnp.concatenate([a, a], axis=0) * 2., - jnp.concatenate([b, b, b, b], axis=0) * 4.), - in_axes=({0: 'a', 1: 'b'}, ['c', ...]), - out_axes=({0: 'a', 1: 'b'}, ['c', ...]), - axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) - - @tf.function(autograph=False, jit_compile=True) - def f_tf(a, b): - # xmap works only with native serialization - f_converted = jax2tf.convert(f_jax, native_serialization=True) - if jtu.test_device_matches(["tpu"]): - res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], - device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) - ) - return (res[0], res[1]) - else: - return f_converted(a, b) - - with Mesh(devices, ('x', 'y')): - res_jax = f_jax(a, b) - self.assertAllClose(res_jax, (jnp.concatenate([a, a], axis=2) * 2., - jnp.concatenate([b, b, b, b], axis=1) * 4.)) - res_tf = f_tf(a, b) - self.assertAllClose(res_tf, res_jax) - - self.check_sharding( - jax2tf.convert(f_jax, native_serialization=True), [a, b], - checks=[ - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1), - # The output sharding - (r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 1), - (r"f32\[2,28\].*custom_call_target.*Sharding.*sharding.*replicated", 1), - ]) - - def test_xmap_collective_reduce(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - bshape = (2, 7) - b = np.arange(np.prod(bshape), dtype=np.float32).reshape(bshape) - f_jax = xmap(lambda a, b: (lax.psum(a * 2., 'a'), b * 4.), - in_axes=(['a', 'b', ...], {0: 'c'}), - out_axes=(['b', ...], {0: 'c'}), - axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) - - @tf.function(autograph=False, jit_compile=True) - def f_tf(a, b): - f_converted = jax2tf.convert(f_jax, native_serialization=True) - if jtu.test_device_matches(["tpu"]): - res = tf.compat.v1.tpu.rewrite( - f_converted, [tf.convert_to_tensor(a), tf.convert_to_tensor(b)], - device_assignment=self.device_assignment( - computation_shape=[1, 1, 1, 2]) - ) - return (res[0], res[1]) - else: - return f_converted(a, b) - - with Mesh(devices, ('x', 'y')): - res_jax = f_jax(a, b) - self.assertAllClose(res_jax, ((a * 2.).sum(0), b * 4.)) - res_tf = f_tf(a, b) - self.assertAllClose(res_tf, res_jax) - self.check_sharding( - jax2tf.convert(f_jax, native_serialization=True), [a, b], - checks=[ - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", 1), - (r"f32\[2,7\].*custom_call_target.*Sharding.*sharding.*replicated", 2), - (r"f32\[8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), - ]) - - def test_grad_xmap(self): - devices = np.reshape(self.devices, (1, 2)) - ashape = (16, 8, 5) - a = np.arange(np.prod(ashape), dtype=np.float32).reshape(ashape) - - # f_jax: f32[16,8,5]-> f32[16,8,10] - # lambda ...: f32[5]-> f32[10] - f_jax = xmap(lambda a: jnp.concatenate([a, a], axis=0) * 2., - in_axes=({0: 'a', 1: 'b'}), - out_axes={0: 'a', 1: 'b'}, - axis_resources={'a': 'x', 'b': 'y'}) - - def f_grad_tf(a, res_ct): - with tf.GradientTape(persistent=True) as tape: - tape.watch(a) - res_tf = jax2tf.convert(f_jax, native_serialization=True)(a) - return tape.gradient(res_tf, a, output_gradients=res_ct) - - with Mesh(devices, ('x', 'y')): - self.check_sharding(f_grad_tf, [a, np.concatenate([a, a], axis=2)], - checks=[ - # Primal input and grad output - (r"f32\[16,8,5\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(2)), - # Input cotangent - (r"f32\[16,8,10\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2,1\]", self.GEQ(1)), - ]) - @jtu.ignore_warning(category=UserWarning, message="all_to_all .* are only implemented properly for TPUs and GPUs .*") def test_shmap_all_to_all(self): diff --git a/tests/export_test.py b/tests/export_test.py index d1bc17fe7..ed42e36d0 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -556,12 +556,16 @@ class JaxExportTest(jtu.JaxTestCase): return jnp.sin(x) # This makes it look like a jitted-function - def lower(self, x, - _experimental_lowering_parameters=None): + def lower(self, x, _experimental_lowering_parameters=None): return jax.jit(self.__call__).lower( x, _experimental_lowering_parameters=_experimental_lowering_parameters) + def specialize(self, x, _experimental_lowering_parameters=None): + return jax.jit(self.__call__).specialize( + x, + _experimental_lowering_parameters=_experimental_lowering_parameters) + a, = export.symbolic_shape("a,") # No error _ = get_exported(MyCallable())(