Move with_sharding_constraint out of experimental into jax.lax namespace.

PiperOrigin-RevId: 494635809
This commit is contained in:
Yash Katariya 2022-12-11 22:54:39 -08:00 committed by jax authors
parent 94590e27bc
commit 13c34f9dc5
3 changed files with 11 additions and 7 deletions

View File

@ -24,6 +24,8 @@ Remember to align the itemized text with the first line of an item within a list
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months.
* `with_sharding_constraint`s new public endpoint is
`jax.lax.with_sharding_constraint`.
* If using ABSL flags together with `jax.config`, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading

View File

@ -363,3 +363,5 @@ from jax._src.lax.ann import (
)
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
from jax.lax import linalg as linalg
from jax.experimental.pjit import with_sharding_constraint

View File

@ -35,6 +35,7 @@ from jax import dtypes
from jax import stages
from jax.errors import JAXTypeError
from jax import lax
from jax.lax import with_sharding_constraint
from jax import prng
from jax.sharding import PartitionSpec as P
from jax.experimental import maps
@ -45,8 +46,7 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding
import jax.experimental.pjit as pjit_lib
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
FROM_GDA, AUTO)
from jax.experimental.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
from jax.interpreters import pxla
from jax.interpreters import mlir
from jax._src.lib import xla_client as xc, xla_bridge, xla_extension_version
@ -454,7 +454,7 @@ class PJitTest(jtu.BufferDonationTestCase):
def testShardingConstraintPyTree(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x):
x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
x = jax.lax.with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
x = x.copy()
x[0]["a"] *= 2
return x
@ -2432,7 +2432,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@partial(jax.jit, static_argnums=(0, 1))
def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
out = sharded_zeros((4096, 3072), P('x', 'y'))
out_s = NamedSharding(mesh, P('x', 'y'))
@ -2447,7 +2447,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@partial(pjit, static_argnums=(0, 1))
def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
out = sharded_zeros((4096, 3072), P('x', 'y'))
out_s = NamedSharding(mesh, P('x', 'y'))
@ -2461,7 +2461,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax.jit
def sharded_inp(inp):
return pjit_lib.with_sharding_constraint(
return jax.lax.with_sharding_constraint(
inp, NamedSharding(mesh, P('x', 'y')))
committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0])
@ -2477,7 +2477,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0])
def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
with self.assertRaisesRegex(
ValueError,