Expose get_abstract_mesh via the jax.sharding namespace

PiperOrigin-RevId: 736972976
This commit is contained in:
Yash Katariya 2025-03-14 13:44:32 -07:00 committed by jax authors
parent a11d8891ce
commit aa9480a441
3 changed files with 700 additions and 701 deletions

File diff suppressed because it is too large Load Diff

View File

@ -59,9 +59,8 @@ outputId: a64bcbcb-27f8-4c57-8931-8091c9bb8ebf
import jax
import numpy as np
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, AxisType, set_mesh
from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh
from jax.experimental.shard import reshard, auto_axes
from jax._src.mesh import get_abstract_mesh
jax.config.update('jax_num_cpu_devices', 8)
```

View File

@ -28,10 +28,11 @@ from jax._src.sharding_impls import (
from jax._src.partition_spec import (
PartitionSpec as PartitionSpec,
)
from jax._src.interpreters.pxla import Mesh as Mesh
from jax._src.mesh import (
Mesh as Mesh,
AbstractMesh as AbstractMesh,
AxisType as AxisType,
get_abstract_mesh as get_abstract_mesh,
)
_deprecations = {