mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Allow None to be passed to in_shardings and out_shardings. The default is still UNSPECIFIED to handle edge cases around the old semantics where None is treated as fully replicated.
The semantics are as follow: * if the mesh context manager is not provided, None will be treated as UNSPECIFIED for both in_shardings and out_shardings * If the mesh context manager is provided, None will be treated as fully replicated as per the old semantics. This will make sure that we don't break existing code depending on None meaning replicated but also start making the transition to None meaning UNSPECIFIED for jit and pjit. PiperOrigin-RevId: 540705660
This commit is contained in:
parent
904b46a2d7
commit
6007698f4e
19
CHANGELOG.md
19
CHANGELOG.md
@ -7,6 +7,25 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
-->
|
||||
|
||||
## jax 0.4.13
|
||||
|
||||
* Changes
|
||||
* `jax.jit` now allows `None` to be passed to `in_shardings` and
|
||||
`out_shardings`. The semantics are as follows:
|
||||
* For in_shardings, JAX will mark is as replicated but this behavior
|
||||
can change in the future.
|
||||
* For out_shardings, we will rely on the XLA GSPMD partitioner to
|
||||
determine the output shardings.
|
||||
* `jax.experimental.pjit.pjit` also allows `None` to be passed to
|
||||
`in_shardings` and `out_shardings`. The semantics are as follows:
|
||||
* If the mesh context manager is *not* provided, JAX has the freedom to
|
||||
choose whatever sharding it wants.
|
||||
* For in_shardings, JAX will mark is as replicated but this behavior
|
||||
can change in the future.
|
||||
* For out_shardings, we will rely on the XLA GSPMD partitioner to
|
||||
determine the output shardings.
|
||||
* If the mesh context manager is provided, None will imply that the value
|
||||
will be replicated on all devices of the mesh.
|
||||
|
||||
* Bug fixes
|
||||
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
|
||||
is named `cudnn89` instead of `cudnn88`.
|
||||
|
@ -187,6 +187,12 @@ def jit(
|
||||
- :py:class:`XLACompatibleSharding`, which will decide how the value
|
||||
will be partitioned. With this, using a mesh context manager is not
|
||||
required.
|
||||
- :py:obj:`None`, will give JAX the freedom to choose whatever sharding
|
||||
it wants.
|
||||
For in_shardings, JAX will mark is as replicated but this behavior
|
||||
can change in the future.
|
||||
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
||||
determine the output shardings.
|
||||
|
||||
The size of every dimension has to be a multiple of the total number of
|
||||
resources assigned to it. This is similar to pjit's in_shardings.
|
||||
|
@ -266,27 +266,29 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
|
||||
def _resolve_axis_resources_and_shardings_arg(
|
||||
in_shardings, out_shardings, in_axis_resources, out_axis_resources):
|
||||
if not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources):
|
||||
if (in_shardings is not None and in_axis_resources is not None and
|
||||
not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources)):
|
||||
raise ValueError(
|
||||
'Setting both in_shardings and in_axis_resources is not '
|
||||
'allowed. in_axis_resources is deprecated. Please use in_shardings.')
|
||||
if not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources):
|
||||
if (out_shardings is not None and out_axis_resources is not None and
|
||||
not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources)):
|
||||
raise ValueError(
|
||||
'Setting both out_shardings and out_axis_resources is not '
|
||||
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
|
||||
if (not is_unspecified(in_axis_resources) or
|
||||
not is_unspecified(out_axis_resources)):
|
||||
if ((in_axis_resources is not None and not is_unspecified(in_axis_resources)) or
|
||||
(out_axis_resources is not None and not is_unspecified(out_axis_resources))):
|
||||
warnings.warn(
|
||||
'in_axis_resources and out_axis_resources are deprecated. Please use '
|
||||
'in_shardings and out_shardings as their replacement.',
|
||||
DeprecationWarning)
|
||||
|
||||
if not is_unspecified(in_axis_resources):
|
||||
if in_axis_resources is not None and not is_unspecified(in_axis_resources):
|
||||
final_in_shardings = in_axis_resources
|
||||
else:
|
||||
final_in_shardings = in_shardings
|
||||
|
||||
if not is_unspecified(out_axis_resources):
|
||||
if out_axis_resources is not None and not is_unspecified(out_axis_resources):
|
||||
final_out_shardings = out_axis_resources
|
||||
else:
|
||||
final_out_shardings = out_shardings
|
||||
@ -311,10 +313,10 @@ def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
f"got {device=} and {backend=}")
|
||||
if not is_unspecified(in_shardings):
|
||||
if in_shardings is not None and not is_unspecified(in_shardings):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'in_shardings should not be specified.')
|
||||
if not is_unspecified(out_shardings):
|
||||
if out_shardings is not None and not is_unspecified(out_shardings):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'out_shardings should not be specified.')
|
||||
|
||||
@ -413,7 +415,8 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
donate_argnums, device, backend, keep_unused, inline,
|
||||
resource_env, abstracted_axes) = pjit_info_args
|
||||
|
||||
if kwargs and not is_unspecified(user_in_shardings):
|
||||
if (kwargs and user_in_shardings is not None and
|
||||
not is_unspecified(user_in_shardings)):
|
||||
raise ValueError(
|
||||
"pjit does not support kwargs when in_shardings is specified.")
|
||||
|
||||
@ -467,14 +470,17 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
in_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x, 'in_shardings',
|
||||
jit_name),
|
||||
user_in_shardings)
|
||||
user_in_shardings, is_leaf=lambda x: x is None)
|
||||
out_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x, 'out_shardings',
|
||||
jit_name),
|
||||
user_out_shardings)
|
||||
user_out_shardings, is_leaf=lambda x: x is None)
|
||||
|
||||
del user_in_shardings, user_out_shardings
|
||||
|
||||
assert in_shardings is not None or all(i is not None for i in in_shardings)
|
||||
assert out_shardings is not None or all(o is not None for o in out_shardings)
|
||||
|
||||
if config.jax_dynamic_shapes:
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
|
||||
in_avals = tuple(a for a, e in in_type if e)
|
||||
@ -661,11 +667,18 @@ def pjit(
|
||||
- :py:class:`XLACompatibleSharding`, which will decide how the value
|
||||
will be partitioned. With this, using a mesh context manager is not
|
||||
required.
|
||||
- :py:obj:`None` is a special case whose semantics are:
|
||||
- if the mesh context manager is *not* provided, JAX has the freedom to
|
||||
choose whatever sharding it wants.
|
||||
For in_shardings, JAX will mark is as replicated but this behavior
|
||||
can change in the future.
|
||||
For out_shardings, we will rely on the XLA GSPMD partitioner to
|
||||
determine the output shardings.
|
||||
- If the mesh context manager is provided, None will imply that the
|
||||
value will be replicated on all devices of the mesh.
|
||||
- For backwards compatibility, in_shardings still supports ingesting
|
||||
:py:class:`PartitionSpec` and :py:obj:`None`. These 2 options can
|
||||
*only* be used with the mesh context manager.
|
||||
|
||||
- :py:obj:`None`, in which case the value will be replicated on all devices
|
||||
:py:class:`PartitionSpec`. This option can *only* be used with the
|
||||
mesh context manager.
|
||||
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
|
||||
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
|
||||
axis or a tuple of mesh axes, and specifies the set of resources assigned
|
||||
@ -774,14 +787,9 @@ def hashable_pytree(pytree):
|
||||
closure=(treedef, vals))
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x):
|
||||
if is_unspecified_or_auto(x):
|
||||
return x
|
||||
return pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x)
|
||||
|
||||
|
||||
def _create_sharding_for_array(mesh, x, name, api_name):
|
||||
if x is None and (mesh is None or mesh.empty):
|
||||
return UNSPECIFIED
|
||||
if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x):
|
||||
return x
|
||||
if mesh is None:
|
||||
@ -804,8 +812,9 @@ def _create_sharding_for_array(mesh, x, name, api_name):
|
||||
f' site? Alternatively, provide `XLACompatibleSharding`s to {name} and'
|
||||
' then the mesh context manager is not required.')
|
||||
# A nice user error is raised in prepare_axis_resources.
|
||||
assert isinstance(x, ParsedPartitionSpec), x
|
||||
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x)
|
||||
assert x is None or isinstance(x, ParsedPartitionSpec), x
|
||||
return (pxla.create_mesh_pspec_sharding(mesh, x)
|
||||
if x is None else pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x))
|
||||
|
||||
|
||||
def _create_sharding_with_device_backend(device, backend):
|
||||
|
@ -217,7 +217,8 @@ class NamedSharding(XLACompatibleSharding):
|
||||
# representation of Parsed Pspec
|
||||
if self._parsed_pspec is None:
|
||||
self._parsed_pspec, _, _ = prepare_axis_resources(
|
||||
self.spec, "NamedSharding spec", allow_unconstrained_dims=True)
|
||||
PartitionSpec() if self.spec is None else self.spec,
|
||||
"NamedSharding spec", allow_unconstrained_dims=True)
|
||||
|
||||
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
||||
|
||||
@ -956,7 +957,7 @@ def prepare_axis_resources(axis_resources,
|
||||
|
||||
new_entries = []
|
||||
for entry in entries:
|
||||
if is_unspecified_or_auto(entry):
|
||||
if is_unspecified_or_auto(entry) or entry is None:
|
||||
new_entries.append(entry)
|
||||
elif isinstance(entry, sharding.Sharding):
|
||||
if isinstance(entry, PmapSharding):
|
||||
|
@ -958,8 +958,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
|
||||
self.assertEqual(mps.is_fully_replicated, ops_ifr)
|
||||
|
||||
ps = _op_sharding_to_pos_sharding(mps_op_sharding,
|
||||
mps._device_assignment)
|
||||
ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment)
|
||||
self.assertEqual(ps.is_fully_replicated,
|
||||
op_shardings.is_op_sharding_replicated(
|
||||
ps._to_xla_hlo_sharding(len(shape))))
|
||||
|
@ -2573,7 +2573,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"pjit does not support kwargs when in_shardings is specified."):
|
||||
pjit(lambda x: x, in_shardings=None)(x=jnp.arange(8.))
|
||||
pjit(lambda x: x,
|
||||
in_shardings=SingleDeviceSharding(jax.devices()[0]))(x=jnp.arange(8.))
|
||||
|
||||
def test_pjit_keep_unused_true(self):
|
||||
@partial(pjit, keep_unused=True)
|
||||
@ -2693,17 +2694,18 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
|
||||
|
||||
def test_pjit_device_backend_axis_resources_error(self):
|
||||
s = SingleDeviceSharding(jax.devices()[0])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'If backend or device is specified on jit, then '
|
||||
'in_shardings should not be specified.'):
|
||||
pjit(lambda x: x, in_shardings=None, backend='cpu')
|
||||
pjit(lambda x: x, in_shardings=s, backend='cpu')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'If backend or device is specified on jit, then '
|
||||
'out_shardings should not be specified.'):
|
||||
pjit(lambda x: x, out_shardings=None, device=jax.devices()[0])
|
||||
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
|
||||
|
||||
def test_pjit_device_backend_both_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
@ -3468,6 +3470,43 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
r"Argument.*is not a valid JAX type"):
|
||||
jax.jit(lambda x: (x, const))(jnp.arange(8))
|
||||
|
||||
def test_jit_out_shardings_none(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
out = jax.jit(lambda x: x * 2, out_shardings=None)(inp)
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
def test_jit_in_shardings_none(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
|
||||
out = jax.jit(lambda x: x * 2, in_shardings=None)(inp)
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
out2 = jax.jit(lambda x: x * 2, in_shardings=None)(np_inp)
|
||||
self.assertArraysEqual(out2, np_inp * 2)
|
||||
self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0]))
|
||||
|
||||
def test_jit_both_shardings_none(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
|
||||
out = jax.jit(lambda x: x * 2, in_shardings=None, out_shardings=None)(inp)
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
out2 = jax.jit(lambda x: x * 2, in_shardings=None, out_shardings=None)(np_inp)
|
||||
self.assertArraysEqual(out2, np_inp * 2)
|
||||
self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0]))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
@ -3683,11 +3722,8 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
f(x, x)
|
||||
|
||||
def testEmptyMesh(self):
|
||||
error = (
|
||||
r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
|
||||
r' `None` to in_shardings.*')
|
||||
with self.assertRaisesRegex(RuntimeError, error):
|
||||
pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
|
||||
out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
|
||||
self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0]))
|
||||
|
||||
def test_pspec_to_wsc_without_mesh(self):
|
||||
error = (
|
||||
|
Loading…
x
Reference in New Issue
Block a user