Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.

This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
This commit is contained in:
Peter Hawkins 2023-02-03 14:28:07 -08:00 committed by jax authors
parent 136c11af5f
commit 428189f8fb
26 changed files with 110 additions and 118 deletions

View File

@ -30,7 +30,6 @@ from jax.interpreters import pxla
from jax._src import array
from jax._src import sharding
from jax.experimental import pjit as pjit_lib
from jax.experimental import maps
from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
@ -62,7 +61,7 @@ def create_mesh(shape, axis_names, state):
return None
devices = sorted(jax.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(shape)
global_mesh = maps.Mesh(mesh_devices, axis_names)
global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
return global_mesh
@ -692,7 +691,7 @@ def bench_repeated_static_slicing(state):
jax.block_until_ready([x[i:i + 2] for i in range(0, 1000, 2)])
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
spec = pjit_lib.PartitionSpec('x')
spec = jax.sharding.PartitionSpec('x')
mesh = create_mesh((num_devices,), ('x',), state)
if mesh is None:
return

View File

@ -226,7 +226,7 @@ f = pjit(
in_axis_resources=PartitionSpec('x', None),
out_axis_resources=(None, PartitionSpec('x', None)))
with maps.Mesh(mesh.devices, mesh.axis_names):
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)

View File

@ -92,7 +92,7 @@ import jax.numpy as jnp
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = maps.Mesh(jax.devices().reshape(4, 2), ('x', 'y'))
mesh = jax.sharding.Mesh(jax.devices().reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)

View File

@ -544,8 +544,8 @@ def make_array_from_callback(
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> input_shape = (8, 8)

View File

@ -444,7 +444,7 @@ def pjit(
propagation of the input partitioning specified in ``in_axis_resources`` and
the output partitioning specified in ``out_axis_resources``. The resources
specified in those two arguments must refer to mesh axes, as defined by
the :py:func:`jax.experimental.maps.Mesh` context manager. Note that the mesh
the :py:func:`jax.sharding.Mesh` context manager. Note that the mesh
definition at :func:`~pjit` application time is ignored, and the returned function
will use the mesh definition available at each call site.

View File

@ -273,13 +273,13 @@ class NamedSharding(XLACompatibleSharding):
``Mesh`` and ``PartitionSpec``.
Args:
mesh: A ``jax.experimental.maps.Mesh`` object.
spec: A ``jax.experimental.PartitionSpec`` object.
mesh: A ``jax.sharding.Mesh`` object.
spec: A ``jax.sharding.PartitionSpec`` object.
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)

View File

@ -23,7 +23,7 @@ from jax._src import config as jax_config
from jax.config import config
from jax._src import array
from jax._src.sharding import NamedSharding, OpShardingSharding
from jax.experimental import PartitionSpec as P
from jax.sharding import PartitionSpec as P
from jax.experimental.gda_serialization import serialization
import numpy as np
import tensorstore as ts

View File

@ -207,8 +207,8 @@ class GlobalDeviceArray:
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> assert jax.device_count() == 8
@ -490,8 +490,8 @@ class GlobalDeviceArray:
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> global_input_shape = (8, 8)

View File

@ -209,7 +209,7 @@ to return the entire array, which will then be sent in a single infeed to the
same device that issued the outfeed. This device is then responsible for
sending the required shards to the other devices::
with maps.Mesh(jax.local_devices()[:2], ["d"]):
with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
pjit.pjit(power3, in_axis_resources=(P("d"),),
out_axis_resources=(P("d"),))(np.array([3., 4.]))

View File

@ -219,7 +219,7 @@ def create_device_mesh(
mesh_shape: Sequence[int],
devices: Optional[Sequence[Any]] = None, *,
contiguous_submeshes: bool = False) -> np.ndarray:
"""Creates a performant device mesh for jax.experimental.maps.Mesh.
"""Creates a performant device mesh for jax.sharding.Mesh.
Args:
mesh_shape: shape of logical mesh, ordered by increasing network-intensity
@ -234,7 +234,7 @@ def create_device_mesh(
Returns:
A np.ndarray of JAX devices with mesh_shape as its shape that can be fed
into jax.experimental.maps.Mesh with good collective performance.
into jax.sharding.Mesh with good collective performance.
"""
if devices is None:
devices = jax.devices()
@ -293,7 +293,7 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
Returns:
A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape
that can be fed into jax.experimental.maps.Mesh for hybrid parallelism.
that can be fed into jax.sharding.Mesh for hybrid parallelism.
"""
if devices is None:
devices = jax.devices()

View File

@ -26,7 +26,6 @@ from jax._src import array
from jax._src import sharding
from jax.tree_util import PyTreeDef
from jax.interpreters import pxla, xla
from jax.experimental import maps
from jax.experimental import pjit as pjit_lib
from jax.experimental.pjit import pjit, FROM_GDA
from jax.interpreters.pxla import PartitionSpec as P
@ -102,7 +101,7 @@ def _handle_array_process_allgather(inp, tiled):
# All inputs here will be fully addressable.
devices = np.array(jax.devices()).reshape(jax.process_count(),
jax.local_device_count())
global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
pspec = P('processes')
s = jax.sharding.NamedSharding(global_mesh, pspec)
@ -158,7 +157,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
# Shape of local_mesh will always be (1, local_device_count())
devices = np.array(jax.devices()).reshape(jax.process_count(),
jax.local_device_count())
global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
in_axis_resources = P('processes')
if inp.ndim == 0 or not tiled:
inp = np.expand_dims(inp, axis=0)

View File

@ -2343,9 +2343,9 @@ class Mesh(ContextDecorator):
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental.pjit import pjit
>>> from jax.experimental import PartitionSpec as P
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))

