mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Add device
and backend
API to pjit
but resolve them away in infer_params
. This is to merge jit
and pjit
frontend API.
The semantics of mentioning `device` or `backend` on `pjit` is the same as doing a `device_put` i.e. no matter which device the arg is on, reshard it to the device mentioned. PiperOrigin-RevId: 495437165
This commit is contained in:
parent
64b6efc680
commit
7d4ef891af
@ -21,12 +21,13 @@ from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
|
||||
import itertools as it
|
||||
from functools import partial, lru_cache
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
|
||||
from jax._src.sharding import (
|
||||
NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
|
||||
XLADeviceAssignment)
|
||||
XLADeviceAssignment, SingleDeviceSharding)
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import stages
|
||||
@ -182,6 +183,8 @@ def pjit(
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
donate_argnums: Union[int, Sequence[int]] = (),
|
||||
keep_unused: bool = False,
|
||||
device: Optional[xc.Device] = None,
|
||||
backend: Optional[str] = None,
|
||||
) -> stages.Wrapped:
|
||||
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
|
||||
|
||||
@ -287,6 +290,16 @@ def pjit(
|
||||
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
||||
Such arguments will not be transferred to the device nor provided to the
|
||||
underlying executable. If `True`, unused arguments will not be pruned.
|
||||
device: This argument is deprecated. Please put your arguments on the
|
||||
device you want before passing them to jit.
|
||||
Optional, the Device the jitted function will run on. (Available devices
|
||||
can be retrieved via :py:func:`jax.devices`.) The default is inherited
|
||||
from XLA's DeviceAssignment logic and is usually to use
|
||||
``jax.devices()[0]``.
|
||||
backend: This argument is deprecated. Please put your arguments on the
|
||||
backend you want before passing them to jit.
|
||||
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
For more details on buffer donation see the [FAQ](https://jax.readthedocs.io/en/latest/faq.html#buffer-donation).
|
||||
Returns:
|
||||
@ -320,6 +333,23 @@ def pjit(
|
||||
"set the environment variable JAX_ARRAY=1 , or set the `jax_array` "
|
||||
"boolean flag to something true-like.")
|
||||
|
||||
if backend is not None or device is not None:
|
||||
warnings.warn(
|
||||
'backend and device argument on jit is deprecated. You can use a '
|
||||
'`jax.sharding.Mesh` context manager or device_put the arguments '
|
||||
'before passing them to `jit`. Please see '
|
||||
'https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html '
|
||||
'for more information.', DeprecationWarning)
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
f"got {device=} and {backend=}")
|
||||
if not _is_unspecified(in_axis_resources):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'in_axis_resources should not be specified.')
|
||||
if not _is_unspecified(out_axis_resources):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'out_axis_resources should not be specified.')
|
||||
|
||||
if isinstance(in_axis_resources, list):
|
||||
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
||||
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
||||
@ -353,6 +383,11 @@ def pjit(
|
||||
raise RuntimeError("pjit requires a non-empty mesh! Are you sure that "
|
||||
"it's defined at the call site?")
|
||||
|
||||
if (backend or device) and not pjit_mesh.empty:
|
||||
raise ValueError(
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit.")
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args,
|
||||
allow_invalid=True)
|
||||
@ -377,10 +412,17 @@ def pjit(
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
if config.jax_array:
|
||||
in_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
|
||||
out_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
|
||||
# If backend or device is set as an arg on jit, then resolve them to
|
||||
# in_shardings and out_shardings as if user passed in in_axis_resources
|
||||
# and out_axis_resources.
|
||||
if backend or device:
|
||||
in_shardings = out_shardings = _create_sharding_with_device_backend(
|
||||
device, backend)
|
||||
else:
|
||||
in_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources)
|
||||
out_shardings = tree_map(
|
||||
lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources)
|
||||
else:
|
||||
in_shardings = tree_map(
|
||||
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
|
||||
@ -508,6 +550,18 @@ def _create_sharding_for_array(mesh, x):
|
||||
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x)
|
||||
|
||||
|
||||
def _create_sharding_with_device_backend(device, backend):
|
||||
if device is not None:
|
||||
assert backend is None
|
||||
out = SingleDeviceSharding(device)
|
||||
elif backend is not None:
|
||||
assert device is None
|
||||
out = SingleDeviceSharding(
|
||||
xb.get_backend(backend).get_default_device_assignment(1)[0])
|
||||
out._device_backend = True
|
||||
return out
|
||||
|
||||
|
||||
def flatten_axis_resources(what, tree, shardings, tupled_args):
|
||||
try:
|
||||
return tuple(flatten_axes(what, tree, shardings, tupled_args=tupled_args))
|
||||
@ -877,6 +931,14 @@ pjit_p.multiple_results = True
|
||||
|
||||
|
||||
def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
|
||||
# If True, means that device or backend is set by the user on pjit and it
|
||||
# has the same semantics as device_put i.e. doesn't matter which device the
|
||||
# arg is on, reshard it to the device mentioned. So don't do any of the
|
||||
# checks and just return the pjit_in_shardings directly. `shard_args` will
|
||||
# handle the resharding.
|
||||
if pxla._check_device_backend_on_shardings(pjit_in_shardings):
|
||||
return pjit_in_shardings
|
||||
|
||||
committed_arg_shardings = []
|
||||
for a in args:
|
||||
if hasattr(a, 'sharding'):
|
||||
@ -1405,7 +1467,8 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
|
||||
for aval, s in zip(jaxpr.in_avals, params['in_shardings']):
|
||||
if _is_unspecified(s) or _is_auto(s):
|
||||
continue
|
||||
elif hasattr(s, '_original_sharding'):
|
||||
elif hasattr(s, '_original_sharding') and hasattr(
|
||||
s._original_sharding, '_parsed_pspec'):
|
||||
parsed_pspec = s._original_sharding._parsed_pspec
|
||||
else:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
@ -1421,7 +1484,8 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
|
||||
for aval, s in zip(jaxpr.out_avals, params['out_shardings']):
|
||||
if _is_unspecified(s) or _is_auto(s):
|
||||
continue
|
||||
elif hasattr(s, '_original_sharding'):
|
||||
elif hasattr(s, '_original_sharding') and hasattr(
|
||||
s._original_sharding, '_parsed_pspec'):
|
||||
parsed_pspec = s._original_sharding._parsed_pspec
|
||||
else:
|
||||
parsed_pspec = parse_flatten_op_sharding(
|
||||
|
@ -3234,7 +3234,8 @@ def _get_normalized_avals_and_shardings(
|
||||
in_sharding = i
|
||||
else:
|
||||
assert isinstance(i, sharding_internal.NamedSharding)
|
||||
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
|
||||
aval = i.mesh._global_to_local(
|
||||
cast(ArrayMapping, _get_array_mapping(i.spec)), gaval) # pylint: disable=g-bare-generic
|
||||
in_sharding = sharding_internal.NamedSharding(i.mesh.local_mesh, i.spec)
|
||||
avals.append(aval)
|
||||
shardings.append(in_sharding)
|
||||
@ -3467,9 +3468,8 @@ class UnloadedMeshExecutable:
|
||||
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
||||
|
||||
input_avals, input_shardings = (
|
||||
_get_normalized_avals_and_shardings(global_in_avals,
|
||||
in_shardings, # type: ignore # arg-type
|
||||
in_is_global))
|
||||
_get_normalized_avals_and_shardings(
|
||||
global_in_avals, in_shardings, in_is_global)) # type: ignore # arg-type
|
||||
|
||||
return UnloadedMeshExecutable(
|
||||
xla_executable=xla_executable,
|
||||
@ -3718,6 +3718,16 @@ def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
|
||||
return sharding_internal.NamedSharding(mesh, pspec, parsed_pspec)
|
||||
|
||||
|
||||
def _check_device_backend_on_shardings(shardings) -> bool:
|
||||
for i in shardings:
|
||||
if _is_unspecified(i) or _is_auto(i):
|
||||
continue
|
||||
if hasattr(i, '_original_sharding') and getattr(
|
||||
i._original_sharding, '_device_backend', False):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax._src.array import ArrayImpl
|
||||
@ -3736,9 +3746,10 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
||||
|
||||
# No need to cache this check since MeshExecutable has a C++ fast path
|
||||
# for AOT compiled call.
|
||||
if committed and not are_op_shardings_equal(
|
||||
arg_sharding._to_xla_op_sharding(arg.ndim),
|
||||
xs._to_xla_op_sharding(arg.ndim)):
|
||||
if (not _check_device_backend_on_shardings([xs]) and
|
||||
committed and
|
||||
not are_op_shardings_equal(arg_sharding._to_xla_op_sharding(arg.ndim),
|
||||
xs._to_xla_op_sharding(arg.ndim))):
|
||||
raise ValueError(
|
||||
f"{arg_type} sharding does not match the input sharding. "
|
||||
f"Got {arg_type} sharding: {arg_sharding} and xla sharding: {xs} for "
|
||||
|
@ -24,5 +24,6 @@ filterwarnings =
|
||||
default:Error reading persistent compilation cache entry for 'jit__lambda_'
|
||||
default:Error writing persistent compilation cache entry for 'jit__lambda_'
|
||||
ignore:DeviceArray, ShardedDeviceArray, and GlobalDeviceArray have been deprecated.*:DeprecationWarning
|
||||
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
||||
|
@ -248,8 +248,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
_check_instance(self, x)
|
||||
self.assertEqual(x.device(), device)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('jit', jax.jit),
|
||||
('pjit', pjit.pjit),
|
||||
)
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_jit_default_device(self):
|
||||
def test_jit_default_device(self, module):
|
||||
if jax.device_count() == 1:
|
||||
raise unittest.SkipTest("Test requires multiple devices")
|
||||
|
||||
@ -257,7 +261,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
test_device = jax.devices()[-1]
|
||||
self.assertNotEqual(system_default_device, test_device)
|
||||
|
||||
f = jax.jit(lambda x: x + 1)
|
||||
f = module(lambda x: x + 1)
|
||||
self.assertEqual(f(1).device(), system_default_device)
|
||||
|
||||
with jax.default_device(test_device):
|
||||
@ -270,14 +274,10 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
with jax.default_device(test_device):
|
||||
# Explicit `device` or `backend` argument to jit overrides default_device
|
||||
self.assertEqual(
|
||||
jax.jit(f, device=system_default_device)(1).device(),
|
||||
module(f, device=system_default_device)(1).device(),
|
||||
system_default_device)
|
||||
out = jax.jit(f, backend="cpu")(1)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(out.sharding, sharding.SingleDeviceSharding)
|
||||
self.assertEqual(out._arrays[0].platform(), "cpu")
|
||||
else:
|
||||
self.assertEqual(out.platform(), "cpu")
|
||||
out = module(f, backend="cpu")(1)
|
||||
self.assertEqual(out.device().platform, "cpu")
|
||||
|
||||
# Sticky input device overrides default_device
|
||||
sticky = jax.device_put(1, system_default_device)
|
||||
@ -715,29 +715,31 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
else:
|
||||
self.assertIsInstance(jitted_f(2), device_array.Buffer)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('jit', jax.jit),
|
||||
('pjit', pjit.pjit)
|
||||
)
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_explicit_backend(self):
|
||||
def test_explicit_backend(self, module):
|
||||
f = lambda x: x + 1
|
||||
jitted_f = jit(f, backend=jtu.device_under_test())
|
||||
jitted_f_cpu = jit(f, backend="cpu")
|
||||
jitted_f = module(f, backend=jtu.device_under_test())
|
||||
jitted_f_cpu = module(f, backend="cpu")
|
||||
|
||||
result = jitted_f(1.)
|
||||
result_cpu = jitted_f_cpu(1.)
|
||||
if config.jax_array:
|
||||
buf = result._arrays[0]
|
||||
buf_cpu = result_cpu._arrays[0]
|
||||
else:
|
||||
buf = result.device_buffer
|
||||
buf_cpu = result_cpu.device_buffer
|
||||
self.assertEqual(buf.platform(), jtu.device_under_test())
|
||||
self.assertEqual(buf_cpu.platform(), "cpu")
|
||||
self.assertEqual(result.device().platform, jtu.device_under_test())
|
||||
self.assertEqual(result_cpu.device().platform, "cpu")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('jit', jax.jit),
|
||||
('pjit', pjit.pjit)
|
||||
)
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_device_to_device_copy_between_backends(self):
|
||||
def test_device_to_device_copy_between_backends(self, module):
|
||||
# b/186624243
|
||||
f = lambda x: x + 1
|
||||
jitted_f = jit(f, backend=jtu.device_under_test())
|
||||
jitted_f_cpu = jit(f, backend="cpu")
|
||||
jitted_f = module(f, backend=jtu.device_under_test())
|
||||
jitted_f_cpu = module(f, backend="cpu")
|
||||
|
||||
x = np.arange(30).reshape(1, 10, 3)
|
||||
result = jitted_f(x)
|
||||
@ -747,16 +749,23 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(result_2, x + 3)
|
||||
self.assertAllClose(result_cpu_2, x + 4)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('jit', jax.jit),
|
||||
('pjit', pjit.pjit)
|
||||
)
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_mismatched_nested_backends(self):
|
||||
@partial(jit, backend=jtu.device_under_test())
|
||||
def test_mismatched_nested_backends(self, module):
|
||||
@partial(module, backend=jtu.device_under_test())
|
||||
def f(x):
|
||||
return jit(lambda x: x + 1, backend="cpu")(x)
|
||||
return module(lambda x: x + 1, backend="cpu")(x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Outer-jit backend specification .* must match explicit inner-jit "
|
||||
"backend specification cpu."):
|
||||
if module is pjit.pjit:
|
||||
msg = 'Devices of all `Array` inputs and outputs should be the same'
|
||||
else:
|
||||
msg = ("Outer-jit backend specification .* must match explicit inner-jit "
|
||||
"backend specification cpu.")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
f(1.)
|
||||
|
||||
def test_omnistaging(self):
|
||||
|
@ -2855,6 +2855,112 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(compiled._executable._kept_var_idx, {5})
|
||||
self.assertLen(compiled._executable.in_avals, 1)
|
||||
|
||||
def test_pjit_with_device_arg(self):
|
||||
def mul(x):
|
||||
return x @ x.T
|
||||
|
||||
def _check(out, expected_device, expected_out):
|
||||
self.assertEqual(out.device(), expected_device)
|
||||
self.assertLen(out.sharding.device_set, 1)
|
||||
self.assertArraysEqual(out, expected_out @ expected_out.T)
|
||||
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
f = pjit(mul, device=jax.devices()[1])
|
||||
x = jnp.arange(8).reshape(4, 2)
|
||||
f_out = f(x)
|
||||
f_out2 = f(f_out)
|
||||
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
_check(f_out, jax.devices()[1], x)
|
||||
_check(f_out2, jax.devices()[1], f_out)
|
||||
|
||||
y = jax.device_put(x, jax.sharding.NamedSharding(mesh, P('x', 'y')))
|
||||
out2 = f(y)
|
||||
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
_check(out2, jax.devices()[1], y)
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
|
||||
h = pjit(mul, device=jax.devices()[4])
|
||||
h_out = h(y)
|
||||
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
_check(h_out, jax.devices()[4], y)
|
||||
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits)
|
||||
self.assertEqual(cache_info3.misses, cache_info2.misses + 1)
|
||||
|
||||
# AOT test
|
||||
compiled = f.lower(jax.ShapedArray(y.shape, y.dtype)).compile()
|
||||
out3 = compiled(y)
|
||||
_check(out3, jax.devices()[1], y)
|
||||
|
||||
def test_pjit_with_device_arg_input_from_another_pjit(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
inp = np.arange(8).reshape(4, 2)
|
||||
|
||||
y = jax.device_put(inp, jax.sharding.NamedSharding(mesh, P('x', 'y')))
|
||||
out = pjit(lambda x: x * 2)(y)
|
||||
|
||||
expected_device = jax.devices()[2]
|
||||
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
|
||||
|
||||
self.assertEqual(final_out.device(), expected_device)
|
||||
self.assertLen(final_out.sharding.device_set, 1)
|
||||
self.assertArraysEqual(final_out, inp * 6)
|
||||
|
||||
@jtu.skip_on_devices("gpu", "cpu")
|
||||
def test_pjit_with_backend_arg(self):
|
||||
def _check(out, expected_device, expected_out):
|
||||
self.assertEqual(out.device(), expected_device)
|
||||
self.assertLen(out.sharding.device_set, 1)
|
||||
self.assertArraysEqual(out, expected_out)
|
||||
|
||||
x = jnp.arange(8)
|
||||
g = pjit(lambda x: x, backend='tpu')
|
||||
g_out = g(x)
|
||||
_check(g_out, jax.devices()[0], x)
|
||||
|
||||
compiled = g.lower(jax.ShapedArray(x.shape, x.dtype)).compile()
|
||||
out4 = compiled(x)
|
||||
_check(out4, jax.devices()[0], x)
|
||||
|
||||
def test_autodiff_with_device_arg(self):
|
||||
if jax.device_count() <= 1:
|
||||
self.skipTest('Test requires more >1 device.')
|
||||
# Add a constant captured by the nested pjit to make things more complicated
|
||||
h = jnp.arange(4.)
|
||||
f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1])
|
||||
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1])
|
||||
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
|
||||
|
||||
def test_pjit_device_backend_axis_resources_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'If backend or device is specified on jit, then '
|
||||
'in_axis_resources should not be specified.'):
|
||||
pjit(lambda x: x, in_axis_resources=None, backend='cpu')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'If backend or device is specified on jit, then '
|
||||
'out_axis_resources should not be specified.'):
|
||||
pjit(lambda x: x, out_axis_resources=None, device=jax.devices()[0])
|
||||
|
||||
def test_pjit_device_backend_both_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "can't specify both a device and a backend for jit"):
|
||||
pjit(lambda x: x, device=jax.devices()[0], backend='cpu')
|
||||
|
||||
def test_pjit_mesh_with_device_or_backend_error(self):
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
with mesh:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit."):
|
||||
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user