Delete the mesh context manager. The replacement for it is Mesh.

PiperOrigin-RevId: 442619711
This commit is contained in:
Yash Katariya 2022-04-18 13:37:16 -07:00 committed by jax authors
parent 4ccce5e25e
commit 8a61414a88
2 changed files with 4 additions and 38 deletions

View File

@ -12,6 +12,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.8 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
* Changes
* `jax.experimental.maps.mesh` has been deleted.
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.
## jaxlib 0.3.8 (Unreleased)
* [GitHub

View File

@ -188,44 +188,6 @@ def serial_loop(name: ResourceAxisName, length: int):
thread_resources.env = old_env
@contextlib.contextmanager
def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]):
"""Declare the hardware resources available in the scope of this manager.
In particular, all ``axis_names`` become valid resource names inside the
managed block and can be used e.g. in the ``axis_resources`` argument of
:py:func:`xmap`.
If you are compiling in multiple threads, make sure that the
``with mesh`` context manager is inside the function that the threads will
execute.
Args:
devices: A NumPy ndarray object containing JAX device objects (as
obtained e.g. from :py:func:`jax.devices`).
axis_names: A sequence of resource axis names to be assigned to the
dimensions of the ``devices`` argument. Its length should match the
rank of ``devices``.
Example::
devices = np.array(jax.devices())[:4].reshape((2, 2))
with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
distributed_out = xmap(
jnp.vdot,
in_axes=({0: 'left', 1: 'right'}),
out_axes=['left', 'right', ...],
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
"""
warn("`maps.mesh` context manager is deprecated. Please use `maps.Mesh`.",
FutureWarning)
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = old_env.with_mesh(Mesh(np.asarray(devices, dtype=object), axis_names))
try:
yield
finally:
thread_resources.env = old_env
_next_resource_id = 0
class _UniqueResourceName:
def __init__(self, uid, tag=None):