diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fedcc95f..8489c46fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.9 * Changes - * The flags experimental_cpp_jit, and experimental_cpp_pjit have been removed. + * The flags experimental_cpp_jit, experimental_cpp_pjit and + experimental_cpp_pmap have been removed. They are now always on. * Deprecations diff --git a/jax/_src/api.py b/jax/_src/api.py index 9bb51401c..604f65b8a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -56,7 +56,7 @@ from jax._src.api_util import ( argnums_partial, argnums_partial_except, flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, shaped_abstractify, _ensure_str_tuple, - check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS) + check_callable, debug_info, result_paths, flat_out_axes, debug_info_final) from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc @@ -1552,12 +1552,7 @@ def pmap( " removed from JAX. Please migrate to pjit and remove global_arg_shapes" " from pmap.") - if FLAGS.experimental_cpp_pmap: - func = _cpp_pmap - else: - func = _python_pmap - - return func( + return _cpp_pmap( fun, axis_name, in_axes=in_axes, @@ -1687,35 +1682,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, global_axis_size=global_axis_size, is_explicit_global_axis_size=is_explicit_global_axis_size) -def _get_f_mapped( - *, - fun: Callable, - axis_name: Optional[AxisName], - in_axes=0, - out_axes=0, - static_broadcasted_tuple: Tuple[int, ...], - devices: Optional[Sequence[xc.Device]], # noqa: F811 - backend: Optional[str], - axis_size: Optional[int], - donate_tuple: Tuple[int, ...], - ): - def pmap_f(*args, **kwargs): - p = _prepare_pmap( - fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, - devices, backend, axis_size, args, kwargs) - for arg in p.flat_args: - dispatch.check_arg(arg) - out = pxla.xla_pmap( - p.flat_fun, *p.flat_args, backend=backend, axis_name=axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk, - name=p.flat_fun.__name__, donated_invars=p.donated_invars, - is_explicit_global_axis_size=p.is_explicit_global_axis_size) - return p.out_tree, out - - return pmap_f - def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes): @@ -1738,47 +1704,6 @@ def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums, return axis_name, static_broadcasted_tuple, donate_tuple -def _python_pmap( - fun: Callable, - axis_name: Optional[AxisName] = None, - *, - in_axes=0, - out_axes=0, - static_broadcasted_argnums: Union[int, Iterable[int]] = (), - devices: Optional[Sequence[xc.Device]] = None, # noqa: F811 - backend: Optional[str] = None, - axis_size: Optional[int] = None, - donate_argnums: Union[int, Iterable[int]] = (), - ) -> stages.Wrapped: - """The Python only implementation.""" - axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( - fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, - out_axes) - - @wraps(fun) - @api_boundary - def pmap_f(*args, **kwargs): - f_pmapped_ = _get_f_mapped( - fun=fun, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_tuple=static_broadcasted_tuple, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_tuple=donate_tuple) - - out_tree, out_flat = f_pmapped_(*args, **kwargs) - return tree_unflatten(out_tree(), out_flat) - - pmap_f.lower = _pmap_lower( - fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, - backend, axis_size, donate_tuple) - - return cast(stages.Wrapped, pmap_f) - - class _PmapFastpathData(NamedTuple): version: int # For forward and backward compatibility xla_executable: xc.LoadedExecutable diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 29f115a11..375bf3922 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -34,18 +34,9 @@ from jax._src import linear_util as lu from jax._src.linear_util import TracingDebugInfo from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, Unhashable) -from jax._src.config import flags, bool_env from jax._src import traceback_util traceback_util.register_exclusion(__file__) -FLAGS = flags.FLAGS - -flags.DEFINE_bool( - "experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True), - "A flag enabling the C++ jax.pmap fast path. Until the default " - "is switched to True, the feature is not supported and possibly broken " - "(e.g. it may use unreleased code from jaxlib.") - map = safe_map def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]: diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 4357385b3..7dcd46e08 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -96,7 +96,7 @@ class DebugNaNsTest(jtu.JaxTestCase): f(1) def testPmap(self): - pmap_funcs = [api._python_pmap, api._cpp_pmap] + pmap_funcs = [api._cpp_pmap] for pmap in pmap_funcs: f = pmap(lambda x: 0. / x) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 725f8311b..01d2feb4e 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -138,7 +138,7 @@ class PythonPmapTest(jtu.JaxTestCase): @property def pmap(self): - return src_api._python_pmap + return src_api.pmap def testDeviceBufferToArray(self): sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))