mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Expose get_abstract_mesh
via the jax.sharding
namespace
PiperOrigin-RevId: 736972976
This commit is contained in:
parent
a11d8891ce
commit
aa9480a441
File diff suppressed because it is too large
Load Diff
@ -59,9 +59,8 @@ outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P, AxisType, set_mesh
|
||||
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
|
||||
from jax.experimental.shard import reshard, auto_axes
|
||||
from jax._src.mesh import get_abstract_mesh
|
||||
|
||||
jax.config.update('jax_num_cpu_devices', 8)
|
||||
```
|
||||
|
@ -28,10 +28,11 @@ from jax._src.sharding_impls import (
|
||||
from jax._src.partition_spec import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
)
|
||||
from jax._src.interpreters.pxla import Mesh as Mesh
|
||||
from jax._src.mesh import (
|
||||
Mesh as Mesh,
|
||||
AbstractMesh as AbstractMesh,
|
||||
AxisType as AxisType,
|
||||
get_abstract_mesh as get_abstract_mesh,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user