View File

@ -60,7 +60,6 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters.pxla import PartitionSpec as P
from jax._src import array, sharding
from jax.experimental import pjit
from jax.experimental import maps
from jax._src import config as jax_config
from jax._src import custom_derivatives
from jax._src import device_array
@ -1503,7 +1502,7 @@ class APITest(jtu.JaxTestCase):
@jax_config.jax_array(True)
def test_device_put_sharding(self):
mesh = maps.Mesh(jax.devices(), ('x',))
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
s = sharding.NamedSharding(mesh, P('x'))
x = jnp.arange(len(jax.devices()))
@ -1529,7 +1528,8 @@ class APITest(jtu.JaxTestCase):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices")
mesh = maps.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)),
("x", "y"))
s1 = sharding.NamedSharding(mesh, P("x"))
s2 = sharding.NamedSharding(mesh, P("y"))
s3 = sharding.NamedSharding(mesh, P("x", "y"))
@ -1552,7 +1552,7 @@ class APITest(jtu.JaxTestCase):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices")
mesh = maps.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
s1 = sharding.NamedSharding(mesh, P("x"))
s2 = sharding.NamedSharding(mesh, P("y"))
@ -1574,7 +1574,7 @@ class APITest(jtu.JaxTestCase):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices")
mesh = maps.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
s1 = sharding.NamedSharding(mesh, P("x"))
s2 = sharding.NamedSharding(mesh, P("y"))
@ -1596,7 +1596,7 @@ class APITest(jtu.JaxTestCase):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices")
mesh = maps.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
s1 = sharding.NamedSharding(mesh, P("x"))
s2 = sharding.NamedSharding(mesh, P("y"))

View File

