diff --git a/CHANGELOG.md b/CHANGELOG.md index a4fe3af34..e2426b107 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`. + * The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are + deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the + `spmd_axis_name` argument for expressing SPMD device-parallel computations. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` that cannot be converted to a JAX array now results in an exception. diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index afdd11dbb..6e398df0d 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -12,18 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + +from jax._src import deprecations from jax._src.maps import ( AxisName as AxisName, ResourceSet as ResourceSet, SerialLoop as SerialLoop, + _prepare_axes as _prepare_axes, make_xmap_callable as make_xmap_callable, serial_loop as serial_loop, - xmap as xmap, xmap_p as xmap_p, - _prepare_axes as _prepare_axes, + xmap as xmap, ) from jax._src.mesh import ( EMPTY_ENV as EMPTY_ENV, ResourceEnv as ResourceEnv, thread_resources as thread_resources, ) + +# Added March 7, 2024. +_msg = ( + "jax.experimental.maps and jax.experimental.maps.xmap are deprecated and" + " will be removed in a future release. Use jax.experimental.shard_map or" + " jax.vmap with the spmd_axis_name argument for expressing SPMD" + " device-parallel computations. Please file an issue on" + " https://github.com/google/jax/issues if neither" + " jax.experimental.shard_map nor jax.vmap are suitable for your use case." +) + +deprecations.register("jax.experimental.maps", "maps-module") + +if deprecations.is_accelerated("jax.experimental.maps", "maps-module"): + raise ImportError(_msg) +else: + warnings.warn(_msg, DeprecationWarning, stacklevel=2) + +del deprecations, warnings, _msg diff --git a/pyproject.toml b/pyproject.toml index 0a5873d89..fdbcf1555 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ filterwarnings = [ "ignore:Special cases found for .* but none were parsed.*:UserWarning", # end array_api_tests-related warnings "ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", + "ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER",