diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c03da5cf..cc8345433 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.8 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...main). +* Changes + * `jax.experimental.maps.mesh` has been deleted. + Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh + for more information. ## jaxlib 0.3.8 (Unreleased) * [GitHub diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 340d16bfd..81d96dd6c 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -188,44 +188,6 @@ def serial_loop(name: ResourceAxisName, length: int): thread_resources.env = old_env -@contextlib.contextmanager -def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]): - """Declare the hardware resources available in the scope of this manager. - - In particular, all ``axis_names`` become valid resource names inside the - managed block and can be used e.g. in the ``axis_resources`` argument of - :py:func:`xmap`. - - If you are compiling in multiple threads, make sure that the - ``with mesh`` context manager is inside the function that the threads will - execute. - - Args: - devices: A NumPy ndarray object containing JAX device objects (as - obtained e.g. from :py:func:`jax.devices`). - axis_names: A sequence of resource axis names to be assigned to the - dimensions of the ``devices`` argument. Its length should match the - rank of ``devices``. - - Example:: - - devices = np.array(jax.devices())[:4].reshape((2, 2)) - with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y' - distributed_out = xmap( - jnp.vdot, - in_axes=({0: 'left', 1: 'right'}), - out_axes=['left', 'right', ...], - axis_resources={'left': 'x', 'right': 'y'})(x, x.T) - """ - warn("`maps.mesh` context manager is deprecated. Please use `maps.Mesh`.", - FutureWarning) - old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV) - thread_resources.env = old_env.with_mesh(Mesh(np.asarray(devices, dtype=object), axis_names)) - try: - yield - finally: - thread_resources.env = old_env - _next_resource_id = 0 class _UniqueResourceName: def __init__(self, uid, tag=None):