@ -30,13 +30,12 @@ from jax._src.lib import xla_extension_version
from jax._src.util import prod, safe_zip
from jax.interpreters import pxla
from jax.experimental.pjit import pjit
from jax.experimental import PartitionSpec as P
from jax.experimental.serialize_executable import (
compile_and_serialize, load_compiled)
from jax.sharding import PartitionSpec as P
from jax._src import sharding
from jax._src import array
from jax._src import prng
from jax.experimental import maps
from jax.config import config
config.parse_flags_with_absl()
@ -363,7 +362,7 @@ class JaxArrayTest(jtu.JaxTestCase):
if jax.device_count() < 3:
self.skipTest('Requires more than 3 devices')
shape = (8, 2)
mesh = maps.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
# sharding device ids = {1, 2}
s = sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
@ -506,7 +505,7 @@ class JaxArrayTest(jtu.JaxTestCase):
self.skipTest('Test requires >= 2 devices.')
single_dev = jax.devices()[1:2]
mesh = maps.Mesh(np.array(single_dev), ('x'))
mesh = jax.sharding.Mesh(np.array(single_dev), ('x'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, sharding.NamedSharding(mesh, P('x')))
@ -984,7 +983,7 @@ class RngShardingTest(jtu.JaxTestCase):
def fun(x):
return x * x
with maps.Mesh(np.array(jax.devices()), ('data',)):
with jax.sharding.Mesh(np.array(jax.devices()), ('data',)):
lowered = pjit(
fun,
in_axis_resources=P('data'),

View File

@ -26,7 +26,6 @@ from jax._src.lib import xla_extension
from jax.config import config
from jax.experimental import checkify
from jax.experimental import pjit
from jax.experimental import maps
from jax._src.sharding import NamedSharding
from jax._src import array
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
@ -479,13 +478,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# binary func
return x / y
mesh = maps.Mesh(np.array(jax.devices()), ["dev"])
mesh = jax.sharding.Mesh(np.array(jax.devices()), ["dev"])
if config.jax_array:
ps = NamedSharding(mesh, pjit.PartitionSpec("dev"))
ps = NamedSharding(mesh, jax.sharding.PartitionSpec("dev"))
inp = np.arange(8)
x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx])
else:
ps = pjit.PartitionSpec("dev")
ps = jax.sharding.PartitionSpec("dev")
x = jnp.arange(8)
f = pjit.pjit(f, in_axis_resources=ps, out_axis_resources=ps)

View File

@ -23,7 +23,7 @@ from unittest import mock, SkipTest
import warnings
from absl.testing import absltest
from jax.experimental import PartitionSpec as P
from jax.sharding import PartitionSpec as P
from jax.experimental.compilation_cache import compilation_cache as cc
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit

View File

@ -147,7 +147,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = pjit.PartitionSpec('x')
p = jax.sharding.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_axis_resources=p,
out_axis_resources=p)
@ -175,7 +175,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = pjit.PartitionSpec('x')
p = jax.sharding.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_axis_resources=p,
out_axis_resources=p,

View File

