alias jax.sharding.NamedSharding -> jax.NamedSharding

This commit is contained in:
Matthew Johnson 2024-04-11 16:23:59 -07:00
parent 4f9cb47f24
commit 8588d4b747
2 changed files with 5 additions and 4 deletions

View File

@ -125,6 +125,7 @@ from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import xla_computation as xla_computation
from jax._src.sharding_impls import NamedSharding as NamedSharding
# Force import, allowing jax.interpreters.* to be used after import jax.
from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla

View File

@ -1706,7 +1706,7 @@ class APITest(jtu.JaxTestCase):
def test_device_put_sharding(self):
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
s = jax.sharding.NamedSharding(mesh, P('x'))
s = jax.NamedSharding(mesh, P('x'))
x = jnp.arange(len(jax.devices()))
y = jax.device_put(x, s)
@ -1732,9 +1732,9 @@ class APITest(jtu.JaxTestCase):
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)),
("x", "y"))
s1 = jax.sharding.NamedSharding(mesh, P("x"))
s2 = jax.sharding.NamedSharding(mesh, P("y"))
s3 = jax.sharding.NamedSharding(mesh, P("x", "y"))
s1 = jax.NamedSharding(mesh, P("x"))
s2 = jax.NamedSharding(mesh, P("y"))
s3 = jax.NamedSharding(mesh, P("x", "y"))
x = jnp.arange(2)
y = jnp.arange(2) + 10