mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Delete the mesh
context manager. The replacement for it is Mesh
.
PiperOrigin-RevId: 442619711
This commit is contained in:
parent
4ccce5e25e
commit
8a61414a88
@ -12,6 +12,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.3.8 (Unreleased)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
|
||||
* Changes
|
||||
* `jax.experimental.maps.mesh` has been deleted.
|
||||
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
||||
for more information.
|
||||
|
||||
## jaxlib 0.3.8 (Unreleased)
|
||||
* [GitHub
|
||||
|
@ -188,44 +188,6 @@ def serial_loop(name: ResourceAxisName, length: int):
|
||||
thread_resources.env = old_env
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]):
|
||||
"""Declare the hardware resources available in the scope of this manager.
|
||||
|
||||
In particular, all ``axis_names`` become valid resource names inside the
|
||||
managed block and can be used e.g. in the ``axis_resources`` argument of
|
||||
:py:func:`xmap`.
|
||||
|
||||
If you are compiling in multiple threads, make sure that the
|
||||
``with mesh`` context manager is inside the function that the threads will
|
||||
execute.
|
||||
|
||||
Args:
|
||||
devices: A NumPy ndarray object containing JAX device objects (as
|
||||
obtained e.g. from :py:func:`jax.devices`).
|
||||
axis_names: A sequence of resource axis names to be assigned to the
|
||||
dimensions of the ``devices`` argument. Its length should match the
|
||||
rank of ``devices``.
|
||||
|
||||
Example::
|
||||
|
||||
devices = np.array(jax.devices())[:4].reshape((2, 2))
|
||||
with mesh(devices, ('x', 'y')): # declare a 2D mesh with axes 'x' and 'y'
|
||||
distributed_out = xmap(
|
||||
jnp.vdot,
|
||||
in_axes=({0: 'left', 1: 'right'}),
|
||||
out_axes=['left', 'right', ...],
|
||||
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
|
||||
"""
|
||||
warn("`maps.mesh` context manager is deprecated. Please use `maps.Mesh`.",
|
||||
FutureWarning)
|
||||
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
|
||||
thread_resources.env = old_env.with_mesh(Mesh(np.asarray(devices, dtype=object), axis_names))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_resources.env = old_env
|
||||
|
||||
_next_resource_id = 0
|
||||
class _UniqueResourceName:
|
||||
def __init__(self, uid, tag=None):
|
||||
|
Loading…
x
Reference in New Issue
Block a user