Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.

The semantics are as follow:

* if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings

* If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics.

This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit.

PiperOrigin-RevId: 540705660
This commit is contained in:
Yash Katariya 2023-06-15 15:21:36 -07:00 committed by jax authors
parent 904b46a2d7
commit 6007698f4e
6 changed files with 106 additions and 36 deletions

View File

@ -7,6 +7,25 @@ Remember to align the itemized text with the first line of an item within a list
--> -->
## jax 0.4.13 ## jax 0.4.13
* Changes
* `jax.jit` now allows `None` to be passed to `in_shardings` and
`out_shardings`. The semantics are as follows:
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* `jax.experimental.pjit.pjit` also allows `None` to be passed to
`in_shardings` and `out_shardings`. The semantics are as follows:
* If the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
* For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
* For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
* If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
* Bug fixes * Bug fixes
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel * Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named `cudnn89` instead of `cudnn88`. is named `cudnn89` instead of `cudnn88`.

View File

@ -187,6 +187,12 @@ def jit(
- :py:class:`XLACompatibleSharding`, which will decide how the value - :py:class:`XLACompatibleSharding`, which will decide how the value
will be partitioned. With this, using a mesh context manager is not will be partitioned. With this, using a mesh context manager is not
required. required.
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
it wants.
For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
The size of every dimension has to be a multiple of the total number of The size of every dimension has to be a multiple of the total number of
resources assigned to it. This is similar to pjit's in_shardings. resources assigned to it. This is similar to pjit's in_shardings.

View File

@ -266,27 +266,29 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
def _resolve_axis_resources_and_shardings_arg( def _resolve_axis_resources_and_shardings_arg(
in_shardings, out_shardings, in_axis_resources, out_axis_resources): in_shardings, out_shardings, in_axis_resources, out_axis_resources):
if not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources): if (in_shardings is not None and in_axis_resources is not None and
not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources)):
raise ValueError( raise ValueError(
'Setting both in_shardings and in_axis_resources is not ' 'Setting both in_shardings and in_axis_resources is not '
'allowed. in_axis_resources is deprecated. Please use in_shardings.') 'allowed. in_axis_resources is deprecated. Please use in_shardings.')
if not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources): if (out_shardings is not None and out_axis_resources is not None and
not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources)):
raise ValueError( raise ValueError(
'Setting both out_shardings and out_axis_resources is not ' 'Setting both out_shardings and out_axis_resources is not '
'allowed. out_axis_resources is deprecated. Please use out_shardings.') 'allowed. out_axis_resources is deprecated. Please use out_shardings.')
if (not is_unspecified(in_axis_resources) or if ((in_axis_resources is not None and not is_unspecified(in_axis_resources)) or
not is_unspecified(out_axis_resources)): (out_axis_resources is not None and not is_unspecified(out_axis_resources))):
warnings.warn( warnings.warn(
'in_axis_resources and out_axis_resources are deprecated. Please use ' 'in_axis_resources and out_axis_resources are deprecated. Please use '
'in_shardings and out_shardings as their replacement.', 'in_shardings and out_shardings as their replacement.',
DeprecationWarning) DeprecationWarning)
if not is_unspecified(in_axis_resources): if in_axis_resources is not None and not is_unspecified(in_axis_resources):
final_in_shardings = in_axis_resources final_in_shardings = in_axis_resources
else: else:
final_in_shardings = in_shardings final_in_shardings = in_shardings
if not is_unspecified(out_axis_resources): if out_axis_resources is not None and not is_unspecified(out_axis_resources):
final_out_shardings = out_axis_resources final_out_shardings = out_axis_resources
else: else:
final_out_shardings = out_shardings final_out_shardings = out_shardings
@ -311,10 +313,10 @@ def pre_infer_params(fun, in_shardings, out_shardings,
if device is not None and backend is not None: if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, " raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}") f"got {device=} and {backend=}")
if not is_unspecified(in_shardings): if in_shardings is not None and not is_unspecified(in_shardings):
raise ValueError('If backend or device is specified on jit, then ' raise ValueError('If backend or device is specified on jit, then '
'in_shardings should not be specified.') 'in_shardings should not be specified.')
if not is_unspecified(out_shardings): if out_shardings is not None and not is_unspecified(out_shardings):
raise ValueError('If backend or device is specified on jit, then ' raise ValueError('If backend or device is specified on jit, then '
'out_shardings should not be specified.') 'out_shardings should not be specified.')
@ -413,7 +415,8 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
donate_argnums, device, backend, keep_unused, inline, donate_argnums, device, backend, keep_unused, inline,
resource_env, abstracted_axes) = pjit_info_args resource_env, abstracted_axes) = pjit_info_args
if kwargs and not is_unspecified(user_in_shardings): if (kwargs and user_in_shardings is not None and
not is_unspecified(user_in_shardings)):
raise ValueError( raise ValueError(
"pjit does not support kwargs when in_shardings is specified.") "pjit does not support kwargs when in_shardings is specified.")
@ -467,14 +470,17 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
in_shardings = tree_map( in_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x, 'in_shardings', lambda x: _create_sharding_for_array(pjit_mesh, x, 'in_shardings',
jit_name), jit_name),
user_in_shardings) user_in_shardings, is_leaf=lambda x: x is None)
out_shardings = tree_map( out_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x, 'out_shardings', lambda x: _create_sharding_for_array(pjit_mesh, x, 'out_shardings',
jit_name), jit_name),
user_out_shardings) user_out_shardings, is_leaf=lambda x: x is None)
del user_in_shardings, user_out_shardings del user_in_shardings, user_out_shardings
assert in_shardings is not None or all(i is not None for i in in_shardings)
assert out_shardings is not None or all(o is not None for o in out_shardings)
if config.jax_dynamic_shapes: if config.jax_dynamic_shapes:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e) in_avals = tuple(a for a, e in in_type if e)
@ -661,11 +667,18 @@ def pjit(
- :py:class:`XLACompatibleSharding`, which will decide how the value - :py:class:`XLACompatibleSharding`, which will decide how the value
will be partitioned. With this, using a mesh context manager is not will be partitioned. With this, using a mesh context manager is not
required. required.
- :py:obj:`None` is a special case whose semantics are:
- if the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- If the mesh context manager is provided, None will imply that the
value will be replicated on all devices of the mesh.
- For backwards compatibility, in_shardings still supports ingesting - For backwards compatibility, in_shardings still supports ingesting
:py:class:`PartitionSpec` and :py:obj:`None`. These 2 options can :py:class:`PartitionSpec`. This option can *only* be used with the
*only* be used with the mesh context manager. mesh context manager.
- :py:obj:`None`, in which case the value will be replicated on all devices
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank - :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
of the partitioned value. Each element can be a :py:obj:`None`, a mesh of the partitioned value. Each element can be a :py:obj:`None`, a mesh
axis or a tuple of mesh axes, and specifies the set of resources assigned axis or a tuple of mesh axes, and specifies the set of resources assigned
@ -774,14 +787,9 @@ def hashable_pytree(pytree):
closure=(treedef, vals)) closure=(treedef, vals))
@lru_cache(maxsize=4096)
def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x):
if is_unspecified_or_auto(x):
return x
return pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x)
def _create_sharding_for_array(mesh, x, name, api_name): def _create_sharding_for_array(mesh, x, name, api_name):
if x is None and (mesh is None or mesh.empty):
return UNSPECIFIED
if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x): if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x):
return x return x
if mesh is None: if mesh is None:
@ -804,8 +812,9 @@ def _create_sharding_for_array(mesh, x, name, api_name):
f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and' f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and'
' then the mesh context manager is not required.') ' then the mesh context manager is not required.')
# A nice user error is raised in prepare_axis_resources. # A nice user error is raised in prepare_axis_resources.
assert isinstance(x, ParsedPartitionSpec), x assert x is None or isinstance(x, ParsedPartitionSpec), x
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x) return (pxla.create_mesh_pspec_sharding(mesh, x)
if x is None else pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x))
def _create_sharding_with_device_backend(device, backend): def _create_sharding_with_device_backend(device, backend):

