Implement pjit fast path in cpp for jax.Array inputs

PiperOrigin-RevId: 475988677
This commit is contained in:
Kuangyuan Chen 2022-09-21 20:17:38 -07:00 committed by jax authors
parent 52476d1ab5
commit 405a2310ce
5 changed files with 206 additions and 15 deletions

View File

@ -26,6 +26,7 @@ from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src.lib import xla_client as xc
from jax.interpreters import pxla
from jax.experimental import array
from jax.experimental import sharding
from jax.experimental import pjit as pjit_lib
import jax.numpy as jnp
@ -628,5 +629,73 @@ def bench_slicing_compilation2(state):
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit):
spec = pjit_lib.PartitionSpec('x')
mesh = jtu.create_global_mesh((num_devices,), ('x',))
s = sharding.MeshPspecSharding(mesh, spec)
inp_data = np.arange(num_devices).astype(np.float32)
x = array.make_array_from_callback(inp_data.shape, s, lambda idx: inp_data[idx])
x = [x for _ in range(num_args)]
prev_state = jax_config.FLAGS.experimental_cpp_pjit
jax_config.FLAGS.experimental_cpp_pjit = cpp_jit
in_axis_resources = sharding.MeshPspecSharding(mesh, spec)
out_axis_resources = sharding.MeshPspecSharding(mesh, spec)
f = pjit_lib.pjit(
lambda x: jax.tree_map(lambda x: x + 1, x),
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)
x = f(x)
while state:
x = f(x)
jax_config.FLAGS.experimental_cpp_pjit = prev_state
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_1_device(state):
pjit_simple_benchmark(
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_4_device(state):
pjit_simple_benchmark(
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_4000_device(state):
pjit_simple_benchmark(
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
if __name__ == "__main__":
google_benchmark.main()

View File

@ -123,6 +123,11 @@ flags.DEFINE_bool(
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
flags.DEFINE_bool(
"experimental_cpp_pjit", bool_env("JAX_CPP_PJIT", False),
"A flag enabling the C++ pjit fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
def _nan_check_posthook(fun, args, kwargs, output):

View File

@ -16,9 +16,10 @@ import dataclasses
from enum import IntEnum
import numpy as np
from collections import OrderedDict, Counter
from typing import Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable
from typing import Any, Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable, NamedTuple
import itertools as it
from functools import partial, lru_cache
import threading
from jax.experimental import maps
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
@ -28,7 +29,7 @@ from jax.experimental.sharding import (
from jax import core
from jax import linear_util as lu
from jax import stages
from jax._src.api import _check_callable, _check_arg, local_devices
from jax._src.api import _check_callable, _check_arg, local_devices, FLAGS
from jax._src.config import config
from jax._src import dispatch
from jax._src import source_info_util
@ -122,6 +123,73 @@ def _check_all_or_none_unspecified(axis_resources, name):
'`pjit._UNSPECIFIED`.')
return unspecified
def _python_pjit_helper(infer_params, *args, **kwargs):
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
for arg in args_flat:
_check_arg(arg)
out_flat = pjit_p.bind(*args_flat, **params)
outs = tree_unflatten(out_tree, out_flat)
return outs, out_flat, out_tree
def _python_pjit(fun: Callable, infer_params):
@wraps(fun)
def wrapped(*args, **kwargs):
return _python_pjit_helper(infer_params, *args, **kwargs)[0]
return wrapped
class _PjitFastpathData(NamedTuple):
xla_executable: xla.XlaExecutable
out_pytree_def: Any
in_shardings: Sequence[Any]
out_shardings: Sequence[Any]
out_avals: Sequence[Any]
out_committed: Sequence[bool]
class _MostRecentPjitCallExecutable(threading.local):
def __init__(self):
self.value = None
_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable()
def _cpp_pjit(fun: Callable, infer_params, static_argnums):
def cache_miss(*args, **kwargs):
global _most_recent_pjit_call_executable
outs, out_flat, out_tree = _python_pjit_helper(infer_params, *args, **kwargs)
executable = _most_recent_pjit_call_executable.value
_most_recent_pjit_call_executable.value = None
use_fastpath = (
executable is not None and
isinstance(executable, pxla.MeshExecutable) and
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
not executable.unsafe_call.has_unordered_effects and
not executable.unsafe_call.has_host_callbacks and
all(isinstance(x, xc.Array) for x in out_flat)
)
if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
fastpath_data = _PjitFastpathData(executable.xla_executable,
out_tree,
executable._in_shardings,
executable._out_shardings, out_avals,
out_committed)
else:
fastpath_data = None
return outs, fastpath_data
cpp_pjit_f = xc._xla.pjit(fun, cache_miss, static_argnums)
return wraps(fun)(cpp_pjit_f)
# TODO(yashkatariya): Add pjit microbenchmarks.
# in_axis_resources and out_axis_resources can't be None as the default value
@ -359,13 +427,10 @@ def pjit(fun: Callable,
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)
@wraps(fun)
def wrapped(*args, **kwargs):
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
for arg in args_flat:
_check_arg(arg)
out = pjit_p.bind(*args_flat, **params)
return tree_unflatten(out_tree, out)
if FLAGS.experimental_cpp_pjit and xc._version >= 95:
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
else:
wrapped = _python_pjit(fun, infer_params)
def lower(*args, _global_avals=False, **kwargs):
(_, flat_local_in_avals, params, in_tree, out_tree,
@ -838,6 +903,9 @@ def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics):
global _most_recent_pjit_call_executable
if config.jax_array:
in_shardings = _resolve_in_shardings(args, in_shardings, out_shardings,
resource_env.physical_mesh)
@ -851,6 +919,7 @@ def _pjit_call_impl(*args, jaxpr,
jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, in_is_global).compile(
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
_most_recent_pjit_call_executable.value = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.jax_enable_checks:
pxla._check_gda_or_array_xla_sharding_match(args, compiled._in_shardings)
@ -880,7 +949,7 @@ class SameDeviceAssignmentTuple:
device_assignment: Optional[XLADeviceAssignment]
def __hash__(self):
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s # type: ignore
for s in self.shardings)
if self.device_assignment is None:
return hash(shardings_hash)
@ -935,14 +1004,14 @@ def _pjit_lower_cached(
in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef]
Tuple[MeshShardingMinusUnspecified, ...], tuple(
MeshPspecSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0])
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore
if isinstance(i, OpShardingSharding) else i
for i in in_shardings
))
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
Tuple[MeshSharding, ...], tuple(
MeshPspecSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0])
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore
if isinstance(o, OpShardingSharding) else o
for o in out_shardings
))

