mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
alias jax.sharding.NamedSharding -> jax.NamedSharding
This commit is contained in:
parent
4f9cb47f24
commit
8588d4b747
@ -125,6 +125,7 @@ from jax._src.api import value_and_grad as value_and_grad
|
||||
from jax._src.api import vjp as vjp
|
||||
from jax._src.api import vmap as vmap
|
||||
from jax._src.api import xla_computation as xla_computation
|
||||
from jax._src.sharding_impls import NamedSharding as NamedSharding
|
||||
|
||||
# Force import, allowing jax.interpreters.* to be used after import jax.
|
||||
from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla
|
||||
|
@ -1706,7 +1706,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_device_put_sharding(self):
|
||||
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
s = jax.NamedSharding(mesh, P('x'))
|
||||
x = jnp.arange(len(jax.devices()))
|
||||
|
||||
y = jax.device_put(x, s)
|
||||
@ -1732,9 +1732,9 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]).reshape((2, 1)),
|
||||
("x", "y"))
|
||||
s1 = jax.sharding.NamedSharding(mesh, P("x"))
|
||||
s2 = jax.sharding.NamedSharding(mesh, P("y"))
|
||||
s3 = jax.sharding.NamedSharding(mesh, P("x", "y"))
|
||||
s1 = jax.NamedSharding(mesh, P("x"))
|
||||
s2 = jax.NamedSharding(mesh, P("y"))
|
||||
s3 = jax.NamedSharding(mesh, P("x", "y"))
|
||||
|
||||
x = jnp.arange(2)
|
||||
y = jnp.arange(2) + 10
|
||||
|
Loading…
x
Reference in New Issue
Block a user