Move with_sharding_constraint out of experimental into jax.lax namespace.

PiperOrigin-RevId: 494635809
This commit is contained in:
Yash Katariya 2022-12-11 22:54:39 -08:00 committed by jax authors
parent 94590e27bc
commit 13c34f9dc5
3 changed files with 11 additions and 7 deletions

View File

@ -24,6 +24,8 @@ Remember to align the itemized text with the first line of an item within a list
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are `jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months. deprecated and will be removed in 3 months.
* `with_sharding_constraint`s new public endpoint is
`jax.lax.with_sharding_constraint`.
* If using ABSL flags together with `jax.config`, the ABSL flag values are no * If using ABSL flags together with `jax.config`, the ABSL flag values are no
longer read or written after the JAX configuration options are initially longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading populated from the ABSL flags. This change improves performance of reading

View File

@ -363,3 +363,5 @@ from jax._src.lax.ann import (
) )
from jax._src.ad_util import stop_gradient_p as stop_gradient_p from jax._src.ad_util import stop_gradient_p as stop_gradient_p
from jax.lax import linalg as linalg from jax.lax import linalg as linalg
from jax.experimental.pjit import with_sharding_constraint

View File

@ -35,6 +35,7 @@ from jax import dtypes
from jax import stages from jax import stages
from jax.errors import JAXTypeError from jax.errors import JAXTypeError
from jax import lax from jax import lax
from jax.lax import with_sharding_constraint
from jax import prng from jax import prng
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jax.experimental import maps from jax.experimental import maps
@ -45,8 +46,7 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array from jax._src import array
from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding
import jax.experimental.pjit as pjit_lib import jax.experimental.pjit as pjit_lib
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint, from jax.experimental.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
FROM_GDA, AUTO)
from jax.interpreters import pxla from jax.interpreters import pxla
from jax.interpreters import mlir from jax.interpreters import mlir
from jax._src.lib import xla_client as xc, xla_bridge, xla_extension_version from jax._src.lib import xla_client as xc, xla_bridge, xla_extension_version
@ -454,7 +454,7 @@ class PJitTest(jtu.BufferDonationTestCase):
def testShardingConstraintPyTree(self): def testShardingConstraintPyTree(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None) @partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x): def f(x):
x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')]) x = jax.lax.with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
x = x.copy() x = x.copy()
x[0]["a"] *= 2 x[0]["a"] *= 2
return x return x
@ -2432,7 +2432,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@partial(jax.jit, static_argnums=(0, 1)) @partial(jax.jit, static_argnums=(0, 1))
def sharded_zeros(shape, pspec): def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16) out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec)) return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
out = sharded_zeros((4096, 3072), P('x', 'y')) out = sharded_zeros((4096, 3072), P('x', 'y'))
out_s = NamedSharding(mesh, P('x', 'y')) out_s = NamedSharding(mesh, P('x', 'y'))
@ -2447,7 +2447,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@partial(pjit, static_argnums=(0, 1)) @partial(pjit, static_argnums=(0, 1))
def sharded_zeros(shape, pspec): def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16) out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec)) return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
out = sharded_zeros((4096, 3072), P('x', 'y')) out = sharded_zeros((4096, 3072), P('x', 'y'))
out_s = NamedSharding(mesh, P('x', 'y')) out_s = NamedSharding(mesh, P('x', 'y'))
@ -2461,7 +2461,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax.jit @jax.jit
def sharded_inp(inp): def sharded_inp(inp):
return pjit_lib.with_sharding_constraint( return jax.lax.with_sharding_constraint(
inp, NamedSharding(mesh, P('x', 'y'))) inp, NamedSharding(mesh, P('x', 'y')))
committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0]) committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0])
@ -2477,7 +2477,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0]) @partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0])
def sharded_zeros(shape, pspec): def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16) out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec)) return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,