mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates: * {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh * {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec * jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding. PiperOrigin-RevId: 506994892
This commit is contained in:
parent
136c11af5f
commit
428189f8fb
@ -30,7 +30,6 @@ from jax.interpreters import pxla
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import multihost_utils
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
@ -62,7 +61,7 @@ def create_mesh(shape, axis_names, state):
|
||||
return None
|
||||
devices = sorted(jax.devices(), key=lambda d: d.id)
|
||||
mesh_devices = np.array(devices[:size]).reshape(shape)
|
||||
global_mesh = maps.Mesh(mesh_devices, axis_names)
|
||||
global_mesh = jax.sharding.Mesh(mesh_devices, axis_names)
|
||||
return global_mesh
|
||||
|
||||
|
||||
@ -692,7 +691,7 @@ def bench_repeated_static_slicing(state):
|
||||
jax.block_until_ready([x[i:i + 2] for i in range(0, 1000, 2)])
|
||||
|
||||
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
|
||||
spec = pjit_lib.PartitionSpec('x')
|
||||
spec = jax.sharding.PartitionSpec('x')
|
||||
mesh = create_mesh((num_devices,), ('x',), state)
|
||||
if mesh is None:
|
||||
return
|
||||
|
@ -226,7 +226,7 @@ f = pjit(
|
||||
in_axis_resources=PartitionSpec('x', None),
|
||||
out_axis_resources=(None, PartitionSpec('x', None)))
|
||||
|
||||
with maps.Mesh(mesh.devices, mesh.axis_names):
|
||||
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
|
||||
err, data = f(input_data)
|
||||
err.throw()
|
||||
# ValueError: divided by zero at <...>:4 (f)
|
||||
|
@ -92,7 +92,7 @@ import jax.numpy as jnp
|
||||
x = jnp.arange(8)
|
||||
|
||||
# Let's say there are 8 devices in jax.devices()
|
||||
mesh = maps.Mesh(jax.devices().reshape(4, 2), ('x', 'y'))
|
||||
mesh = jax.sharding.Mesh(jax.devices().reshape(4, 2), ('x', 'y'))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
sharded_x = jax.device_put(x, sharding)
|
||||
|
@ -544,8 +544,8 @@ def make_array_from_callback(
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> input_shape = (8, 8)
|
||||
|
@ -444,7 +444,7 @@ def pjit(
|
||||
propagation of the input partitioning specified in ``in_axis_resources`` and
|
||||
the output partitioning specified in ``out_axis_resources``. The resources
|
||||
specified in those two arguments must refer to mesh axes, as defined by
|
||||
the :py:func:`jax.experimental.maps.Mesh` context manager. Note that the mesh
|
||||
the :py:func:`jax.sharding.Mesh` context manager. Note that the mesh
|
||||
definition at :func:`~pjit` application time is ignored, and the returned function
|
||||
will use the mesh definition available at each call site.
|
||||
|
||||
|
@ -273,13 +273,13 @@ class NamedSharding(XLACompatibleSharding):
|
||||
``Mesh`` and ``PartitionSpec``.
|
||||
|
||||
Args:
|
||||
mesh: A ``jax.experimental.maps.Mesh`` object.
|
||||
spec: A ``jax.experimental.PartitionSpec`` object.
|
||||
mesh: A ``jax.sharding.Mesh`` object.
|
||||
spec: A ``jax.sharding.PartitionSpec`` object.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> spec = P('x', 'y')
|
||||
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
|
||||
|
@ -23,7 +23,7 @@ from jax._src import config as jax_config
|
||||
from jax.config import config
|
||||
from jax._src import array
|
||||
from jax._src.sharding import NamedSharding, OpShardingSharding
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
|
@ -207,8 +207,8 @@ class GlobalDeviceArray:
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> assert jax.device_count() == 8
|
||||
@ -490,8 +490,8 @@ class GlobalDeviceArray:
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> global_input_shape = (8, 8)
|
||||
|
@ -209,7 +209,7 @@ to return the entire array, which will then be sent in a single infeed to the
|
||||
same device that issued the outfeed. This device is then responsible for
|
||||
sending the required shards to the other devices::
|
||||
|
||||
with maps.Mesh(jax.local_devices()[:2], ["d"]):
|
||||
with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
|
||||
pjit.pjit(power3, in_axis_resources=(P("d"),),
|
||||
out_axis_resources=(P("d"),))(np.array([3., 4.]))
|
||||
|
||||
|
@ -219,7 +219,7 @@ def create_device_mesh(
|
||||
mesh_shape: Sequence[int],
|
||||
devices: Optional[Sequence[Any]] = None, *,
|
||||
contiguous_submeshes: bool = False) -> np.ndarray:
|
||||
"""Creates a performant device mesh for jax.experimental.maps.Mesh.
|
||||
"""Creates a performant device mesh for jax.sharding.Mesh.
|
||||
|
||||
Args:
|
||||
mesh_shape: shape of logical mesh, ordered by increasing network-intensity
|
||||
@ -234,7 +234,7 @@ def create_device_mesh(
|
||||
|
||||
Returns:
|
||||
A np.ndarray of JAX devices with mesh_shape as its shape that can be fed
|
||||
into jax.experimental.maps.Mesh with good collective performance.
|
||||
into jax.sharding.Mesh with good collective performance.
|
||||
"""
|
||||
if devices is None:
|
||||
devices = jax.devices()
|
||||
@ -293,7 +293,7 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
|
||||
|
||||
Returns:
|
||||
A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape
|
||||
that can be fed into jax.experimental.maps.Mesh for hybrid parallelism.
|
||||
that can be fed into jax.sharding.Mesh for hybrid parallelism.
|
||||
"""
|
||||
if devices is None:
|
||||
devices = jax.devices()
|
||||
|
@ -26,7 +26,6 @@ from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax.tree_util import PyTreeDef
|
||||
from jax.interpreters import pxla, xla
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
from jax.experimental.pjit import pjit, FROM_GDA
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
@ -102,7 +101,7 @@ def _handle_array_process_allgather(inp, tiled):
|
||||
# All inputs here will be fully addressable.
|
||||
devices = np.array(jax.devices()).reshape(jax.process_count(),
|
||||
jax.local_device_count())
|
||||
global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
|
||||
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
|
||||
pspec = P('processes')
|
||||
s = jax.sharding.NamedSharding(global_mesh, pspec)
|
||||
|
||||
@ -158,7 +157,7 @@ def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
|
||||
# Shape of local_mesh will always be (1, local_device_count())
|
||||
devices = np.array(jax.devices()).reshape(jax.process_count(),
|
||||
jax.local_device_count())
|
||||
global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
|
||||
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
|
||||
in_axis_resources = P('processes')
|
||||
if inp.ndim == 0 or not tiled:
|
||||
inp = np.expand_dims(inp, axis=0)
|
||||
|
@ -2343,9 +2343,9 @@ class Mesh(ContextDecorator):
|
||||
|
||||
Example:
|
||||
|
||||
>>> from jax.experimental.maps import Mesh
|
||||
>>> from jax.experimental.pjit import pjit
|
||||
>>> from jax.experimental import PartitionSpec as P
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> inp = np.arange(16).reshape((8, 2))
|
||||
|
@ -60,7 +60,6 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
from jax._src import array, sharding
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import maps
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import device_array
|
||||
@ -1503,7 +1502,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
@jax_config.jax_array(True)
|
||||
def test_device_put_sharding(self):
|
||||
mesh = maps.Mesh(jax.devices(), ('x',))
|
||||
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
|
||||
s = sharding.NamedSharding(mesh, P('x'))
|
||||
x = jnp.arange(len(jax.devices()))
|
||||
|
||||
@ -1529,7 +1528,8 @@ class APITest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices")
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)),
|
||||
("x", "y"))
|
||||
s1 = sharding.NamedSharding(mesh, P("x"))
|
||||
s2 = sharding.NamedSharding(mesh, P("y"))
|
||||
s3 = sharding.NamedSharding(mesh, P("x", "y"))
|
||||
@ -1552,7 +1552,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices")
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
|
||||
s1 = sharding.NamedSharding(mesh, P("x"))
|
||||
s2 = sharding.NamedSharding(mesh, P("y"))
|
||||
|
||||
@ -1574,7 +1574,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices")
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
|
||||
s1 = sharding.NamedSharding(mesh, P("x"))
|
||||
s2 = sharding.NamedSharding(mesh, P("y"))
|
||||
|
||||
@ -1596,7 +1596,7 @@ class APITest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices")
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y"))
|
||||
s1 = sharding.NamedSharding(mesh, P("x"))
|
||||
s2 = sharding.NamedSharding(mesh, P("y"))
|
||||
|
||||
|
@ -30,13 +30,12 @@ from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.serialize_executable import (
|
||||
compile_and_serialize, load_compiled)
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import sharding
|
||||
from jax._src import array
|
||||
from jax._src import prng
|
||||
from jax.experimental import maps
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -363,7 +362,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 3:
|
||||
self.skipTest('Requires more than 3 devices')
|
||||
shape = (8, 2)
|
||||
mesh = maps.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
|
||||
mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
|
||||
# sharding device ids = {1, 2}
|
||||
s = sharding.NamedSharding(mesh, P('x'))
|
||||
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -506,7 +505,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.skipTest('Test requires >= 2 devices.')
|
||||
|
||||
single_dev = jax.devices()[1:2]
|
||||
mesh = maps.Mesh(np.array(single_dev), ('x'))
|
||||
mesh = jax.sharding.Mesh(np.array(single_dev), ('x'))
|
||||
input_shape = (8, 2)
|
||||
arr, input_data = create_array(
|
||||
input_shape, sharding.NamedSharding(mesh, P('x')))
|
||||
@ -984,7 +983,7 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
def fun(x):
|
||||
return x * x
|
||||
|
||||
with maps.Mesh(np.array(jax.devices()), ('data',)):
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ('data',)):
|
||||
lowered = pjit(
|
||||
fun,
|
||||
in_axis_resources=P('data'),
|
||||
|
@ -26,7 +26,6 @@ from jax._src.lib import xla_extension
|
||||
from jax.config import config
|
||||
from jax.experimental import checkify
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import maps
|
||||
from jax._src.sharding import NamedSharding
|
||||
from jax._src import array
|
||||
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
|
||||
@ -479,13 +478,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
# binary func
|
||||
return x / y
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()), ["dev"])
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ["dev"])
|
||||
if config.jax_array:
|
||||
ps = NamedSharding(mesh, pjit.PartitionSpec("dev"))
|
||||
ps = NamedSharding(mesh, jax.sharding.PartitionSpec("dev"))
|
||||
inp = np.arange(8)
|
||||
x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx])
|
||||
else:
|
||||
ps = pjit.PartitionSpec("dev")
|
||||
ps = jax.sharding.PartitionSpec("dev")
|
||||
x = jnp.arange(8)
|
||||
|
||||
f = pjit.pjit(f, in_axis_resources=ps, out_axis_resources=ps)
|
||||
|
@ -23,7 +23,7 @@ from unittest import mock, SkipTest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
|
@ -147,7 +147,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 2:
|
||||
raise SkipTest("test requires >=2 devices")
|
||||
|
||||
p = pjit.PartitionSpec('x')
|
||||
p = jax.sharding.PartitionSpec('x')
|
||||
f = pjit.pjit(lambda x: 0. / x,
|
||||
in_axis_resources=p,
|
||||
out_axis_resources=p)
|
||||
@ -175,7 +175,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 2:
|
||||
raise SkipTest("test requires >=2 devices")
|
||||
|
||||
p = pjit.PartitionSpec('x')
|
||||
p = jax.sharding.PartitionSpec('x')
|
||||
f = pjit.pjit(lambda x: 0. / x,
|
||||
in_axis_resources=p,
|
||||
out_axis_resources=p,
|
||||
|
@ -21,7 +21,6 @@ from typing import IO, Sequence, Tuple
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax._src import debugger
|
||||
from jax._src import test_util as jtu
|
||||
@ -330,9 +329,9 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
def g(x):
|
||||
y = f(x)
|
||||
return jnp.exp(y)
|
||||
g = pjit.pjit(g, in_axis_resources=pjit.PartitionSpec("dev"),
|
||||
out_axis_resources=pjit.PartitionSpec("dev"))
|
||||
with maps.Mesh(np.array(jax.devices()), ["dev"]):
|
||||
g = pjit.pjit(g, in_axis_resources=jax.sharding.PartitionSpec("dev"),
|
||||
out_axis_resources=jax.sharding.PartitionSpec("dev"))
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]):
|
||||
arr = (1 + np.arange(8)).astype(np.int32)
|
||||
expected = _format_multiline(r"""
|
||||
Entering jdb:
|
||||
|
@ -790,13 +790,13 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
debug_print("{}", x, ordered=False)
|
||||
return x
|
||||
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
|
||||
if config.jax_array:
|
||||
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
|
||||
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
else:
|
||||
spec = pjit.PartitionSpec('dev')
|
||||
out_spec = pjit.PartitionSpec()
|
||||
spec = jax.sharding.PartitionSpec('dev')
|
||||
out_spec = jax.sharding.PartitionSpec()
|
||||
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
|
||||
with mesh:
|
||||
with jtu.capture_stdout() as output:
|
||||
@ -809,7 +809,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
debug_print("{}", y, ordered=False)
|
||||
return y
|
||||
f2 = pjit.pjit(f2, in_axis_resources=spec, out_axis_resources=out_spec)
|
||||
with maps.Mesh(np.array(jax.devices()), ['dev']):
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
|
||||
with jtu.capture_stdout() as output:
|
||||
f2(np.arange(8, dtype=jnp.int32))
|
||||
jax.effects_barrier()
|
||||
@ -847,11 +847,11 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
return (i + 1, x)
|
||||
return lax.while_loop(cond, body, (0, x))[1]
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
|
||||
if config.jax_array:
|
||||
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
|
||||
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
|
||||
else:
|
||||
spec = pjit.PartitionSpec('dev')
|
||||
spec = jax.sharding.PartitionSpec('dev')
|
||||
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
|
||||
with mesh:
|
||||
with jtu.capture_stdout() as output:
|
||||
@ -876,13 +876,13 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x)
|
||||
debug_print("Out: {}", out)
|
||||
return out
|
||||
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
|
||||
if config.jax_array:
|
||||
in_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
|
||||
in_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
else:
|
||||
in_spec = pjit.PartitionSpec('dev')
|
||||
out_spec = pjit.PartitionSpec()
|
||||
in_spec = jax.sharding.PartitionSpec('dev')
|
||||
out_spec = jax.sharding.PartitionSpec()
|
||||
f = pjit.pjit(f, in_axis_resources=in_spec, out_axis_resources=out_spec)
|
||||
with mesh:
|
||||
with jtu.capture_stdout() as output:
|
||||
@ -900,7 +900,7 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
debug_print("{}", x, ordered=False)
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
|
||||
axis_resources={'a': 'dev'})
|
||||
with maps.Mesh(np.array(jax.devices()), ['dev']):
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
|
||||
with jtu.capture_stdout() as output:
|
||||
f(np.arange(40))
|
||||
jax.effects_barrier()
|
||||
@ -991,8 +991,8 @@ class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
return np.array(devices).reshape(shape)
|
||||
|
||||
def test_trivial_sharding(self):
|
||||
mesh = maps.Mesh(self._create_devices(1), ['x'])
|
||||
pspec = pjit.PartitionSpec('x')
|
||||
mesh = jax.sharding.Mesh(self._create_devices(1), ['x'])
|
||||
pspec = jax.sharding.PartitionSpec('x')
|
||||
sd = sharding.NamedSharding(mesh, pspec)
|
||||
shape = (5,)
|
||||
with jtu.capture_stdout() as output:
|
||||
@ -1004,8 +1004,8 @@ class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
"""))
|
||||
|
||||
def test_trivial_sharding_with_scale(self):
|
||||
mesh = maps.Mesh(self._create_devices(1), ['x'])
|
||||
pspec = pjit.PartitionSpec('x')
|
||||
mesh = jax.sharding.Mesh(self._create_devices(1), ['x'])
|
||||
pspec = jax.sharding.PartitionSpec('x')
|
||||
sd = sharding.NamedSharding(mesh, pspec)
|
||||
shape = (5,)
|
||||
with jtu.capture_stdout() as output:
|
||||
@ -1017,8 +1017,8 @@ class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
"""))
|
||||
|
||||
def test_full_sharding(self):
|
||||
mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])
|
||||
pspec = pjit.PartitionSpec('x', 'y')
|
||||
mesh = jax.sharding.Mesh(self._create_devices((8, 4)), ['x', 'y'])
|
||||
pspec = jax.sharding.PartitionSpec('x', 'y')
|
||||
sd = sharding.NamedSharding(mesh, pspec)
|
||||
shape = (8, 8)
|
||||
with jtu.capture_stdout() as output:
|
||||
@ -1046,9 +1046,9 @@ class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_sharding_with_replication(self):
|
||||
shape = (8, 8)
|
||||
mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])
|
||||
mesh = jax.sharding.Mesh(self._create_devices((8, 4)), ['x', 'y'])
|
||||
|
||||
pspec = pjit.PartitionSpec('x', None)
|
||||
pspec = jax.sharding.PartitionSpec('x', None)
|
||||
sd = sharding.NamedSharding(mesh, pspec)
|
||||
with jtu.capture_stdout() as output:
|
||||
debugging.visualize_sharding(shape, sd)
|
||||
@ -1073,8 +1073,8 @@ class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
""")
|
||||
self.assertEqual(output(), expected)
|
||||
|
||||
mesh = maps.Mesh(self._create_devices((4, 2)), ['x', 'y'])
|
||||
pspec = pjit.PartitionSpec(None, 'y')
|
||||
mesh = jax.sharding.Mesh(self._create_devices((4, 2)), ['x', 'y'])
|
||||
pspec = jax.sharding.PartitionSpec(None, 'y')
|
||||
sd = sharding.NamedSharding(mesh, pspec)
|
||||
with jtu.capture_stdout() as output:
|
||||
debugging.visualize_sharding(shape, sd)
|
||||
@ -1095,9 +1095,9 @@ class VisualizeShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_visualize_wide_array(self):
|
||||
shape = (128, 10000)
|
||||
mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])
|
||||
mesh = jax.sharding.Mesh(self._create_devices((8, 4)), ['x', 'y'])
|
||||
|
||||
pspec = pjit.PartitionSpec('x', None)
|
||||
pspec = jax.sharding.PartitionSpec('x', None)
|
||||
sd = sharding.NamedSharding(mesh, pspec)
|
||||
with jtu.capture_stdout() as output:
|
||||
debugging.visualize_sharding(shape, sd)
|
||||
@ -1184,13 +1184,13 @@ class InspectShardingTest(jtu.JaxTestCase):
|
||||
debugging.inspect_array_sharding(x, callback=_cb)
|
||||
return jnp.square(x)
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
|
||||
if config.jax_array:
|
||||
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
|
||||
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
else:
|
||||
spec = pjit.PartitionSpec('dev')
|
||||
out_spec = pjit.PartitionSpec()
|
||||
spec = jax.sharding.PartitionSpec('dev')
|
||||
out_spec = jax.sharding.PartitionSpec()
|
||||
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
|
||||
with mesh:
|
||||
f(np.arange(8, dtype=jnp.int32))
|
||||
|
@ -23,8 +23,8 @@ from jax import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import prod, safe_zip
|
||||
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.sharding import Mesh
|
||||
import jax.experimental.global_device_array as gda_lib
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray, get_shard_indices
|
||||
|
||||
|
@ -31,8 +31,7 @@ from jax import core
|
||||
from jax.config import config
|
||||
from jax import dtypes
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental import maps
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental import pjit
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
@ -1720,7 +1719,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
in_axis_resources=(P("d"),),
|
||||
out_axis_resources=P("d"))
|
||||
|
||||
with maps.Mesh(devices, ["d"]):
|
||||
with jax.sharding.Mesh(devices, ["d"]):
|
||||
# Print the internal IR
|
||||
helper_log_ir(
|
||||
f"{self._testMethodName}.pjit",
|
||||
@ -2338,7 +2337,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
|
||||
pjit_fun = pjit.pjit(
|
||||
fun, in_axis_resources=(P("d"),), out_axis_resources=P("d"))
|
||||
with maps.Mesh(devices, ["d"]):
|
||||
with jax.sharding.Mesh(devices, ["d"]):
|
||||
# Print the internal IR
|
||||
helper_log_ir(
|
||||
f"{self._testMethodName}.pjit",
|
||||
|
@ -272,11 +272,11 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='bar')
|
||||
return x
|
||||
mesh = maps.Mesh(np.array(jax.devices()), ['x'])
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['x'])
|
||||
if config.jax_array:
|
||||
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('x'))
|
||||
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
else:
|
||||
spec = pjit.PartitionSpec('x')
|
||||
spec = jax.sharding.PartitionSpec('x')
|
||||
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
|
||||
with mesh:
|
||||
jaxpr = jax.make_jaxpr(f)(np.arange(jax.local_device_count()))
|
||||
|
@ -33,7 +33,6 @@ import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
|
||||
try:
|
||||
@ -256,7 +255,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
# 13 15
|
||||
assert [d.id for d in device_mesh.flat
|
||||
] == [0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15]
|
||||
return maps.Mesh(device_mesh, ("x", "y"))
|
||||
return jax.sharding.Mesh(device_mesh, ("x", "y"))
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -347,7 +346,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
gda3 = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes3, cb)
|
||||
|
||||
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
|
||||
@functools.partial(
|
||||
pjit.pjit,
|
||||
@ -394,7 +393,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
# The process-gpu mapping is random: @sudhakarsingh27 to figure out why so
|
||||
# and the data is:
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
|
||||
global_mesh = maps.Mesh(mesh_devices, ("x",))
|
||||
global_mesh = jax.sharding.Mesh(mesh_devices, ("x",))
|
||||
global_input_shape = (16,)
|
||||
mesh_axes = experimental.PartitionSpec("x")
|
||||
global_input_data = np.arange(
|
||||
@ -426,7 +425,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
15: ((slice(15, 16),), 0),
|
||||
}
|
||||
|
||||
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
f = pjit.pjit(lambda x: x,
|
||||
in_axis_resources=pjit.FROM_GDA,
|
||||
out_axis_resources=mesh_axes)
|
||||
|
@ -37,7 +37,6 @@ from jax import lax
|
||||
from jax.lax import with_sharding_constraint
|
||||
from jax import prng
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental import multihost_utils
|
||||
@ -284,7 +283,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
||||
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
|
||||
|
||||
@maps.Mesh(mesh_devices, ('x', 'y'))
|
||||
@jax.sharding.Mesh(mesh_devices, ('x', 'y'))
|
||||
def dec():
|
||||
return pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None)(x)
|
||||
out = dec()
|
||||
@ -529,13 +528,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
if devices.size < 4:
|
||||
raise unittest.SkipTest("Test requires 4 devices")
|
||||
devices = devices.reshape((2, 2))
|
||||
with maps.Mesh(devices, ('x', 'y')):
|
||||
with jax.sharding.Mesh(devices, ('x', 'y')):
|
||||
should_be_tracing = True
|
||||
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
||||
should_be_tracing = False
|
||||
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
||||
# Re-create the mesh to make sure that has no influence on caching
|
||||
with maps.Mesh(devices, ('x', 'y')):
|
||||
with jax.sharding.Mesh(devices, ('x', 'y')):
|
||||
should_be_tracing = False
|
||||
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
||||
|
||||
@ -764,7 +763,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
d.transfer_to_infeed(z[3 * didx:3 * didx + 3, :])
|
||||
d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5],))
|
||||
|
||||
with maps.Mesh(devices, ['d']):
|
||||
with jax.sharding.Mesh(devices, ['d']):
|
||||
logging.info('Making pjit call')
|
||||
res = pjit(
|
||||
f_for_pjit, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(
|
||||
@ -788,7 +787,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
def dispatch():
|
||||
with maps.Mesh(devices, ['d']):
|
||||
with jax.sharding.Mesh(devices, ['d']):
|
||||
logging.info('Making pjit call')
|
||||
pjit(f, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(x)
|
||||
execution = threading.Thread(target=dispatch)
|
||||
@ -1417,7 +1416,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
# pickling in_axis_resources and sending to other processes). Make sure this
|
||||
# this doesn't cause an error to avoid user confusion.
|
||||
from_gda_dup = pjit_lib._FromGdaSingleton()
|
||||
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
with jax.sharding.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
pjit(lambda x: x, in_axis_resources=from_gda_dup, out_axis_resources=None)(
|
||||
input_gda)
|
||||
|
||||
@ -3457,7 +3456,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
def testNestedDifferentResources(self):
|
||||
@partial(pjit, in_axis_resources=P('x'), out_axis_resources=None)
|
||||
def f(x):
|
||||
with maps.Mesh(np.array([jax.local_devices()[0]]), ('x')):
|
||||
with jax.sharding.Mesh(np.array([jax.local_devices()[0]]), ('x')):
|
||||
@partial(pjit, in_axis_resources=P('x'), out_axis_resources=None)
|
||||
def h(x):
|
||||
return x
|
||||
@ -3801,12 +3800,12 @@ class UtilTest(jtu.JaxTestCase):
|
||||
P(('x',), ('y',)))
|
||||
|
||||
def test_mesh_with_list_devices(self):
|
||||
mesh = maps.Mesh(jax.devices(), ('x',))
|
||||
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
|
||||
self.assertIsInstance(mesh.devices, np.ndarray)
|
||||
self.assertEqual(mesh.size, jax.device_count())
|
||||
|
||||
def test_mesh_with_string_axis_names(self):
|
||||
mesh = maps.Mesh(jax.devices(), 'dp')
|
||||
mesh = jax.sharding.Mesh(jax.devices(), 'dp')
|
||||
self.assertTupleEqual(mesh.axis_names, ('dp',))
|
||||
|
||||
|
||||
|
@ -33,10 +33,10 @@ from jax.config import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental import io_callback
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -683,7 +683,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
try:
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
||||
|
||||
spec = pjit.PartitionSpec('x')
|
||||
spec = jax.sharding.PartitionSpec('x')
|
||||
|
||||
def f(x):
|
||||
axis_resources = {v: v for v in mesh.axis_names}
|
||||
@ -850,7 +850,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
|
||||
axis_resources={'a': 'dev'})
|
||||
with maps.Mesh(np.array(jax.devices()), ['dev']):
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
|
||||
out = f(np.arange(40.))
|
||||
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
||||
|
||||
@ -866,7 +866,7 @@ class PurePythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
|
||||
axis_resources={'a': 'dev'})
|
||||
with maps.Mesh(np.array(jax.devices()), ['dev']):
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
|
||||
out = f(np.arange(40.))
|
||||
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
||||
|
||||
@ -1020,13 +1020,13 @@ class IOPythonCallbackTest(jtu.JaxTestCase):
|
||||
io_callback(_cb, None, x)
|
||||
return x
|
||||
|
||||
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
|
||||
if config.jax_array:
|
||||
spec = sharding.NamedSharding(mesh, pjit.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, pjit.PartitionSpec())
|
||||
spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
|
||||
out_spec = sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
else:
|
||||
spec = pjit.PartitionSpec('dev')
|
||||
out_spec = pjit.PartitionSpec()
|
||||
spec = jax.sharding.PartitionSpec('dev')
|
||||
out_spec = jax.sharding.PartitionSpec()
|
||||
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
|
||||
with mesh:
|
||||
f(jnp.arange(mesh.size))
|
||||
|
@ -38,7 +38,7 @@ from jax.experimental import global_device_array
|
||||
from jax._src import array
|
||||
from jax._src.sharding import NamedSharding
|
||||
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||
from jax.experimental.pjit import PartitionSpec as P
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap, serial_loop, SerialLoop
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src import config as jax_config
|
||||
@ -273,7 +273,7 @@ class XMapTest(XMapTestCase):
|
||||
def f(a, b):
|
||||
return a * 2, b * 4
|
||||
devices = np.array(local_devices[:4]).reshape((2, 2))
|
||||
with maps.Mesh(devices, ('x', 'y')):
|
||||
with jax.sharding.Mesh(devices, ('x', 'y')):
|
||||
fm = xmap(f,
|
||||
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
||||
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
||||
@ -382,14 +382,14 @@ class XMapTest(XMapTestCase):
|
||||
if devices.size < 2:
|
||||
raise SkipTest("Test requires 2 devices")
|
||||
x = np.arange(8).reshape((2, 2, 2))
|
||||
with maps.Mesh(devices, ('x',)):
|
||||
with jax.sharding.Mesh(devices, ('x',)):
|
||||
python_should_be_executing = True
|
||||
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})(x)
|
||||
python_should_be_executing = False
|
||||
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})(x)
|
||||
with maps.Mesh(devices, ('x',)):
|
||||
with jax.sharding.Mesh(devices, ('x',)):
|
||||
python_should_be_executing = False
|
||||
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': 'x'})(x)
|
||||
@ -1795,7 +1795,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
def testNestedDifferentResources(self):
|
||||
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
||||
def f(x):
|
||||
with maps.Mesh(np.empty((), dtype=np.object_), ()):
|
||||
with jax.sharding.Mesh(np.empty((), dtype=np.object_), ()):
|
||||
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
|
||||
def h(x):
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user