Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.

PiperOrigin-RevId: 548673905
This commit is contained in:
Yash Katariya 2023-07-17 06:35:13 -07:00 committed by jax authors
parent 603eeb1901
commit f0ce0d8c6a
3 changed files with 9 additions and 49 deletions

View File

@ -19,6 +19,15 @@ Remember to align the itemized text with the first line of an item within a list
parameters listed in either donate_argnums or donate_argnames will
be donated.
* Deletions
* `in_axis_resources` and `out_axis_resources` have been deleted from pjit since
it has been more than 3 months since their deprecation. Please use
`in_shardings` and `out_shardings` as the replacement.
This is a safe and trivial name replacement. It does not change any of the
current pjit semantics and doesn't break any code.
You can still pass in `PartitionSpecs` to in_shardings and out_shardings.
* Deprecations
* Python 3.8 support has been dropped as per
https://jax.readthedocs.io/en/latest/deprecation.html

View File

@ -264,37 +264,6 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
return cpp_pjitted_f
def _resolve_axis_resources_and_shardings_arg(
in_shardings, out_shardings, in_axis_resources, out_axis_resources):
if (in_shardings is not None and in_axis_resources is not None and
not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources)):
raise ValueError(
'Setting both in_shardings and in_axis_resources is not '
'allowed. in_axis_resources is deprecated. Please use in_shardings.')
if (out_shardings is not None and out_axis_resources is not None and
not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources)):
raise ValueError(
'Setting both out_shardings and out_axis_resources is not '
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
if ((in_axis_resources is not None and not is_unspecified(in_axis_resources)) or
(out_axis_resources is not None and not is_unspecified(out_axis_resources))):
warnings.warn(
'in_axis_resources and out_axis_resources are deprecated. Please use '
'in_shardings and out_shardings as their replacement.',
DeprecationWarning)
if in_axis_resources is not None and not is_unspecified(in_axis_resources):
final_in_shardings = in_axis_resources
else:
final_in_shardings = in_shardings
if out_axis_resources is not None and not is_unspecified(out_axis_resources):
final_out_shardings = out_axis_resources
else:
final_out_shardings = out_shardings
return final_in_shardings, final_out_shardings
def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, donate_argnames,
static_argnums, static_argnames, device,
@ -588,8 +557,6 @@ def pjit(
fun: Callable,
in_shardings=UNSPECIFIED,
out_shardings=UNSPECIFIED,
in_axis_resources=UNSPECIFIED,
out_axis_resources=UNSPECIFIED,
static_argnums: Union[int, Sequence[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Sequence[int], None] = None,
@ -698,8 +665,6 @@ def pjit(
assignment for function outputs.
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
will use GSPMD's sharding propagation to determine how to shard the outputs.
in_axis_resources: (Deprecated) Please use in_shardings.
out_axis_resources: (Deprecated) Please use out_shardings.
static_argnums: An optional int or collection of ints that specify which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded in
@ -779,9 +744,6 @@ def pjit(
... print(f(x)) # doctest: +SKIP
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
"""
in_shardings, out_shardings = _resolve_axis_resources_and_shardings_arg(
in_shardings, out_shardings, in_axis_resources, out_axis_resources)
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
static_argnames) = pre_infer_params(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,

View File

@ -3010,17 +3010,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
pf1(jnp.arange(8.))
self.assertEqual(count[0], 1)
def test_set_both_axis_resources_and_shardings(self):
with self.assertRaisesRegex(
ValueError,
"Setting both in_shardings and in_axis_resources is not allowed"):
pjit(lambda x: x, in_shardings=P('x'), in_axis_resources=P('x'))
with self.assertRaisesRegex(
ValueError,
"Setting both out_shardings and out_axis_resources is not allowed"):
pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x'))
def test_with_sharding_constraint_spmd_axis_name(self):
mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
shape = (8, 4, 2, 2)