Add impl rule for with_sharding_constraint so that users can use their functions with and without a jit.

The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.

eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.

PiperOrigin-RevId: 532858670
This commit is contained in:
Yash Katariya 2023-05-17 11:49:31 -07:00 committed by jax authors
parent 79be3482ca
commit f1c2711292
2 changed files with 33 additions and 4 deletions

View File

@ -32,6 +32,7 @@ from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import api
from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
@ -1841,11 +1842,14 @@ def with_sharding_constraint(x, shardings=UNSPECIFIED,
for xf, i, ud in zip(x_flat, shardings_flat, unconstrained_dims)]
return tree_unflatten(tree, outs)
def _identity_fn(x): return x
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):
# TODO(skye): can we also prevent this from being called in other
# non-pjit contexts? (e.g. pmap, control flow)
raise NotImplementedError(
"with_sharding_constraint() should only be called inside pjit()")
if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim):
return x
# Run a jit here to raise good errors when device assignment don't match.
return api.jit(_identity_fn, out_shardings=sharding)(x)
sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)

View File

@ -3384,6 +3384,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
self.assertEqual(count[0], 1)
def test_wsc_eager(self):
mesh = jtu.create_global_mesh((2,), ('x',))
np_inp = np.arange(8)
inp = jax.device_put(np_inp, NamedSharding(mesh, P()))
out = with_sharding_constraint(inp, NamedSharding(mesh, P('x')))
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
for s in out.addressable_shards:
self.assertArraysEqual(s.data, np_inp[s.index])
def test_wsc_eager_no_resharding(self):
mesh = jtu.create_global_mesh((2,), ('x',))
np_inp = np.arange(8)
inp = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
out = with_sharding_constraint(inp, NamedSharding(mesh, P('x')))
self.assertEqual(id(out), id(inp))
def test_wsc_eager_different_order_devices(self):
mesh1 = jtu.create_global_mesh((2,), ('x',))
mesh2 = jax.sharding.Mesh([jax.devices()[1], jax.devices()[0]], 'x')
inp = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
with self.assertRaisesRegex(
ValueError, "Received incompatible devices for jitted computation"):
with_sharding_constraint(inp, NamedSharding(mesh2, P('x')))
class TempSharding(Sharding):