mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
PiperOrigin-RevId: 548673905
This commit is contained in:
parent
603eeb1901
commit
f0ce0d8c6a
@ -19,6 +19,15 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
parameters listed in either donate_argnums or donate_argnames will
|
parameters listed in either donate_argnums or donate_argnames will
|
||||||
be donated.
|
be donated.
|
||||||
|
|
||||||
|
* Deletions
|
||||||
|
* `in_axis_resources` and `out_axis_resources` have been deleted from pjit since
|
||||||
|
it has been more than 3 months since their deprecation. Please use
|
||||||
|
`in_shardings` and `out_shardings` as the replacement.
|
||||||
|
This is a safe and trivial name replacement. It does not change any of the
|
||||||
|
current pjit semantics and doesn't break any code.
|
||||||
|
You can still pass in `PartitionSpecs` to in_shardings and out_shardings.
|
||||||
|
|
||||||
|
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* Python 3.8 support has been dropped as per
|
* Python 3.8 support has been dropped as per
|
||||||
https://jax.readthedocs.io/en/latest/deprecation.html
|
https://jax.readthedocs.io/en/latest/deprecation.html
|
||||||
|
@ -264,37 +264,6 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
|||||||
return cpp_pjitted_f
|
return cpp_pjitted_f
|
||||||
|
|
||||||
|
|
||||||
def _resolve_axis_resources_and_shardings_arg(
|
|
||||||
in_shardings, out_shardings, in_axis_resources, out_axis_resources):
|
|
||||||
if (in_shardings is not None and in_axis_resources is not None and
|
|
||||||
not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources)):
|
|
||||||
raise ValueError(
|
|
||||||
'Setting both in_shardings and in_axis_resources is not '
|
|
||||||
'allowed. in_axis_resources is deprecated. Please use in_shardings.')
|
|
||||||
if (out_shardings is not None and out_axis_resources is not None and
|
|
||||||
not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources)):
|
|
||||||
raise ValueError(
|
|
||||||
'Setting both out_shardings and out_axis_resources is not '
|
|
||||||
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
|
|
||||||
if ((in_axis_resources is not None and not is_unspecified(in_axis_resources)) or
|
|
||||||
(out_axis_resources is not None and not is_unspecified(out_axis_resources))):
|
|
||||||
warnings.warn(
|
|
||||||
'in_axis_resources and out_axis_resources are deprecated. Please use '
|
|
||||||
'in_shardings and out_shardings as their replacement.',
|
|
||||||
DeprecationWarning)
|
|
||||||
|
|
||||||
if in_axis_resources is not None and not is_unspecified(in_axis_resources):
|
|
||||||
final_in_shardings = in_axis_resources
|
|
||||||
else:
|
|
||||||
final_in_shardings = in_shardings
|
|
||||||
|
|
||||||
if out_axis_resources is not None and not is_unspecified(out_axis_resources):
|
|
||||||
final_out_shardings = out_axis_resources
|
|
||||||
else:
|
|
||||||
final_out_shardings = out_shardings
|
|
||||||
return final_in_shardings, final_out_shardings
|
|
||||||
|
|
||||||
|
|
||||||
def pre_infer_params(fun, in_shardings, out_shardings,
|
def pre_infer_params(fun, in_shardings, out_shardings,
|
||||||
donate_argnums, donate_argnames,
|
donate_argnums, donate_argnames,
|
||||||
static_argnums, static_argnames, device,
|
static_argnums, static_argnames, device,
|
||||||
@ -588,8 +557,6 @@ def pjit(
|
|||||||
fun: Callable,
|
fun: Callable,
|
||||||
in_shardings=UNSPECIFIED,
|
in_shardings=UNSPECIFIED,
|
||||||
out_shardings=UNSPECIFIED,
|
out_shardings=UNSPECIFIED,
|
||||||
in_axis_resources=UNSPECIFIED,
|
|
||||||
out_axis_resources=UNSPECIFIED,
|
|
||||||
static_argnums: Union[int, Sequence[int], None] = None,
|
static_argnums: Union[int, Sequence[int], None] = None,
|
||||||
static_argnames: Union[str, Iterable[str], None] = None,
|
static_argnames: Union[str, Iterable[str], None] = None,
|
||||||
donate_argnums: Union[int, Sequence[int], None] = None,
|
donate_argnums: Union[int, Sequence[int], None] = None,
|
||||||
@ -698,8 +665,6 @@ def pjit(
|
|||||||
assignment for function outputs.
|
assignment for function outputs.
|
||||||
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
|
||||||
will use GSPMD's sharding propagation to determine how to shard the outputs.
|
will use GSPMD's sharding propagation to determine how to shard the outputs.
|
||||||
in_axis_resources: (Deprecated) Please use in_shardings.
|
|
||||||
out_axis_resources: (Deprecated) Please use out_shardings.
|
|
||||||
static_argnums: An optional int or collection of ints that specify which
|
static_argnums: An optional int or collection of ints that specify which
|
||||||
positional arguments to treat as static (compile-time constant).
|
positional arguments to treat as static (compile-time constant).
|
||||||
Operations that only depend on static arguments will be constant-folded in
|
Operations that only depend on static arguments will be constant-folded in
|
||||||
@ -779,9 +744,6 @@ def pjit(
|
|||||||
... print(f(x)) # doctest: +SKIP
|
... print(f(x)) # doctest: +SKIP
|
||||||
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
|
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
|
||||||
"""
|
"""
|
||||||
in_shardings, out_shardings = _resolve_axis_resources_and_shardings_arg(
|
|
||||||
in_shardings, out_shardings, in_axis_resources, out_axis_resources)
|
|
||||||
|
|
||||||
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
|
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
|
||||||
static_argnames) = pre_infer_params(
|
static_argnames) = pre_infer_params(
|
||||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||||
|
@ -3010,17 +3010,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
|||||||
pf1(jnp.arange(8.))
|
pf1(jnp.arange(8.))
|
||||||
self.assertEqual(count[0], 1)
|
self.assertEqual(count[0], 1)
|
||||||
|
|
||||||
def test_set_both_axis_resources_and_shardings(self):
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError,
|
|
||||||
"Setting both in_shardings and in_axis_resources is not allowed"):
|
|
||||||
pjit(lambda x: x, in_shardings=P('x'), in_axis_resources=P('x'))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError,
|
|
||||||
"Setting both out_shardings and out_axis_resources is not allowed"):
|
|
||||||
pjit(lambda x: x, out_shardings=P('x'), out_axis_resources=P('x'))
|
|
||||||
|
|
||||||
def test_with_sharding_constraint_spmd_axis_name(self):
|
def test_with_sharding_constraint_spmd_axis_name(self):
|
||||||
mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
|
mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
|
||||||
shape = (8, 4, 2, 2)
|
shape = (8, 4, 2, 2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user