mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move with_sharding_constraint
out of experimental into jax.lax
namespace.
PiperOrigin-RevId: 494635809
This commit is contained in:
parent
94590e27bc
commit
13c34f9dc5
@ -24,6 +24,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
|
||||
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
|
||||
deprecated and will be removed in 3 months.
|
||||
* `with_sharding_constraint`s new public endpoint is
|
||||
`jax.lax.with_sharding_constraint`.
|
||||
* If using ABSL flags together with `jax.config`, the ABSL flag values are no
|
||||
longer read or written after the JAX configuration options are initially
|
||||
populated from the ABSL flags. This change improves performance of reading
|
||||
|
@ -363,3 +363,5 @@ from jax._src.lax.ann import (
|
||||
)
|
||||
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
|
||||
from jax.lax import linalg as linalg
|
||||
|
||||
from jax.experimental.pjit import with_sharding_constraint
|
||||
|
@ -35,6 +35,7 @@ from jax import dtypes
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax import lax
|
||||
from jax.lax import with_sharding_constraint
|
||||
from jax import prng
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental import maps
|
||||
@ -45,8 +46,7 @@ from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding
|
||||
import jax.experimental.pjit as pjit_lib
|
||||
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
|
||||
FROM_GDA, AUTO)
|
||||
from jax.experimental.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc, xla_bridge, xla_extension_version
|
||||
@ -454,7 +454,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
def testShardingConstraintPyTree(self):
|
||||
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
|
||||
def f(x):
|
||||
x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
|
||||
x = jax.lax.with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
|
||||
x = x.copy()
|
||||
x[0]["a"] *= 2
|
||||
return x
|
||||
@ -2432,7 +2432,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@partial(jax.jit, static_argnums=(0, 1))
|
||||
def sharded_zeros(shape, pspec):
|
||||
out = jnp.zeros(shape, jnp.bfloat16)
|
||||
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
||||
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
||||
|
||||
out = sharded_zeros((4096, 3072), P('x', 'y'))
|
||||
out_s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -2447,7 +2447,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@partial(pjit, static_argnums=(0, 1))
|
||||
def sharded_zeros(shape, pspec):
|
||||
out = jnp.zeros(shape, jnp.bfloat16)
|
||||
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
||||
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
||||
|
||||
out = sharded_zeros((4096, 3072), P('x', 'y'))
|
||||
out_s = NamedSharding(mesh, P('x', 'y'))
|
||||
@ -2461,7 +2461,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def sharded_inp(inp):
|
||||
return pjit_lib.with_sharding_constraint(
|
||||
return jax.lax.with_sharding_constraint(
|
||||
inp, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0])
|
||||
@ -2477,7 +2477,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0])
|
||||
def sharded_zeros(shape, pspec):
|
||||
out = jnp.zeros(shape, jnp.bfloat16)
|
||||
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
||||
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
|
Loading…
x
Reference in New Issue
Block a user