@ -21,7 +21,6 @@ from typing import IO, Sequence, Tuple
from absl.testing import absltest
import jax
from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax._src import debugger
from jax._src import test_util as jtu
@ -330,9 +329,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
def g(x):
y = f(x)
return jnp.exp(y)
g = pjit.pjit(g, in_axis_resources=pjit.PartitionSpec("dev"),
out_axis_resources=pjit.PartitionSpec("dev"))
with maps.Mesh(np.array(jax.devices()), ["dev"]):
g = pjit.pjit(g, in_axis_resources=jax.sharding.PartitionSpec("dev"),
out_axis_resources=jax.sharding.PartitionSpec("dev"))
with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]):
arr = (1 + np.arange(8)).astype(np.int32)
expected = _format_multiline(r"""
Entering jdb:

View File

@ -790,13 +790,13 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
def f(x):
debug_print("{}", x, ordered=False)
return x
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
else:
spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
spec = jax.sharding.PartitionSpec('dev')
out_spec = jax.sharding.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
with mesh:
with jtu.capture_stdout() as output:
@ -809,7 +809,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
debug_print("{}", y, ordered=False)
return y
f2 = pjit.pjit(f2, in_axis_resources=spec, out_axis_resources=out_spec)
with maps.Mesh(np.array(jax.devices()), ['dev']):
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
with jtu.capture_stdout() as output:
f2(np.arange(8, dtype=jnp.int32))
jax.effects_barrier()
@ -847,11 +847,11 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
return (i + 1, x)
return lax.while_loop(cond, body, (0, x))[1]
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
else:
spec = pjit.PartitionSpec('dev')
spec = jax.sharding.PartitionSpec('dev')
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
with mesh:
with jtu.capture_stdout() as output:
@ -876,13 +876,13 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x)
debug_print("Out: {}", out)
return out
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
in_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
in_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
else:
in_spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
in_spec = jax.sharding.PartitionSpec('dev')
out_spec = jax.sharding.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=in_spec, out_axis_resources=out_spec)
with mesh:
with jtu.capture_stdout() as output:
@ -900,7 +900,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
debug_print("{}", x, ordered=False)
f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
axis_resources={'a': 'dev'})
with maps.Mesh(np.array(jax.devices()), ['dev']):
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
with jtu.capture_stdout() as output:
f(np.arange(40))
jax.effects_barrier()
@ -991,8 +991,8 @@ class VisualizeShardingTest(jtu.JaxTestCase):
return np.array(devices).reshape(shape)
def test_trivial_sharding(self):
mesh = maps.Mesh(self._create_devices(1), ['x'])
pspec = pjit.PartitionSpec('x')
mesh = jax.sharding.Mesh(self._create_devices(1), ['x'])
pspec = jax.sharding.PartitionSpec('x')
sd = sharding.NamedSharding(mesh, pspec)
shape = (5,)
with jtu.capture_stdout() as output:
@ -1004,8 +1004,8 @@ class VisualizeShardingTest(jtu.JaxTestCase):
"""))
def test_trivial_sharding_with_scale(self):
mesh = maps.Mesh(self._create_devices(1), ['x'])
pspec = pjit.PartitionSpec('x')
mesh = jax.sharding.Mesh(self._create_devices(1), ['x'])
pspec = jax.sharding.PartitionSpec('x')
sd = sharding.NamedSharding(mesh, pspec)
shape = (5,)
with jtu.capture_stdout() as output:
@ -1017,8 +1017,8 @@ class VisualizeShardingTest(jtu.JaxTestCase):
"""))
def test_full_sharding(self):
mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])
pspec = pjit.PartitionSpec('x', 'y')
mesh = jax.sharding.Mesh(self._create_devices((8, 4)), ['x', 'y'])
pspec = jax.sharding.PartitionSpec('x', 'y')
sd = sharding.NamedSharding(mesh, pspec)
shape = (8, 8)
with jtu.capture_stdout() as output:
@ -1046,9 +1046,9 @@ class VisualizeShardingTest(jtu.JaxTestCase):
def test_sharding_with_replication(self):
shape = (8, 8)
mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])
mesh = jax.sharding.Mesh(self._create_devices((8, 4)), ['x', 'y'])
pspec = pjit.PartitionSpec('x', None)
pspec = jax.sharding.PartitionSpec('x', None)
sd = sharding.NamedSharding(mesh, pspec)
with jtu.capture_stdout() as output:
debugging.visualize_sharding(shape, sd)
@ -1073,8 +1073,8 @@ class VisualizeShardingTest(jtu.JaxTestCase):
""")
self.assertEqual(output(), expected)
mesh = maps.Mesh(self._create_devices((4, 2)), ['x', 'y'])
pspec = pjit.PartitionSpec(None, 'y')
mesh = jax.sharding.Mesh(self._create_devices((4, 2)), ['x', 'y'])
pspec = jax.sharding.PartitionSpec(None, 'y')
sd = sharding.NamedSharding(mesh, pspec)
with jtu.capture_stdout() as output:
debugging.visualize_sharding(shape, sd)
@ -1095,9 +1095,9 @@ class VisualizeShardingTest(jtu.JaxTestCase):
def test_visualize_wide_array(self):
shape = (128, 10000)
mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])
mesh = jax.sharding.Mesh(self._create_devices((8, 4)), ['x', 'y'])
pspec = pjit.PartitionSpec('x', None)
pspec = jax.sharding.PartitionSpec('x', None)
sd = sharding.NamedSharding(mesh, pspec)
with jtu.capture_stdout() as output:
debugging.visualize_sharding(shape, sd)
@ -1184,13 +1184,13 @@ class InspectShardingTest(jtu.JaxTestCase):
debugging.inspect_array_sharding(x, callback=_cb)
return jnp.square(x)
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
else:
spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
spec = jax.sharding.PartitionSpec('dev')
out_spec = jax.sharding.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
with mesh:
f(np.arange(8, dtype=jnp.int32))

View File

@ -23,8 +23,8 @@ from jax import core
from jax._src import test_util as jtu
from jax._src.util import prod, safe_zip
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import Mesh
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
import jax.experimental.global_device_array as gda_lib
from jax.experimental.global_device_array import GlobalDeviceArray, get_shard_indices

View File

