mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
PiperOrigin-RevId: 548673905
This commit is contained in:
parent
603eeb1901
commit
f0ce0d8c6a
@ -19,6 +19,15 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
parameters listed in either donate_argnums or donate_argnames will
|
||||
be donated.
|
||||
|
||||
* Deletions
|
||||
* `in_axis_resources` and `out_axis_resources` have been deleted from pjit since
|
||||
it has been more than 3 months since their deprecation. Please use
|
||||
`in_shardings` and `out_shardings` as the replacement.
|
||||
This is a safe and trivial name replacement. It does not change any of the
|
||||
current pjit semantics and doesn't break any code.
|
||||
You can still pass in `PartitionSpecs` to in_shardings and out_shardings.
|
||||
|
||||
|
||||
* Deprecations
|
||||
* Python 3.8 support has been dropped as per
|
||||
https://jax.readthedocs.io/en/latest/deprecation.html
|
||||
|
@ -264,37 +264,6 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
return cpp_pjitted_f
|
||||
|
||||
|
||||
def _resolve_axis_resources_and_shardings_arg(
|
||||
in_shardings, out_shardings, in_axis_resources, out_axis_resources):
|
||||
if (in_shardings is not None and in_axis_resources is not None and
|
||||
not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources)):
|
||||
raise ValueError(
|
||||
'Setting both in_shardings and in_axis_resources is not '
|
||||
'allowed. in_axis_resources is deprecated. Please use in_shardings.')
|
||||
if (out_shardings is not None and out_axis_resources is not None and
|
||||
not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources)):
|
||||
raise ValueError(
|
||||
'Setting both out_shardings and out_axis_resources is not '
|
||||
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
|
||||
if ((in_axis_resources is not None and not is_unspecified(in_axis_resources)) or
|
||||
(out_axis_resources is not None and not is_unspecified(out_axis_resources))):
|
||||
warnings.warn(
|
||||
'in_axis_resources and out_axis_resources are deprecated. Please use '
|
||||
'in_shardings and out_shardings as their replacement.',
|
||||
DeprecationWarning)
|
||||
|
||||
if in_axis_resources is not None and not is_unspecified(in_axis_resources):
|
||||
final_in_shardings = in_axis_resources
|
||||
else:
|
||||
final_in_shardings = in_shardings
|
||||
|
||||
if out_axis_resources is not None and not is_unspecified(out_axis_resources):
|
||||
final_out_shardings = out_axis_resources
|
||||
else:
|
||||
final_out_shardings = out_shardings
|
||||
return final_in_shardings, final_out_shardings
|
||||
|
||||
|
||||
def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device,
|
||||
@ -588,8 +557,6 @@ def pjit(
|
||||
fun: Callable,
|
||||
in_shardings=UNSPECIFIED,
|
||||
out_shardings=UNSPECIFIED,
|
||||
in_axis_resources=UNSPECIFIED,
|
||||
out_axis_resources=UNSPECIFIED,
|
||||
static_argnums: Union[int, Sequence[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
donate_argnums: Union[int, Sequence[int], None] = None,
|
||||
@ -698,8 +665,6 @@ def pjit(
|
||||
assignment for function outputs.
|
||||
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
||||
will use GSPMD's sharding propagation to determine how to shard the outputs.
|
||||
in_axis_resources: (Deprecated) Please use in_shardings.
|
||||
out_axis_resources: (Deprecated) Please use out_shardings.
|
||||
static_argnums: An optional int or collection of ints that specify which
|
||||
positional arguments to treat as static (compile-time constant).
|
||||
Operations that only depend on static arguments will be constant-folded in
|
||||
@ -779,9 +744,6 @@ def pjit(
|
||||
... print(f(x)) # doctest: +SKIP
|
||||
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
|
||||
"""
|
||||
in_shardings, out_shardings = _resolve_axis_resources_and_shardings_arg(
|
||||
in_shardings, out_shardings, in_axis_resources, out_axis_resources)
|
||||
|
||||
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
|
||||
static_argnames) = pre_infer_params(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
|
@ -3010,17 +3010,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
pf1(jnp.arange(8.))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_set_both_axis_resources_and_shardings(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Setting both in_shardings and in_axis_resources is not allowed"):
|
||||
pjit(lambda x: x, in_shardings=P('x'), in_axis_resources=P('x'))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Setting both out_shardings and out_axis_resources is not allowed"):
|
||||
pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x'))
|
||||
|
||||
def test_with_sharding_constraint_spmd_axis_name(self):
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
|
||||
shape = (8, 4, 2, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user