Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.

PiperOrigin-RevId: 509275107
This commit is contained in:
Yash Katariya 2023-02-13 10:53:21 -08:00 committed by jax authors
parent c49af18b9b
commit 2fc64bee13
3 changed files with 63 additions and 17 deletions

View File

@ -23,6 +23,11 @@ Remember to align the itemized text with the first line of an item within a list
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.
* `axis_resources` argument of `with_sharding_constraint` is deprecated.
Please use `shardings` instead. There is no change needed if you were using
`axis_resources` as an arg. If you were using it as a kwarg, then please
use `shardings` instead. `axis_resources` will be removed after 3 months
from Feb 13, 2023.
## jaxlib 0.4.4
* Breaking changes

View File

@ -1975,37 +1975,65 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
# -------------------- with_sharding_constraint --------------------
def with_sharding_constraint(x, axis_resources):
def _resolve_wsc_args(axis_resources, shardings):
if not _is_unspecified(axis_resources) and not _is_unspecified(shardings):
raise ValueError(
'Setting both axis_resources and shardings is not '
'allowed. axis_resources is deprecated. Please use shardings.')
if _is_unspecified(axis_resources) and _is_unspecified(shardings):
raise ValueError(
'Not specifying shardings to `with_sharding_constraint` is not allowed. '
'Please specify the shardings argument with a concrete sharding. Note '
'that axis_resources is deprecated, so use the shardings argument.')
if not _is_unspecified(axis_resources):
final_shardings = axis_resources
else:
final_shardings = shardings
return final_shardings
# TODO(yashkatariya): Remove the axis_resources argument and make the signature
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
shardings=_UNSPECIFIED):
final_shardings = _resolve_wsc_args(axis_resources, shardings)
x_flat, tree = tree_flatten(x)
axis_resources, _, _ = _prepare_axis_resources(
axis_resources, "axis_resources", allow_unconstrained_dims=True)
axis_resources_flat = tuple(
flatten_axes("with_sharding_constraint sharding", tree, axis_resources))
user_shardings, _, _ = _prepare_axis_resources(
final_shardings, "shardings", allow_unconstrained_dims=True)
del final_shardings
user_shardings_flat = tuple(
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
del user_shardings
resource_env = pxla.thread_resources.env
mesh = resource_env.physical_mesh
if config.jax_array:
sharding_flat = [_create_sharding_for_array(mesh, a)
for a in axis_resources_flat]
unconstrained_dims = [
get_unconstrained_dims(s) if isinstance(s, NamedSharding) else {}
for s in sharding_flat
]
shardings_flat = [_create_sharding_for_array(mesh, a)
for a in user_shardings_flat]
unconstrained_dims = [get_unconstrained_dims(s)
if isinstance(s, NamedSharding) else {}
for s in shardings_flat]
else:
sharding_flat = [pxla.create_mesh_pspec_sharding(mesh, a.user_spec, a)
for a in axis_resources_flat]
shardings_flat = [pxla.create_mesh_pspec_sharding(mesh, a.user_spec, a)
for a in user_shardings_flat]
# Calculate unconstrained_dims from NamedSharding because that information
# is lost when converted to OpSharding. Bind unconstrained_dims to
# with_sharding_constraint primitive.
unconstrained_dims = [get_unconstrained_dims(s) for s in sharding_flat]
unconstrained_dims = [get_unconstrained_dims(s) for s in shardings_flat]
pjit_check_aval_sharding(sharding_flat, x_flat, "with_sharding_constraint arguments",
del user_shardings_flat
pjit_check_aval_sharding(shardings_flat, x_flat, "with_sharding_constraint arguments",
allow_uneven_sharding=True)
outs = [sharding_constraint_p.bind(xf, sharding=to_op_sharding_sharding(i, xf.ndim),
resource_env=resource_env,
unconstrained_dims=ud)
for xf, i, ud in safe_zip(x_flat, sharding_flat, unconstrained_dims)]
for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)]
return tree_unflatten(tree, outs)
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):

View File

@ -3348,6 +3348,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
"Setting both out_shardings and out_axis_resources is not allowed"):
pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x'))
def test_set_none_wsc_axis_resources_and_shardings(self):
with self.assertRaisesRegex(
ValueError,
"Not specifying shardings to `with_sharding_constraint` is not allowed."):
pjit(jax.lax.with_sharding_constraint(jnp.arange(8)))
def test_set_both_wsc_axis_resources_and_shardings(self):
with self.assertRaisesRegex(
ValueError,
"Setting both axis_resources and shardings is not allowed"):
pjit(jax.lax.with_sharding_constraint(
jnp.arange(8), axis_resources=P('x'), shardings=P('x')))
class TempSharding(Sharding):
@ -3520,7 +3533,7 @@ class PJitErrorTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 2)])
def testConstraintShardsXMapAxis(self):
spec = P('x')
f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
f = xmap(lambda x: with_sharding_constraint(x, spec),
in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
x = jnp.arange(4).reshape((2, 2))
error = (r"with_sharding_constraint input has an axis resources specification of " +