Add device and backend API to pjit but resolve them away in infer_params. This is to merge jit and pjit frontend API.

The semantics of mentioning `device` or `backend` on `pjit` is the same as doing a `device_put` i.e. no matter which device the arg is on, reshard it to the device mentioned.

PiperOrigin-RevId: 495437165
This commit is contained in:
Yash Katariya 2022-12-14 15:41:19 -08:00 committed by jax authors
parent 64b6efc680
commit 7d4ef891af
5 changed files with 235 additions and 44 deletions

View File

@ -21,12 +21,13 @@ from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
import itertools as it
from functools import partial, lru_cache
import threading
import warnings
from jax.experimental import maps
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
from jax._src.sharding import (
NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
XLADeviceAssignment)
XLADeviceAssignment, SingleDeviceSharding)
from jax import core
from jax import linear_util as lu
from jax import stages
@ -182,6 +183,8 @@ def pjit(
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Sequence[int]] = (),
keep_unused: bool = False,
device: Optional[xc.Device] = None,
backend: Optional[str] = None,
) -> stages.Wrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
@ -287,6 +290,16 @@ def pjit(
unused by `fun` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
device: This argument is deprecated. Please put your arguments on the
device you want before passing them to jit.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited
from XLA's DeviceAssignment logic and is usually to use
``jax.devices()[0]``.
backend: This argument is deprecated. Please put your arguments on the
backend you want before passing them to jit.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
Returns:
@ -320,6 +333,23 @@ def pjit(
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
"boolean flag to something true-like.")
if backend is not None or device is not None:
warnings.warn(
'backend and device argument on jit is deprecated. You can use a '
'`jax.sharding.Mesh` context manager or device_put the arguments '
'before passing them to `jit`. Please see '
'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html '
'for more information.', DeprecationWarning)
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if not _is_unspecified(in_axis_resources):
raise ValueError('If backend or device is specified on jit, then '
'in_axis_resources should not be specified.')
if not _is_unspecified(out_axis_resources):
raise ValueError('If backend or device is specified on jit, then '
'out_axis_resources should not be specified.')
if isinstance(in_axis_resources, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
@ -353,6 +383,11 @@ def pjit(
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
"it's defined at the call site?")
if (backend or device) and not pjit_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
f = lu.wrap_init(fun)
f, dyn_args = argnums_partial_except(f, static_argnums, args,
allow_invalid=True)
@ -377,10 +412,17 @@ def pjit(
donated_invars = (False,) * len(args_flat)
if config.jax_array:
in_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
out_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
# If backend or device is set as an arg on jit, then resolve them to
# in_shardings and out_shardings as if user passed in in_axis_resources
# and out_axis_resources.
if backend or device:
in_shardings = out_shardings = _create_sharding_with_device_backend(
device, backend)
else:
in_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
out_shardings = tree_map(
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
else:
in_shardings = tree_map(
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
@ -508,6 +550,18 @@ def _create_sharding_for_array(mesh, x):
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x)
def _create_sharding_with_device_backend(device, backend):
if device is not None:
assert backend is None
out = SingleDeviceSharding(device)
elif backend is not None:
assert device is None
out = SingleDeviceSharding(
xb.get_backend(backend).get_default_device_assignment(1)[0])
out._device_backend = True
return out
def flatten_axis_resources(what, tree, shardings, tupled_args):
try:
return tuple(flatten_axes(what, tree, shardings, tupled_args=tupled_args))
@ -877,6 +931,14 @@ pjit_p.multiple_results = True
def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
# If True, means that device or backend is set by the user on pjit and it
# has the same semantics as device_put i.e. doesn't matter which device the
# arg is on, reshard it to the device mentioned. So don't do any of the
# checks and just return the pjit_in_shardings directly. `shard_args` will
# handle the resharding.
if pxla._check_device_backend_on_shardings(pjit_in_shardings):
return pjit_in_shardings
committed_arg_shardings = []
for a in args:
if hasattr(a, 'sharding'):
@ -1405,7 +1467,8 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
for aval, s in zip(jaxpr.in_avals, params['in_shardings']):
if _is_unspecified(s) or _is_auto(s):
continue
elif hasattr(s, '_original_sharding'):
elif hasattr(s, '_original_sharding') and hasattr(
s._original_sharding, '_parsed_pspec'):
parsed_pspec = s._original_sharding._parsed_pspec
else:
parsed_pspec = parse_flatten_op_sharding(
@ -1421,7 +1484,8 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
for aval, s in zip(jaxpr.out_avals, params['out_shardings']):
if _is_unspecified(s) or _is_auto(s):
continue
elif hasattr(s, '_original_sharding'):
elif hasattr(s, '_original_sharding') and hasattr(
s._original_sharding, '_parsed_pspec'):
parsed_pspec = s._original_sharding._parsed_pspec
else:
parsed_pspec = parse_flatten_op_sharding(

View File

@ -3234,7 +3234,8 @@ def _get_normalized_avals_and_shardings(
in_sharding = i
else:
assert isinstance(i, sharding_internal.NamedSharding)
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
aval = i.mesh._global_to_local(
cast(ArrayMapping, _get_array_mapping(i.spec)), gaval) # pylint: disable=g-bare-generic
in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec)
avals.append(aval)
shardings.append(in_sharding)
@ -3467,9 +3468,8 @@ class UnloadedMeshExecutable:
are_out_shardings_from_xla = (False,) * len(global_out_avals)
input_avals, input_shardings = (
_get_normalized_avals_and_shardings(global_in_avals,
in_shardings, # type: ignore # arg-type
in_is_global))
_get_normalized_avals_and_shardings(
global_in_avals, in_shardings, in_is_global)) # type: ignore # arg-type
return UnloadedMeshExecutable(
xla_executable=xla_executable,
@ -3718,6 +3718,16 @@ def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
return sharding_internal.NamedSharding(mesh, pspec, parsed_pspec)
def _check_device_backend_on_shardings(shardings) -> bool:
for i in shardings:
if _is_unspecified(i) or _is_auto(i):
continue
if hasattr(i, '_original_sharding') and getattr(
i._original_sharding, '_device_backend', False):
return True
return False
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
from jax.experimental.global_device_array import GlobalDeviceArray
from jax._src.array import ArrayImpl
@ -3736,9 +3746,10 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
# No need to cache this check since MeshExecutable has a C++ fast path
# for AOT compiled call.
if committed and not are_op_shardings_equal(
arg_sharding._to_xla_op_sharding(arg.ndim),
xs._to_xla_op_sharding(arg.ndim)):
if (not _check_device_backend_on_shardings([xs]) and
committed and
not are_op_shardings_equal(arg_sharding._to_xla_op_sharding(arg.ndim),
xs._to_xla_op_sharding(arg.ndim))):
raise ValueError(
f"{arg_type} sharding does not match the input sharding. "
f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for "

View File

@ -24,5 +24,6 @@ filterwarnings =
default:Error reading persistent compilation cache entry for 'jit__lambda_'
default:Error writing persistent compilation cache entry for 'jit__lambda_'
ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"

View File

@ -248,8 +248,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
_check_instance(self, x)
self.assertEqual(x.device(), device)
@parameterized.named_parameters(
('jit', jax.jit),
('pjit', pjit.pjit),
)
@jtu.skip_on_devices("cpu")
def test_jit_default_device(self):
def test_jit_default_device(self, module):
if jax.device_count() == 1:
raise unittest.SkipTest("Test requires multiple devices")
@ -257,7 +261,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
test_device = jax.devices()[-1]
self.assertNotEqual(system_default_device, test_device)
f = jax.jit(lambda x: x + 1)
f = module(lambda x: x + 1)
self.assertEqual(f(1).device(), system_default_device)
with jax.default_device(test_device):
@ -270,14 +274,10 @@ class CPPJitTest(jtu.BufferDonationTestCase):
with jax.default_device(test_device):
# Explicit `device` or `backend` argument to jit overrides default_device
self.assertEqual(
jax.jit(f, device=system_default_device)(1).device(),
module(f, device=system_default_device)(1).device(),
system_default_device)
out = jax.jit(f, backend="cpu")(1)
if config.jax_array:
self.assertIsInstance(out.sharding, sharding.SingleDeviceSharding)
self.assertEqual(out._arrays[0].platform(), "cpu")
else:
self.assertEqual(out.platform(), "cpu")
out = module(f, backend="cpu")(1)
self.assertEqual(out.device().platform, "cpu")
# Sticky input device overrides default_device
sticky = jax.device_put(1, system_default_device)
@ -715,29 +715,31 @@ class CPPJitTest(jtu.BufferDonationTestCase):
else:
self.assertIsInstance(jitted_f(2), device_array.Buffer)
@parameterized.named_parameters(
('jit', jax.jit),
('pjit', pjit.pjit)
)
@jtu.skip_on_devices("cpu")
def test_explicit_backend(self):
def test_explicit_backend(self, module):
f = lambda x: x + 1
jitted_f = jit(f, backend=jtu.device_under_test())
jitted_f_cpu = jit(f, backend="cpu")
jitted_f = module(f, backend=jtu.device_under_test())
jitted_f_cpu = module(f, backend="cpu")
result = jitted_f(1.)
result_cpu = jitted_f_cpu(1.)
if config.jax_array:
buf = result._arrays[0]
buf_cpu = result_cpu._arrays[0]
else:
buf = result.device_buffer
buf_cpu = result_cpu.device_buffer
self.assertEqual(buf.platform(), jtu.device_under_test())
self.assertEqual(buf_cpu.platform(), "cpu")
self.assertEqual(result.device().platform, jtu.device_under_test())
self.assertEqual(result_cpu.device().platform, "cpu")
@parameterized.named_parameters(
('jit', jax.jit),
('pjit', pjit.pjit)
)
@jtu.skip_on_devices("cpu")
def test_device_to_device_copy_between_backends(self):
def test_device_to_device_copy_between_backends(self, module):
# b/186624243
f = lambda x: x + 1
jitted_f = jit(f, backend=jtu.device_under_test())
jitted_f_cpu = jit(f, backend="cpu")
jitted_f = module(f, backend=jtu.device_under_test())
jitted_f_cpu = module(f, backend="cpu")
x = np.arange(30).reshape(1, 10, 3)
result = jitted_f(x)
@ -747,16 +749,23 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(result_2, x + 3)
self.assertAllClose(result_cpu_2, x + 4)
@parameterized.named_parameters(
('jit', jax.jit),
('pjit', pjit.pjit)
)
@jtu.skip_on_devices("cpu")
def test_mismatched_nested_backends(self):
@partial(jit, backend=jtu.device_under_test())
def test_mismatched_nested_backends(self, module):
@partial(module, backend=jtu.device_under_test())
def f(x):
return jit(lambda x: x + 1, backend="cpu")(x)
return module(lambda x: x + 1, backend="cpu")(x)
with self.assertRaisesRegex(
ValueError,
"Outer-jit backend specification .* must match explicit inner-jit "
"backend specification cpu."):
if module is pjit.pjit:
msg = 'Devices of all `Array` inputs and outputs should be the same'
else:
msg = ("Outer-jit backend specification .* must match explicit inner-jit "
"backend specification cpu.")
with self.assertRaisesRegex(ValueError, msg):
f(1.)
def test_omnistaging(self):

View File

@ -2855,6 +2855,112 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(compiled._executable._kept_var_idx, {5})
self.assertLen(compiled._executable.in_avals, 1)
def test_pjit_with_device_arg(self):
def mul(x):
return x @ x.T
def _check(out, expected_device, expected_out):
self.assertEqual(out.device(), expected_device)
self.assertLen(out.sharding.device_set, 1)
self.assertArraysEqual(out, expected_out @ expected_out.T)
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
f = pjit(mul, device=jax.devices()[1])
x = jnp.arange(8).reshape(4, 2)
f_out = f(x)
f_out2 = f(f_out)
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
_check(f_out, jax.devices()[1], x)
_check(f_out2, jax.devices()[1], f_out)
y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y')))
out2 = f(y)
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
_check(out2, jax.devices()[1], y)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
h = pjit(mul, device=jax.devices()[4])
h_out = h(y)
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
_check(h_out, jax.devices()[4], y)
self.assertEqual(cache_info3.hits, cache_info2.hits)
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
# AOT test
compiled = f.lower(jax.ShapedArray(y.shape, y.dtype)).compile()
out3 = compiled(y)
_check(out3, jax.devices()[1], y)
def test_pjit_with_device_arg_input_from_another_pjit(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
inp = np.arange(8).reshape(4, 2)
y = jax.device_put(inp, jax.sharding.NamedSharding(mesh, P('x', 'y')))
out = pjit(lambda x: x * 2)(y)
expected_device = jax.devices()[2]
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
self.assertEqual(final_out.device(), expected_device)
self.assertLen(final_out.sharding.device_set, 1)
self.assertArraysEqual(final_out, inp * 6)
@jtu.skip_on_devices("gpu", "cpu")
def test_pjit_with_backend_arg(self):
def _check(out, expected_device, expected_out):
self.assertEqual(out.device(), expected_device)
self.assertLen(out.sharding.device_set, 1)
self.assertArraysEqual(out, expected_out)
x = jnp.arange(8)
g = pjit(lambda x: x, backend='tpu')
g_out = g(x)
_check(g_out, jax.devices()[0], x)
compiled = g.lower(jax.ShapedArray(x.shape, x.dtype)).compile()
out4 = compiled(x)
_check(out4, jax.devices()[0], x)
def test_autodiff_with_device_arg(self):
if jax.device_count() <= 1:
self.skipTest('Test requires more >1 device.')
# Add a constant captured by the nested pjit to make things more complicated
h = jnp.arange(4.)
f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1])
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1])
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
def test_pjit_device_backend_axis_resources_error(self):
with self.assertRaisesRegex(
ValueError,
'If backend or device is specified on jit, then '
'in_axis_resources should not be specified.'):
pjit(lambda x: x, in_axis_resources=None, backend='cpu')
with self.assertRaisesRegex(
ValueError,
'If backend or device is specified on jit, then '
'out_axis_resources should not be specified.'):
pjit(lambda x: x, out_axis_resources=None, device=jax.devices()[0])
def test_pjit_device_backend_both_error(self):
with self.assertRaisesRegex(
ValueError, "can't specify both a device and a backend for jit"):
pjit(lambda x: x, device=jax.devices()[0], backend='cpu')
def test_pjit_mesh_with_device_or_backend_error(self):
mesh = jtu.create_global_mesh((1,), ('x',))
with mesh:
with self.assertRaisesRegex(
ValueError,
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit."):
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
class TempSharding(Sharding):