mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates: * {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh * {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec * jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding. PiperOrigin-RevId: 510027595
This commit is contained in:
parent
1b2a318fd1
commit
0af9fff5ca
@ -479,8 +479,8 @@ We are using `jax.experimental.pjit.pjit` for parallel execution on multiple dev
|
||||
Let's first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding `x` on all devices.
|
||||
|
||||
```python
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
from jax.sharding import Mesh, PartitionSpec
|
||||
from jax.experimental.pjit import pjit
|
||||
|
||||
|
||||
mesh = Mesh(jax.local_devices(), ("x",))
|
||||
@ -777,11 +777,12 @@ import jax.numpy as jnp
|
||||
from build import gpu_ops
|
||||
from jax import core, dtypes
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax.experimental.maps import Mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.interpreters import mlir, xla
|
||||
from jax.interpreters.mlir import ir
|
||||
from jax.lib import xla_client
|
||||
from jax.sharding import Mesh, PartitionSpec
|
||||
from jaxlib.mhlo_helpers import custom_call
|
||||
|
||||
|
||||
|
@ -169,7 +169,7 @@
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"import numpy as np\n",
|
||||
"from jax.experimental.maps import Mesh\n",
|
||||
"from jax.sharding import Mesh\n",
|
||||
"\n",
|
||||
"loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],\n",
|
||||
" axis_resources={'batch': 'x'})\n",
|
||||
@ -790,7 +790,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.experimental.maps import Mesh\n",
|
||||
"from jax.sharding import Mesh\n",
|
||||
"\n",
|
||||
"local = local_matmul(x, x) # The local function doesn't require the mesh definition\n",
|
||||
"with Mesh(*mesh_def): # Makes the mesh axis names available as resources\n",
|
||||
|
@ -120,7 +120,7 @@ But on a whim we can decide to parallelize over the batch axis:
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.sharding import Mesh
|
||||
|
||||
loss = xmap(named_loss, in_axes=in_axes, out_axes=[...],
|
||||
axis_resources={'batch': 'x'})
|
||||
@ -536,7 +536,7 @@ To introduce the resources in a scope, use the `with Mesh` context manager:
|
||||
```{code-cell} ipython3
|
||||
:id: kYdoeaSS9m9f
|
||||
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.sharding import Mesh
|
||||
|
||||
local = local_matmul(x, x) # The local function doesn't require the mesh definition
|
||||
with Mesh(*mesh_def): # Makes the mesh axis names available as resources
|
||||
|
@ -213,7 +213,7 @@ class custom_partitioning:
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.sharding import Mesh
|
||||
from jax.numpy.fft import fft
|
||||
import regex as re
|
||||
import numpy as np
|
||||
|
@ -29,7 +29,7 @@ from jax.experimental import global_device_array as gda
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax._src import typing
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.sharding import Mesh
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
|
Loading…
x
Reference in New Issue
Block a user