diff --git a/CHANGELOG.md b/CHANGELOG.md index 78e893dd5..ea901c9b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ Remember to align the itemized text with the first line of an item within a list are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. `jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are deprecated and will be removed in 3 months. + * `with_sharding_constraint`s new public endpoint is + `jax.lax.with_sharding_constraint`. * If using ABSL flags together with `jax.config`, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of reading diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 084932263..5b42e644d 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -363,3 +363,5 @@ from jax._src.lax.ann import ( ) from jax._src.ad_util import stop_gradient_p as stop_gradient_p from jax.lax import linalg as linalg + +from jax.experimental.pjit import with_sharding_constraint diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 040936246..3f34b7ce1 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -35,6 +35,7 @@ from jax import dtypes from jax import stages from jax.errors import JAXTypeError from jax import lax +from jax.lax import with_sharding_constraint from jax import prng from jax.sharding import PartitionSpec as P from jax.experimental import maps @@ -45,8 +46,7 @@ from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding import jax.experimental.pjit as pjit_lib -from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint, - FROM_GDA, AUTO) +from jax.experimental.pjit import (pjit, pjit_p, FROM_GDA, AUTO) from jax.interpreters import pxla from jax.interpreters import mlir from jax._src.lib import xla_client as xc, xla_bridge, xla_extension_version @@ -454,7 +454,7 @@ class PJitTest(jtu.BufferDonationTestCase): def testShardingConstraintPyTree(self): @partial(pjit, in_axis_resources=None, out_axis_resources=None) def f(x): - x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')]) + x = jax.lax.with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')]) x = x.copy() x[0]["a"] *= 2 return x @@ -2432,7 +2432,7 @@ class ArrayPjitTest(jtu.JaxTestCase): @partial(jax.jit, static_argnums=(0, 1)) def sharded_zeros(shape, pspec): out = jnp.zeros(shape, jnp.bfloat16) - return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec)) + return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec)) out = sharded_zeros((4096, 3072), P('x', 'y')) out_s = NamedSharding(mesh, P('x', 'y')) @@ -2447,7 +2447,7 @@ class ArrayPjitTest(jtu.JaxTestCase): @partial(pjit, static_argnums=(0, 1)) def sharded_zeros(shape, pspec): out = jnp.zeros(shape, jnp.bfloat16) - return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec)) + return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec)) out = sharded_zeros((4096, 3072), P('x', 'y')) out_s = NamedSharding(mesh, P('x', 'y')) @@ -2461,7 +2461,7 @@ class ArrayPjitTest(jtu.JaxTestCase): @jax.jit def sharded_inp(inp): - return pjit_lib.with_sharding_constraint( + return jax.lax.with_sharding_constraint( inp, NamedSharding(mesh, P('x', 'y'))) committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0]) @@ -2477,7 +2477,7 @@ class ArrayPjitTest(jtu.JaxTestCase): @partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0]) def sharded_zeros(shape, pspec): out = jnp.zeros(shape, jnp.bfloat16) - return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec)) + return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec)) with self.assertRaisesRegex( ValueError,