@ -31,8 +31,7 @@ from jax import core
from jax.config import config
from jax import dtypes
from jax.experimental import host_callback as hcb
from jax.experimental import PartitionSpec as P
from jax.experimental import maps
from jax.sharding import PartitionSpec as P
from jax.experimental import pjit
from jax import lax
from jax import numpy as jnp
@ -1720,7 +1719,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
in_axis_resources=(P("d"),),
out_axis_resources=P("d"))
with maps.Mesh(devices, ["d"]):
with jax.sharding.Mesh(devices, ["d"]):
# Print the internal IR
helper_log_ir(
f"{self._testMethodName}.pjit",
@ -2338,7 +2337,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
pjit_fun = pjit.pjit(
fun, in_axis_resources=(P("d"),), out_axis_resources=P("d"))
with maps.Mesh(devices, ["d"]):
with jax.sharding.Mesh(devices, ["d"]):
# Print the internal IR
helper_log_ir(
f"{self._testMethodName}.pjit",

View File

@ -272,11 +272,11 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
mesh = maps.Mesh(np.array(jax.devices()), ['x'])
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['x'])
if config.jax_array:
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('x'))
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
else:
spec = pjit.PartitionSpec('x')
spec = jax.sharding.PartitionSpec('x')
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
with mesh:
jaxpr = jax.make_jaxpr(f)(np.arange(jax.local_device_count()))

View File

@ -33,7 +33,6 @@ import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src import util
from jax.experimental import global_device_array
from jax.experimental import maps
from jax.experimental import pjit
try:
@ -256,7 +255,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
# 13 15
assert [d.id for d in device_mesh.flat
] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
return maps.Mesh(device_mesh, ("x", "y"))
return jax.sharding.Mesh(device_mesh, ("x", "y"))
def setUp(self):
super().setUp()
@ -347,7 +346,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
gda3 = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes3, cb)
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
@functools.partial(
pjit.pjit,
@ -394,7 +393,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
# The process-gpu mapping is random: @sudhakarsingh27 to figure out why so
# and the data is:
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
global_mesh = maps.Mesh(mesh_devices, ("x",))
global_mesh = jax.sharding.Mesh(mesh_devices, ("x",))
global_input_shape = (16,)
mesh_axes = experimental.PartitionSpec("x")
global_input_data = np.arange(
@ -426,7 +425,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
15: ((slice(15, 16),), 0),
}
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
f = pjit.pjit(lambda x: x,
in_axis_resources=pjit.FROM_GDA,
out_axis_resources=mesh_axes)

View File

@ -37,7 +37,6 @@ from jax import lax
from jax.lax import with_sharding_constraint
from jax import prng
from jax.sharding import PartitionSpec as P
from jax.experimental import maps
from jax.experimental.maps import xmap
from jax.experimental import global_device_array
from jax.experimental import multihost_utils
@ -284,7 +283,7 @@ class PJitTest(jtu.BufferDonationTestCase):
raise unittest.SkipTest(f"Test requires {size} global devices.")
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
@maps.Mesh(mesh_devices, ('x', 'y'))
@jax.sharding.Mesh(mesh_devices, ('x', 'y'))
def dec():
return pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None)(x)
out = dec()
@ -529,13 +528,13 @@ class PJitTest(jtu.BufferDonationTestCase):
if devices.size < 4:
raise unittest.SkipTest("Test requires 4 devices")
devices = devices.reshape((2, 2))
with maps.Mesh(devices, ('x', 'y')):
with jax.sharding.Mesh(devices, ('x', 'y')):
should_be_tracing = True
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
should_be_tracing = False
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
# Re-create the mesh to make sure that has no influence on caching
with maps.Mesh(devices, ('x', 'y')):
with jax.sharding.Mesh(devices, ('x', 'y')):
should_be_tracing = False
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
@ -764,7 +763,7 @@ class PJitTest(jtu.BufferDonationTestCase):
d.transfer_to_infeed(z[3 * didx:3 * didx + 3, :])
d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5],))
with maps.Mesh(devices, ['d']):
with jax.sharding.Mesh(devices, ['d']):
logging.info('Making pjit call')
res = pjit(
f_for_pjit, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(
@ -788,7 +787,7 @@ class PJitTest(jtu.BufferDonationTestCase):
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
def dispatch():
with maps.Mesh(devices, ['d']):
with jax.sharding.Mesh(devices, ['d']):
logging.info('Making pjit call')
pjit(f, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(x)
execution = threading.Thread(target=dispatch)
@ -1417,7 +1416,7 @@ class GDAPjitTest(jtu.JaxTestCase):
# pickling in_axis_resources and sending to other processes). Make sure this
# this doesn't cause an error to avoid user confusion.
from_gda_dup = pjit_lib._FromGdaSingleton()
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
pjit(lambda x: x, in_axis_resources=from_gda_dup, out_axis_resources=None)(
input_gda)
@ -3457,7 +3456,7 @@ class PJitErrorTest(jtu.JaxTestCase):
def testNestedDifferentResources(self):
@partial(pjit, in_axis_resources=P('x'), out_axis_resources=None)
def f(x):
with maps.Mesh(np.array([jax.local_devices()[0]]), ('x')):
with jax.sharding.Mesh(np.array([jax.local_devices()[0]]), ('x')):
@partial(pjit, in_axis_resources=P('x'), out_axis_resources=None)
def h(x):
return x
@ -3801,12 +3800,12 @@ class UtilTest(jtu.JaxTestCase):
P(('x',), ('y',)))
def test_mesh_with_list_devices(self):
mesh = maps.Mesh(jax.devices(), ('x',))
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
self.assertIsInstance(mesh.devices, np.ndarray)
self.assertEqual(mesh.size, jax.device_count())
def test_mesh_with_string_axis_names(self):
mesh = maps.Mesh(jax.devices(), 'dp')
mesh = jax.sharding.Mesh(jax.devices(), 'dp')
self.assertTupleEqual(mesh.axis_names, ('dp',))

View File

@ -33,10 +33,10 @@ from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import mlir
from jax.experimental.maps import Mesh
from jax.experimental.maps import xmap
from jax.experimental import io_callback
import jax.numpy as jnp
from jax.sharding import Mesh
import numpy as np
@ -683,7 +683,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
try:
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = pjit.PartitionSpec('x')
spec = jax.sharding.PartitionSpec('x')
def f(x):
axis_resources = {v: v for v in mesh.axis_names}
@ -850,7 +850,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
axis_resources={'a': 'dev'})
with maps.Mesh(np.array(jax.devices()), ['dev']):
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
out = f(np.arange(40.))
np.testing.assert_allclose(out, jnp.arange(1., 41.))
@ -866,7 +866,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
axis_resources={'a': 'dev'})
with maps.Mesh(np.array(jax.devices()), ['dev']):
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
out = f(np.arange(40.))
np.testing.assert_allclose(out, jnp.arange(1., 41.))
@ -1020,13 +1020,13 @@ class IOPythonCallbackTest(jtu.JaxTestCase):
io_callback(_cb, None, x)
return x
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
out_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
else:
spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
spec = jax.sharding.PartitionSpec('dev')
out_spec = jax.sharding.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
with mesh:
f(jnp.arange(mesh.size))

