Remove experimental_cpp_pmap flag since it is always on

PiperOrigin-RevId: 522631405
This commit is contained in:
Yash Katariya 2023-04-07 10:41:42 -07:00 committed by jax authors
parent 06569e0889
commit 738dd719bd
5 changed files with 6 additions and 89 deletions

View File

@ -9,7 +9,8 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.9
* Changes
* The flags experimental_cpp_jit, and experimental_cpp_pjit have been removed.
* The flags experimental_cpp_jit, experimental_cpp_pjit and
experimental_cpp_pmap have been removed.
They are now always on.
* Deprecations

View File

@ -56,7 +56,7 @@ from jax._src.api_util import (
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple,
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS)
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
@ -1552,12 +1552,7 @@ def pmap(
" removed from JAX. Please migrate to pjit and remove global_arg_shapes"
" from pmap.")
if FLAGS.experimental_cpp_pmap:
func = _cpp_pmap
else:
func = _python_pmap
return func(
return _cpp_pmap(
fun,
axis_name,
in_axes=in_axes,
@ -1687,35 +1682,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
global_axis_size=global_axis_size,
is_explicit_global_axis_size=is_explicit_global_axis_size)
def _get_f_mapped(
*,
fun: Callable,
axis_name: Optional[AxisName],
in_axes=0,
out_axes=0,
static_broadcasted_tuple: Tuple[int, ...],
devices: Optional[Sequence[xc.Device]], # noqa: F811
backend: Optional[str],
axis_size: Optional[int],
donate_tuple: Tuple[int, ...],
):
def pmap_f(*args, **kwargs):
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
for arg in p.flat_args:
dispatch.check_arg(arg)
out = pxla.xla_pmap(
p.flat_fun, *p.flat_args, backend=backend, axis_name=axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk,
name=p.flat_fun.__name__, donated_invars=p.donated_invars,
is_explicit_global_axis_size=p.is_explicit_global_axis_size)
return p.out_tree, out
return pmap_f
def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,
donate_argnums, in_axes, out_axes):
@ -1738,47 +1704,6 @@ def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,
return axis_name, static_broadcasted_tuple, donate_tuple
def _python_pmap(
fun: Callable,
axis_name: Optional[AxisName] = None,
*,
in_axes=0,
out_axes=0,
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices: Optional[Sequence[xc.Device]] = None, # noqa: F811
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
) -> stages.Wrapped:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
out_axes)
@wraps(fun)
@api_boundary
def pmap_f(*args, **kwargs):
f_pmapped_ = _get_f_mapped(
fun=fun,
axis_name=axis_name,
in_axes=in_axes,
out_axes=out_axes,
static_broadcasted_tuple=static_broadcasted_tuple,
devices=devices,
backend=backend,
axis_size=axis_size,
donate_tuple=donate_tuple)
out_tree, out_flat = f_pmapped_(*args, **kwargs)
return tree_unflatten(out_tree(), out_flat)
pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, donate_tuple)
return cast(stages.Wrapped, pmap_f)
class _PmapFastpathData(NamedTuple):
version: int # For forward and backward compatibility
xla_executable: xc.LoadedExecutable

View File

@ -34,18 +34,9 @@ from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable)
from jax._src.config import flags, bool_env
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
FLAGS = flags.FLAGS
flags.DEFINE_bool(
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
map = safe_map
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:

View File

@ -96,7 +96,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
f(1)
def testPmap(self):
pmap_funcs = [api._python_pmap, api._cpp_pmap]
pmap_funcs = [api._cpp_pmap]
for pmap in pmap_funcs:
f = pmap(lambda x: 0. / x)

View File

@ -138,7 +138,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@property
def pmap(self):
return src_api._python_pmap
return src_api.pmap
def testDeviceBufferToArray(self):
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))