mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change the axis_resources
argument of with_sharding_constraint
to shardings
to match pjit
and jit
.
PiperOrigin-RevId: 509275107
This commit is contained in:
parent
c49af18b9b
commit
2fc64bee13
@ -23,6 +23,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
|
||||
The merge must be disabled via an environment variable since it affects JAX
|
||||
at import time so it needs to be disabled before jax is imported.
|
||||
* `axis_resources` argument of `with_sharding_constraint` is deprecated.
|
||||
Please use `shardings` instead. There is no change needed if you were using
|
||||
`axis_resources` as an arg. If you were using it as a kwarg, then please
|
||||
use `shardings` instead. `axis_resources` will be removed after 3 months
|
||||
from Feb 13, 2023.
|
||||
|
||||
## jaxlib 0.4.4
|
||||
* Breaking changes
|
||||
|
@ -1975,37 +1975,65 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
|
||||
|
||||
# -------------------- with_sharding_constraint --------------------
|
||||
|
||||
def with_sharding_constraint(x, axis_resources):
|
||||
def _resolve_wsc_args(axis_resources, shardings):
|
||||
if not _is_unspecified(axis_resources) and not _is_unspecified(shardings):
|
||||
raise ValueError(
|
||||
'Setting both axis_resources and shardings is not '
|
||||
'allowed. axis_resources is deprecated. Please use shardings.')
|
||||
if _is_unspecified(axis_resources) and _is_unspecified(shardings):
|
||||
raise ValueError(
|
||||
'Not specifying shardings to `with_sharding_constraint` is not allowed. '
|
||||
'Please specify the shardings argument with a concrete sharding. Note '
|
||||
'that axis_resources is deprecated, so use the shardings argument.')
|
||||
|
||||
if not _is_unspecified(axis_resources):
|
||||
final_shardings = axis_resources
|
||||
else:
|
||||
final_shardings = shardings
|
||||
return final_shardings
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove the axis_resources argument and make the signature
|
||||
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
|
||||
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
|
||||
def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
|
||||
shardings=_UNSPECIFIED):
|
||||
final_shardings = _resolve_wsc_args(axis_resources, shardings)
|
||||
x_flat, tree = tree_flatten(x)
|
||||
axis_resources, _, _ = _prepare_axis_resources(
|
||||
axis_resources, "axis_resources", allow_unconstrained_dims=True)
|
||||
axis_resources_flat = tuple(
|
||||
flatten_axes("with_sharding_constraint sharding", tree, axis_resources))
|
||||
user_shardings, _, _ = _prepare_axis_resources(
|
||||
final_shardings, "shardings", allow_unconstrained_dims=True)
|
||||
del final_shardings
|
||||
|
||||
user_shardings_flat = tuple(
|
||||
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
|
||||
del user_shardings
|
||||
|
||||
resource_env = pxla.thread_resources.env
|
||||
mesh = resource_env.physical_mesh
|
||||
|
||||
if config.jax_array:
|
||||
sharding_flat = [_create_sharding_for_array(mesh, a)
|
||||
for a in axis_resources_flat]
|
||||
unconstrained_dims = [
|
||||
get_unconstrained_dims(s) if isinstance(s, NamedSharding) else {}
|
||||
for s in sharding_flat
|
||||
]
|
||||
shardings_flat = [_create_sharding_for_array(mesh, a)
|
||||
for a in user_shardings_flat]
|
||||
unconstrained_dims = [get_unconstrained_dims(s)
|
||||
if isinstance(s, NamedSharding) else {}
|
||||
for s in shardings_flat]
|
||||
else:
|
||||
sharding_flat = [pxla.create_mesh_pspec_sharding(mesh, a.user_spec, a)
|
||||
for a in axis_resources_flat]
|
||||
shardings_flat = [pxla.create_mesh_pspec_sharding(mesh, a.user_spec, a)
|
||||
for a in user_shardings_flat]
|
||||
# Calculate unconstrained_dims from NamedSharding because that information
|
||||
# is lost when converted to OpSharding. Bind unconstrained_dims to
|
||||
# with_sharding_constraint primitive.
|
||||
unconstrained_dims = [get_unconstrained_dims(s) for s in sharding_flat]
|
||||
unconstrained_dims = [get_unconstrained_dims(s) for s in shardings_flat]
|
||||
|
||||
pjit_check_aval_sharding(sharding_flat, x_flat, "with_sharding_constraint arguments",
|
||||
del user_shardings_flat
|
||||
|
||||
pjit_check_aval_sharding(shardings_flat, x_flat, "with_sharding_constraint arguments",
|
||||
allow_uneven_sharding=True)
|
||||
|
||||
outs = [sharding_constraint_p.bind(xf, sharding=to_op_sharding_sharding(i, xf.ndim),
|
||||
resource_env=resource_env,
|
||||
unconstrained_dims=ud)
|
||||
for xf, i, ud in safe_zip(x_flat, sharding_flat, unconstrained_dims)]
|
||||
for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)]
|
||||
return tree_unflatten(tree, outs)
|
||||
|
||||
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):
|
||||
|
@ -3348,6 +3348,19 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
"Setting both out_shardings and out_axis_resources is not allowed"):
|
||||
pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x'))
|
||||
|
||||
def test_set_none_wsc_axis_resources_and_shardings(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Not specifying shardings to `with_sharding_constraint` is not allowed."):
|
||||
pjit(jax.lax.with_sharding_constraint(jnp.arange(8)))
|
||||
|
||||
def test_set_both_wsc_axis_resources_and_shardings(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Setting both axis_resources and shardings is not allowed"):
|
||||
pjit(jax.lax.with_sharding_constraint(
|
||||
jnp.arange(8), axis_resources=P('x'), shardings=P('x')))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
@ -3520,7 +3533,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testConstraintShardsXMapAxis(self):
|
||||
spec = P('x')
|
||||
f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
|
||||
f = xmap(lambda x: with_sharding_constraint(x, spec),
|
||||
in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
|
||||
x = jnp.arange(4).reshape((2, 2))
|
||||
error = (r"with_sharding_constraint input has an axis resources specification of " +
|
||||
|
Loading…
x
Reference in New Issue
Block a user