mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add impl rule for with_sharding_constraint so that users can use their functions with and without a jit.
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order. eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that. PiperOrigin-RevId: 532858670
This commit is contained in:
parent
79be3482ca
commit
f1c2711292
@ -32,6 +32,7 @@ from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import api
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
@ -1841,11 +1842,14 @@ def with_sharding_constraint(x, shardings=UNSPECIFIED,
|
||||
for xf, i, ud in zip(x_flat, shardings_flat, unconstrained_dims)]
|
||||
return tree_unflatten(tree, outs)
|
||||
|
||||
def _identity_fn(x): return x
|
||||
|
||||
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):
|
||||
# TODO(skye): can we also prevent this from being called in other
|
||||
# non-pjit contexts? (e.g. pmap, control flow)
|
||||
raise NotImplementedError(
|
||||
"with_sharding_constraint() should only be called inside pjit()")
|
||||
if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim):
|
||||
return x
|
||||
# Run a jit here to raise good errors when device assignment don't match.
|
||||
return api.jit(_identity_fn, out_shardings=sharding)(x)
|
||||
|
||||
|
||||
sharding_constraint_p = core.Primitive("sharding_constraint")
|
||||
sharding_constraint_p.def_impl(_sharding_constraint_impl)
|
||||
|
@ -3384,6 +3384,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_wsc_eager(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8)
|
||||
inp = jax.device_put(np_inp, NamedSharding(mesh, P()))
|
||||
out = with_sharding_constraint(inp, NamedSharding(mesh, P('x')))
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
for s in out.addressable_shards:
|
||||
self.assertArraysEqual(s.data, np_inp[s.index])
|
||||
|
||||
def test_wsc_eager_no_resharding(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8)
|
||||
inp = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
||||
out = with_sharding_constraint(inp, NamedSharding(mesh, P('x')))
|
||||
self.assertEqual(id(out), id(inp))
|
||||
|
||||
def test_wsc_eager_different_order_devices(self):
|
||||
mesh1 = jtu.create_global_mesh((2,), ('x',))
|
||||
mesh2 = jax.sharding.Mesh([jax.devices()[1], jax.devices()[0]], 'x')
|
||||
inp = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Received incompatible devices for jitted computation"):
|
||||
with_sharding_constraint(inp, NamedSharding(mesh2, P('x')))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user