Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.

PiperOrigin-RevId: 509275107
This commit is contained in:
Yash Katariya 2023-02-13 10:53:21 -08:00 committed by jax authors
parent c49af18b9b
commit 2fc64bee13
3 changed files with 63 additions and 17 deletions

View File

@ -23,6 +23,11 @@ Remember to align the itemized text with the first line of an item within a list
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
The merge must be disabled via an environment variable since it affects JAX The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported. at import time so it needs to be disabled before jax is imported.
* `axis_resources` argument of `with_sharding_constraint` is deprecated.
Please use `shardings` instead. There is no change needed if you were using
`axis_resources` as an arg. If you were using it as a kwarg, then please
use `shardings` instead. `axis_resources` will be removed after 3 months
from Feb 13, 2023.
## jaxlib 0.4.4 ## jaxlib 0.4.4
* Breaking changes * Breaking changes

View File

@ -1975,37 +1975,65 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
# -------------------- with_sharding_constraint -------------------- # -------------------- with_sharding_constraint --------------------
def with_sharding_constraint(x, axis_resources): def _resolve_wsc_args(axis_resources, shardings):
if not _is_unspecified(axis_resources) and not _is_unspecified(shardings):
raise ValueError(
'Setting both axis_resources and shardings is not '
'allowed. axis_resources is deprecated. Please use shardings.')
if _is_unspecified(axis_resources) and _is_unspecified(shardings):
raise ValueError(
'Not specifying shardings to `with_sharding_constraint` is not allowed. '
'Please specify the shardings argument with a concrete sharding. Note '
'that axis_resources is deprecated, so use the shardings argument.')
if not _is_unspecified(axis_resources):
final_shardings = axis_resources
else:
final_shardings = shardings
return final_shardings
# TODO(yashkatariya): Remove the axis_resources argument and make the signature
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
shardings=_UNSPECIFIED):
final_shardings = _resolve_wsc_args(axis_resources, shardings)
x_flat, tree = tree_flatten(x) x_flat, tree = tree_flatten(x)
axis_resources, _, _ = _prepare_axis_resources( user_shardings, _, _ = _prepare_axis_resources(
axis_resources, "axis_resources", allow_unconstrained_dims=True) final_shardings, "shardings", allow_unconstrained_dims=True)
axis_resources_flat = tuple( del final_shardings
flatten_axes("with_sharding_constraint sharding", tree, axis_resources))
user_shardings_flat = tuple(
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
del user_shardings
resource_env = pxla.thread_resources.env resource_env = pxla.thread_resources.env
mesh = resource_env.physical_mesh mesh = resource_env.physical_mesh
if config.jax_array: if config.jax_array:
sharding_flat = [_create_sharding_for_array(mesh, a) shardings_flat = [_create_sharding_for_array(mesh, a)
for a in axis_resources_flat] for a in user_shardings_flat]
unconstrained_dims = [ unconstrained_dims = [get_unconstrained_dims(s)
get_unconstrained_dims(s) if isinstance(s, NamedSharding) else {} if isinstance(s, NamedSharding) else {}
for s in sharding_flat for s in shardings_flat]
]
else: else:
sharding_flat = [pxla.create_mesh_pspec_sharding(mesh, a.user_spec, a) shardings_flat = [pxla.create_mesh_pspec_sharding(mesh, a.user_spec, a)
for a in axis_resources_flat] for a in user_shardings_flat]
# Calculate unconstrained_dims from NamedSharding because that information # Calculate unconstrained_dims from NamedSharding because that information
# is lost when converted to OpSharding. Bind unconstrained_dims to # is lost when converted to OpSharding. Bind unconstrained_dims to
# with_sharding_constraint primitive. # with_sharding_constraint primitive.
unconstrained_dims = [get_unconstrained_dims(s) for s in sharding_flat] unconstrained_dims = [get_unconstrained_dims(s) for s in shardings_flat]
pjit_check_aval_sharding(sharding_flat, x_flat, "with_sharding_constraint arguments", del user_shardings_flat
pjit_check_aval_sharding(shardings_flat, x_flat, "with_sharding_constraint arguments",
allow_uneven_sharding=True) allow_uneven_sharding=True)
outs = [sharding_constraint_p.bind(xf, sharding=to_op_sharding_sharding(i, xf.ndim), outs = [sharding_constraint_p.bind(xf, sharding=to_op_sharding_sharding(i, xf.ndim),
resource_env=resource_env, resource_env=resource_env,
unconstrained_dims=ud) unconstrained_dims=ud)
for xf, i, ud in safe_zip(x_flat, sharding_flat, unconstrained_dims)] for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)]
return tree_unflatten(tree, outs) return tree_unflatten(tree, outs)
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims): def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):

View File

@ -3348,6 +3348,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
"Setting both out_shardings and out_axis_resources is not allowed"): "Setting both out_shardings and out_axis_resources is not allowed"):
pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x')) pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x'))
def test_set_none_wsc_axis_resources_and_shardings(self):
with self.assertRaisesRegex(
ValueError,
"Not specifying shardings to `with_sharding_constraint` is not allowed."):
pjit(jax.lax.with_sharding_constraint(jnp.arange(8)))
def test_set_both_wsc_axis_resources_and_shardings(self):
with self.assertRaisesRegex(
ValueError,
"Setting both axis_resources and shardings is not allowed"):
pjit(jax.lax.with_sharding_constraint(
jnp.arange(8), axis_resources=P('x'), shardings=P('x')))
class TempSharding(Sharding): class TempSharding(Sharding):
@ -3520,7 +3533,7 @@ class PJitErrorTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 2)]) @jtu.with_mesh([('x', 2)])
def testConstraintShardsXMapAxis(self): def testConstraintShardsXMapAxis(self):
spec = P('x') spec = P('x')
f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec), f = xmap(lambda x: with_sharding_constraint(x, spec),
in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'}) in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
x = jnp.arange(4).reshape((2, 2)) x = jnp.arange(4).reshape((2, 2))
error = (r"with_sharding_constraint input has an axis resources specification of " + error = (r"with_sharding_constraint input has an axis resources specification of " +