View File

@ -217,7 +217,8 @@ class NamedSharding(XLACompatibleSharding):
# representation of Parsed Pspec # representation of Parsed Pspec
if self._parsed_pspec is None: if self._parsed_pspec is None:
self._parsed_pspec, _, _ = prepare_axis_resources( self._parsed_pspec, _, _ = prepare_axis_resources(
self.spec, "NamedSharding spec", allow_unconstrained_dims=True) PartitionSpec() if self.spec is None else self.spec,
"NamedSharding spec", allow_unconstrained_dims=True)
_check_mesh_resource_axis(self.mesh, self._parsed_pspec) _check_mesh_resource_axis(self.mesh, self._parsed_pspec)
@ -956,7 +957,7 @@ def prepare_axis_resources(axis_resources,
new_entries = [] new_entries = []
for entry in entries: for entry in entries:
if is_unspecified_or_auto(entry): if is_unspecified_or_auto(entry) or entry is None:
new_entries.append(entry) new_entries.append(entry)
elif isinstance(entry, sharding.Sharding): elif isinstance(entry, sharding.Sharding):
if isinstance(entry, PmapSharding): if isinstance(entry, PmapSharding):

View File

@ -958,8 +958,7 @@ class ShardingTest(jtu.JaxTestCase):
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding) ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
self.assertEqual(mps.is_fully_replicated, ops_ifr) self.assertEqual(mps.is_fully_replicated, ops_ifr)
ps = _op_sharding_to_pos_sharding(mps_op_sharding, ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment)
mps._device_assignment)
self.assertEqual(ps.is_fully_replicated, self.assertEqual(ps.is_fully_replicated,
op_shardings.is_op_sharding_replicated( op_shardings.is_op_sharding_replicated(
ps._to_xla_hlo_sharding(len(shape)))) ps._to_xla_hlo_sharding(len(shape))))

