Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.

This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 510027595
This commit is contained in:
Peter Hawkins 2023-02-15 21:02:22 -08:00 committed by jax authors
parent 1b2a318fd1
commit 0af9fff5ca
5 changed files with 11 additions and 10 deletions

View File

@ -479,8 +479,8 @@ We are using `jax.experimental.pjit.pjit` for parallel execution on multiple dev
Let's first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding `x` on all devices.
```python
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec, pjit
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit
mesh = Mesh(jax.local_devices(), ("x",))
@ -777,11 +777,12 @@ import jax.numpy as jnp
from build import gpu_ops
from jax import core, dtypes
from jax.abstract_arrays import ShapedArray
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir
from jax.lib import xla_client
from jax.sharding import Mesh, PartitionSpec
from jaxlib.mhlo_helpers import custom_call

View File

@ -169,7 +169,7 @@
"source": [
"import jax\n",
"import numpy as np\n",
"from jax.experimental.maps import Mesh\n",
"from jax.sharding import Mesh\n",
"\n",
"loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],\n",
" axis_resources={'batch': 'x'})\n",
@ -790,7 +790,7 @@
},
"outputs": [],
"source": [
"from jax.experimental.maps import Mesh\n",
"from jax.sharding import Mesh\n",
"\n",
"local = local_matmul(x, x) # The local function doesn't require the mesh definition\n",
"with Mesh(*mesh_def): # Makes the mesh axis names available as resources\n",

View File

@ -120,7 +120,7 @@ But on a whim we can decide to parallelize over the batch axis:
import jax
import numpy as np
from jax.experimental.maps import Mesh
from jax.sharding import Mesh
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
axis_resources={'batch': 'x'})
@ -536,7 +536,7 @@ To introduce the resources in a scope, use the `with Mesh` context manager:
```{code-cell} ipython3
:id: kYdoeaSS9m9f
from jax.experimental.maps import Mesh
from jax.sharding import Mesh
local = local_matmul(x, x) # The local function doesn't require the mesh definition
with Mesh(*mesh_def): # Makes the mesh axis names available as resources

View File

@ -213,7 +213,7 @@ class custom_partitioning:
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.experimental.maps import Mesh
from jax.sharding import Mesh
from jax.numpy.fft import fft
import regex as re
import numpy as np

View File

@ -29,7 +29,7 @@ from jax.experimental import global_device_array as gda
from jax._src import array
from jax._src import sharding
from jax._src import typing
from jax.experimental.maps import Mesh
from jax.sharding import Mesh
import jax.numpy as jnp
import numpy as np
import tensorstore as ts