View File

@ -157,14 +157,19 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
return out
@pxla.use_cpp_class(xc.MeshPspecSharding if xc._version >= 95 else None)
class MeshPspecSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):
self.mesh = mesh
self.spec = spec
self._parsed_pspec = _parsed_pspec
self._preprocess()
def _preprocess(self):
# This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to
# axis_resources for vmap handlers. In such cases you need to preserve the
@ -172,12 +177,10 @@ class MeshPspecSharding(XLACompatibleSharding):
# PartitionSpec is inferred from the parsed pspec in this case.
# TODO(yaskatariya): Remove this and replace this with a normalized
# representation of Parsed Pspec
if _parsed_pspec is None:
if self._parsed_pspec is None:
from jax.experimental import pjit
self._parsed_pspec, _, _, _ = pjit._prepare_axis_resources(
self.spec, "MeshPspecSharding spec")
else:
self._parsed_pspec = _parsed_pspec
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
@ -256,8 +259,10 @@ def _get_replicated_op_sharding():
return proto
@pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
class SingleDeviceSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(self, device: Device):
self._device = device
@ -349,8 +354,10 @@ def _hash_op_sharding(op: xc.OpSharding):
op.type, op.replicate_on_last_tile_dim, tuple(op.last_tile_dims)))
@pxla.use_cpp_class(xc.OpShardingSharding if xc._version >= 95 else None)
class OpShardingSharding(XLACompatibleSharding):
@pxla.use_cpp_method
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
self._devices = tuple(devices)
self._op_sharding = op_sharding

View File

@ -24,6 +24,8 @@ from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import concurrent.futures
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
@ -2290,6 +2292,45 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
class ArrayCppPjitTest(ArrayPjitTest):
def setUp(self):
super().setUp()
self.jax_array = config.jax_array
self.cpp_pjit = config.FLAGS.experimental_cpp_pjit
config.update('experimental_cpp_pjit', True)
config.update('jax_array', True)
def tearDown(self):
config.update('experimental_cpp_pjit', self.cpp_pjit)
config.update('jax_array', self.jax_array)
super().tearDown()
def test_concurrent_cpp_pjit(self):
global_mesh = jtu.create_global_mesh((1,), ('x',))
sharding = MeshPspecSharding(global_mesh, P('x',))
n = 10
with global_mesh:
fs = [pjit(lambda x, i: x + i, static_argnums=1) for _ in range(n)]
def _invoke_with_mesh_twice(arg_tuple):
f, x, i = arg_tuple
with global_mesh:
f(x, i)
return f(x, i)
xs = [
array.make_array_from_callback(
(i,), sharding, lambda idx: np.arange(i, dtype=np.float32))
for i in range(n)
]
with concurrent.futures.ThreadPoolExecutor() as executor:
ys = executor.map(_invoke_with_mesh_twice,
[(fs[i], x, i) for i, x in enumerate(xs)])
for i, x, y in zip(range(n), xs, ys):
self.assertAllClose(x + i, y)
class TempSharding(Sharding):
def __init__(self, devices):