View File

@ -38,7 +38,7 @@ from jax.experimental import global_device_array
from jax._src import array
from jax._src.sharding import NamedSharding
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.experimental.pjit import PartitionSpec as P
from jax.sharding import PartitionSpec as P
from jax.experimental.maps import xmap, serial_loop, SerialLoop
from jax.errors import JAXTypeError
from jax._src import config as jax_config
@ -273,7 +273,7 @@ class XMapTest(XMapTestCase):
def f(a, b):
return a * 2, b * 4
devices = np.array(local_devices[:4]).reshape((2, 2))
with maps.Mesh(devices, ('x', 'y')):
with jax.sharding.Mesh(devices, ('x', 'y')):
fm = xmap(f,
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
@ -382,14 +382,14 @@ class XMapTest(XMapTestCase):
if devices.size < 2:
raise SkipTest("Test requires 2 devices")
x = np.arange(8).reshape((2, 2, 2))
with maps.Mesh(devices, ('x',)):
with jax.sharding.Mesh(devices, ('x',)):
python_should_be_executing = True
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})(x)
python_should_be_executing = False
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})(x)
with maps.Mesh(devices, ('x',)):
with jax.sharding.Mesh(devices, ('x',)):
python_should_be_executing = False
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})(x)
@ -1795,7 +1795,7 @@ class XMapErrorTest(jtu.JaxTestCase):
def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
def f(x):
with maps.Mesh(np.empty((), dtype=np.object_), ()):
with jax.sharding.Mesh(np.empty((), dtype=np.object_), ()):
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
def h(x):
return x