View File

@ -2573,7 +2573,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
"pjit does not support kwargs when in_shardings is specified."): "pjit does not support kwargs when in_shardings is specified."):
pjit(lambda x: x, in_shardings=None)(x=jnp.arange(8.)) pjit(lambda x: x,
in_shardings=SingleDeviceSharding(jax.devices()[0]))(x=jnp.arange(8.))
def test_pjit_keep_unused_true(self): def test_pjit_keep_unused_true(self):
@partial(pjit, keep_unused=True) @partial(pjit, keep_unused=True)
@ -2693,17 +2694,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2) jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
def test_pjit_device_backend_axis_resources_error(self): def test_pjit_device_backend_axis_resources_error(self):
s = SingleDeviceSharding(jax.devices()[0])
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
'If backend or device is specified on jit, then ' 'If backend or device is specified on jit, then '
'in_shardings should not be specified.'): 'in_shardings should not be specified.'):
pjit(lambda x: x, in_shardings=None, backend='cpu') pjit(lambda x: x, in_shardings=s, backend='cpu')
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
'If backend or device is specified on jit, then ' 'If backend or device is specified on jit, then '
'out_shardings should not be specified.'): 'out_shardings should not be specified.'):
pjit(lambda x: x, out_shardings=None, device=jax.devices()[0]) pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
def test_pjit_device_backend_both_error(self): def test_pjit_device_backend_both_error(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -3468,6 +3470,43 @@ class ArrayPjitTest(jtu.JaxTestCase):
r"Argument.*is not a valid JAX type"): r"Argument.*is not a valid JAX type"):
jax.jit(lambda x: (x, const))(jnp.arange(8)) jax.jit(lambda x: (x, const))(jnp.arange(8))
def test_jit_out_shardings_none(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
out = jax.jit(lambda x: x * 2, out_shardings=None)(inp)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.sharding, s)
def test_jit_in_shardings_none(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
out = jax.jit(lambda x: x * 2, in_shardings=None)(inp)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.sharding, s)
out2 = jax.jit(lambda x: x * 2, in_shardings=None)(np_inp)
self.assertArraysEqual(out2, np_inp * 2)
self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0]))
def test_jit_both_shardings_none(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
out = jax.jit(lambda x: x * 2, in_shardings=None, out_shardings=None)(inp)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.sharding, s)
out2 = jax.jit(lambda x: x * 2, in_shardings=None, out_shardings=None)(np_inp)
self.assertArraysEqual(out2, np_inp * 2)
self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0]))
class TempSharding(Sharding): class TempSharding(Sharding):
@ -3683,11 +3722,8 @@ class PJitErrorTest(jtu.JaxTestCase):
f(x, x) f(x, x)
def testEmptyMesh(self): def testEmptyMesh(self):
error = ( out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or' self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0]))
r' `None` to in_shardings.*')
with self.assertRaisesRegex(RuntimeError, error):
pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
def test_pspec_to_wsc_without_mesh(self): def test_pspec_to_wsc_without_mesh(self):
error = ( error = (