mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Expose get_abstract_mesh
via the jax.sharding
namespace
PiperOrigin-RevId: 736972976
This commit is contained in:
parent
a11d8891ce
commit
aa9480a441
@ -62,9 +62,8 @@
|
||||
"import jax\n",
|
||||
"import numpy as np\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from jax.sharding import PartitionSpec as P, AxisType, set_mesh\n",
|
||||
"from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh\n",
|
||||
"from jax.experimental.shard import reshard, auto_axes\n",
|
||||
"from jax._src.mesh import get_abstract_mesh\n",
|
||||
"\n",
|
||||
"jax.config.update('jax_num_cpu_devices', 8)"
|
||||
]
|
||||
|
@ -59,9 +59,8 @@ outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P, AxisType, set_mesh
|
||||
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
|
||||
from jax.experimental.shard import reshard, auto_axes
|
||||
from jax._src.mesh import get_abstract_mesh
|
||||
|
||||
jax.config.update('jax_num_cpu_devices', 8)
|
||||
```
|
||||
|
@ -28,10 +28,11 @@ from jax._src.sharding_impls import (
|
||||
from jax._src.partition_spec import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
)
|
||||
from jax._src.interpreters.pxla import Mesh as Mesh
|
||||
from jax._src.mesh import (
|
||||
Mesh as Mesh,
|
||||
AbstractMesh as AbstractMesh,
|
||||
AxisType as AxisType,
|
||||
get_abstract_mesh as get_abstract_mesh,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user