mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
This commit is contained in:
parent
06569e0889
commit
738dd719bd
@ -9,7 +9,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
## jax 0.4.9
|
||||
|
||||
* Changes
|
||||
* The flags experimental_cpp_jit, and experimental_cpp_pjit have been removed.
|
||||
* The flags experimental_cpp_jit, experimental_cpp_pjit and
|
||||
experimental_cpp_pmap have been removed.
|
||||
They are now always on.
|
||||
|
||||
* Deprecations
|
||||
|
@ -56,7 +56,7 @@ from jax._src.api_util import (
|
||||
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
shaped_abstractify, _ensure_str_tuple,
|
||||
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, FLAGS)
|
||||
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -1552,12 +1552,7 @@ def pmap(
|
||||
" removed from JAX. Please migrate to pjit and remove global_arg_shapes"
|
||||
" from pmap.")
|
||||
|
||||
if FLAGS.experimental_cpp_pmap:
|
||||
func = _cpp_pmap
|
||||
else:
|
||||
func = _python_pmap
|
||||
|
||||
return func(
|
||||
return _cpp_pmap(
|
||||
fun,
|
||||
axis_name,
|
||||
in_axes=in_axes,
|
||||
@ -1687,35 +1682,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
global_axis_size=global_axis_size,
|
||||
is_explicit_global_axis_size=is_explicit_global_axis_size)
|
||||
|
||||
def _get_f_mapped(
|
||||
*,
|
||||
fun: Callable,
|
||||
axis_name: Optional[AxisName],
|
||||
in_axes=0,
|
||||
out_axes=0,
|
||||
static_broadcasted_tuple: Tuple[int, ...],
|
||||
devices: Optional[Sequence[xc.Device]], # noqa: F811
|
||||
backend: Optional[str],
|
||||
axis_size: Optional[int],
|
||||
donate_tuple: Tuple[int, ...],
|
||||
):
|
||||
def pmap_f(*args, **kwargs):
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
for arg in p.flat_args:
|
||||
dispatch.check_arg(arg)
|
||||
out = pxla.xla_pmap(
|
||||
p.flat_fun, *p.flat_args, backend=backend, axis_name=axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk,
|
||||
name=p.flat_fun.__name__, donated_invars=p.donated_invars,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size)
|
||||
return p.out_tree, out
|
||||
|
||||
return pmap_f
|
||||
|
||||
|
||||
def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,
|
||||
donate_argnums, in_axes, out_axes):
|
||||
@ -1738,47 +1704,6 @@ def _shared_code_pmap(fun, axis_name, static_broadcasted_argnums,
|
||||
return axis_name, static_broadcasted_tuple, donate_tuple
|
||||
|
||||
|
||||
def _python_pmap(
|
||||
fun: Callable,
|
||||
axis_name: Optional[AxisName] = None,
|
||||
*,
|
||||
in_axes=0,
|
||||
out_axes=0,
|
||||
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
|
||||
devices: Optional[Sequence[xc.Device]] = None, # noqa: F811
|
||||
backend: Optional[str] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
) -> stages.Wrapped:
|
||||
"""The Python only implementation."""
|
||||
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
|
||||
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
|
||||
out_axes)
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def pmap_f(*args, **kwargs):
|
||||
f_pmapped_ = _get_f_mapped(
|
||||
fun=fun,
|
||||
axis_name=axis_name,
|
||||
in_axes=in_axes,
|
||||
out_axes=out_axes,
|
||||
static_broadcasted_tuple=static_broadcasted_tuple,
|
||||
devices=devices,
|
||||
backend=backend,
|
||||
axis_size=axis_size,
|
||||
donate_tuple=donate_tuple)
|
||||
|
||||
out_tree, out_flat = f_pmapped_(*args, **kwargs)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
pmap_f.lower = _pmap_lower(
|
||||
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
|
||||
backend, axis_size, donate_tuple)
|
||||
|
||||
return cast(stages.Wrapped, pmap_f)
|
||||
|
||||
|
||||
class _PmapFastpathData(NamedTuple):
|
||||
version: int # For forward and backward compatibility
|
||||
xla_executable: xc.LoadedExecutable
|
||||
|
@ -34,18 +34,9 @@ from jax._src import linear_util as lu
|
||||
from jax._src.linear_util import TracingDebugInfo
|
||||
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
|
||||
Unhashable)
|
||||
from jax._src.config import flags, bool_env
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
|
||||
"A flag enabling the C++ jax.pmap fast path. Until the default "
|
||||
"is switched to True, the feature is not supported and possibly broken "
|
||||
"(e.g. it may use unreleased code from jaxlib.")
|
||||
|
||||
map = safe_map
|
||||
|
||||
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
|
||||
|
@ -96,7 +96,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
f(1)
|
||||
|
||||
def testPmap(self):
|
||||
pmap_funcs = [api._python_pmap, api._cpp_pmap]
|
||||
pmap_funcs = [api._cpp_pmap]
|
||||
|
||||
for pmap in pmap_funcs:
|
||||
f = pmap(lambda x: 0. / x)
|
||||
|
@ -138,7 +138,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@property
|
||||
def pmap(self):
|
||||
return src_api._python_pmap
|
||||
return src_api.pmap
|
||||
|
||||
def testDeviceBufferToArray(self):
|
||||
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user