mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Implement pjit fast path in cpp for jax.Array inputs
PiperOrigin-RevId: 475988677
This commit is contained in:
parent
52476d1ab5
commit
405a2310ce
@ -26,6 +26,7 @@ from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental import array
|
||||
from jax.experimental import sharding
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
import jax.numpy as jnp
|
||||
@ -628,5 +629,73 @@ def bench_slicing_compilation2(state):
|
||||
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()
|
||||
|
||||
|
||||
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit):
|
||||
spec = pjit_lib.PartitionSpec('x')
|
||||
mesh = jtu.create_global_mesh((num_devices,), ('x',))
|
||||
s = sharding.MeshPspecSharding(mesh, spec)
|
||||
inp_data = np.arange(num_devices).astype(np.float32)
|
||||
x = array.make_array_from_callback(inp_data.shape, s, lambda idx: inp_data[idx])
|
||||
|
||||
x = [x for _ in range(num_args)]
|
||||
|
||||
prev_state = jax_config.FLAGS.experimental_cpp_pjit
|
||||
jax_config.FLAGS.experimental_cpp_pjit = cpp_jit
|
||||
|
||||
in_axis_resources = sharding.MeshPspecSharding(mesh, spec)
|
||||
out_axis_resources = sharding.MeshPspecSharding(mesh, spec)
|
||||
|
||||
f = pjit_lib.pjit(
|
||||
lambda x: jax.tree_map(lambda x: x + 1, x),
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources)
|
||||
|
||||
x = f(x)
|
||||
|
||||
while state:
|
||||
x = f(x)
|
||||
|
||||
jax_config.FLAGS.experimental_cpp_pjit = prev_state
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_simple_1_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_simple_4_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_simple_4000_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
@ -123,6 +123,11 @@ flags.DEFINE_bool(
|
||||
"A flag enabling the C++ jax.pmap fast path. Until the default "
|
||||
"is switched to True, the feature is not supported and possibly broken "
|
||||
"(e.g. it may use unreleased code from jaxlib.")
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_pjit", bool_env("JAX_CPP_PJIT", False),
|
||||
"A flag enabling the C++ pjit fast path. Until the default "
|
||||
"is switched to True, the feature is not supported and possibly broken "
|
||||
"(e.g. it may use unreleased code from jaxlib.")
|
||||
|
||||
|
||||
def _nan_check_posthook(fun, args, kwargs, output):
|
||||
|
@ -16,9 +16,10 @@ import dataclasses
|
||||
from enum import IntEnum
|
||||
import numpy as np
|
||||
from collections import OrderedDict, Counter
|
||||
from typing import Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable
|
||||
from typing import Any, Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable, NamedTuple
|
||||
import itertools as it
|
||||
from functools import partial, lru_cache
|
||||
import threading
|
||||
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
|
||||
@ -28,7 +29,7 @@ from jax.experimental.sharding import (
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src.api import _check_callable, _check_arg, local_devices
|
||||
from jax._src.api import _check_callable, _check_arg, local_devices, FLAGS
|
||||
from jax._src.config import config
|
||||
from jax._src import dispatch
|
||||
from jax._src import source_info_util
|
||||
@ -122,6 +123,73 @@ def _check_all_or_none_unspecified(axis_resources, name):
|
||||
'`pjit._UNSPECIFIED`.')
|
||||
return unspecified
|
||||
|
||||
def _python_pjit_helper(infer_params, *args, **kwargs):
|
||||
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
out_flat = pjit_p.bind(*args_flat, **params)
|
||||
outs = tree_unflatten(out_tree, out_flat)
|
||||
return outs, out_flat, out_tree
|
||||
|
||||
def _python_pjit(fun: Callable, infer_params):
|
||||
|
||||
@wraps(fun)
|
||||
def wrapped(*args, **kwargs):
|
||||
return _python_pjit_helper(infer_params, *args, **kwargs)[0]
|
||||
|
||||
return wrapped
|
||||
|
||||
class _PjitFastpathData(NamedTuple):
|
||||
xla_executable: xla.XlaExecutable
|
||||
out_pytree_def: Any
|
||||
in_shardings: Sequence[Any]
|
||||
out_shardings: Sequence[Any]
|
||||
out_avals: Sequence[Any]
|
||||
out_committed: Sequence[bool]
|
||||
|
||||
class _MostRecentPjitCallExecutable(threading.local):
|
||||
def __init__(self):
|
||||
self.value = None
|
||||
|
||||
_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable()
|
||||
|
||||
def _cpp_pjit(fun: Callable, infer_params, static_argnums):
|
||||
|
||||
def cache_miss(*args, **kwargs):
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
outs, out_flat, out_tree = _python_pjit_helper(infer_params, *args, **kwargs)
|
||||
|
||||
executable = _most_recent_pjit_call_executable.value
|
||||
_most_recent_pjit_call_executable.value = None
|
||||
|
||||
use_fastpath = (
|
||||
executable is not None and
|
||||
isinstance(executable, pxla.MeshExecutable) and
|
||||
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
|
||||
not executable.unsafe_call.has_unordered_effects and
|
||||
not executable.unsafe_call.has_host_callbacks and
|
||||
all(isinstance(x, xc.Array) for x in out_flat)
|
||||
)
|
||||
|
||||
if use_fastpath:
|
||||
out_avals = [o.aval for o in out_flat]
|
||||
out_committed = [o._committed for o in out_flat]
|
||||
fastpath_data = _PjitFastpathData(executable.xla_executable,
|
||||
out_tree,
|
||||
executable._in_shardings,
|
||||
executable._out_shardings, out_avals,
|
||||
out_committed)
|
||||
else:
|
||||
fastpath_data = None
|
||||
|
||||
|
||||
return outs, fastpath_data
|
||||
|
||||
cpp_pjit_f = xc._xla.pjit(fun, cache_miss, static_argnums)
|
||||
|
||||
return wraps(fun)(cpp_pjit_f)
|
||||
|
||||
|
||||
# TODO(yashkatariya): Add pjit microbenchmarks.
|
||||
# in_axis_resources and out_axis_resources can't be None as the default value
|
||||
@ -359,13 +427,10 @@ def pjit(fun: Callable,
|
||||
return (args_flat, local_in_avals, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
|
||||
@wraps(fun)
|
||||
def wrapped(*args, **kwargs):
|
||||
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
|
||||
for arg in args_flat:
|
||||
_check_arg(arg)
|
||||
out = pjit_p.bind(*args_flat, **params)
|
||||
return tree_unflatten(out_tree, out)
|
||||
if FLAGS.experimental_cpp_pjit and xc._version >= 95:
|
||||
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params)
|
||||
|
||||
def lower(*args, _global_avals=False, **kwargs):
|
||||
(_, flat_local_in_avals, params, in_tree, out_tree,
|
||||
@ -838,6 +903,9 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
if config.jax_array:
|
||||
in_shardings = _resolve_in_shardings(args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh)
|
||||
@ -851,6 +919,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
jaxpr, in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name, in_is_global).compile(
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
|
||||
_most_recent_pjit_call_executable.value = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
if compiled._auto_spmd_lowering and config.jax_enable_checks:
|
||||
pxla._check_gda_or_array_xla_sharding_match(args, compiled._in_shardings)
|
||||
@ -880,7 +949,7 @@ class SameDeviceAssignmentTuple:
|
||||
device_assignment: Optional[XLADeviceAssignment]
|
||||
|
||||
def __hash__(self):
|
||||
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s
|
||||
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s # type: ignore
|
||||
for s in self.shardings)
|
||||
if self.device_assignment is None:
|
||||
return hash(shardings_hash)
|
||||
@ -935,14 +1004,14 @@ def _pjit_lower_cached(
|
||||
in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef]
|
||||
Tuple[MeshShardingMinusUnspecified, ...], tuple(
|
||||
MeshPspecSharding._from_parsed_pspec(
|
||||
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0])
|
||||
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore
|
||||
if isinstance(i, OpShardingSharding) else i
|
||||
for i in in_shardings
|
||||
))
|
||||
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
|
||||
Tuple[MeshSharding, ...], tuple(
|
||||
MeshPspecSharding._from_parsed_pspec(
|
||||
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0])
|
||||
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore
|
||||
if isinstance(o, OpShardingSharding) else o
|
||||
for o in out_shardings
|
||||
))
|
||||
|
@ -157,14 +157,19 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
||||
return out
|
||||
|
||||
|
||||
@pxla.use_cpp_class(xc.MeshPspecSharding if xc._version >= 95 else None)
|
||||
class MeshPspecSharding(XLACompatibleSharding):
|
||||
|
||||
@pxla.use_cpp_method
|
||||
def __init__(
|
||||
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):
|
||||
|
||||
self.mesh = mesh
|
||||
self.spec = spec
|
||||
self._parsed_pspec = _parsed_pspec
|
||||
self._preprocess()
|
||||
|
||||
def _preprocess(self):
|
||||
# This split exists because you can pass `_parsed_pspec` that has been
|
||||
# modified from the original. For example: Adding extra dimension to
|
||||
# axis_resources for vmap handlers. In such cases you need to preserve the
|
||||
@ -172,12 +177,10 @@ class MeshPspecSharding(XLACompatibleSharding):
|
||||
# PartitionSpec is inferred from the parsed pspec in this case.
|
||||
# TODO(yaskatariya): Remove this and replace this with a normalized
|
||||
# representation of Parsed Pspec
|
||||
if _parsed_pspec is None:
|
||||
if self._parsed_pspec is None:
|
||||
from jax.experimental import pjit
|
||||
self._parsed_pspec, _, _, _ = pjit._prepare_axis_resources(
|
||||
self.spec, "MeshPspecSharding spec")
|
||||
else:
|
||||
self._parsed_pspec = _parsed_pspec
|
||||
|
||||
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
||||
|
||||
@ -256,8 +259,10 @@ def _get_replicated_op_sharding():
|
||||
return proto
|
||||
|
||||
|
||||
@pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
|
||||
class SingleDeviceSharding(XLACompatibleSharding):
|
||||
|
||||
@pxla.use_cpp_method
|
||||
def __init__(self, device: Device):
|
||||
self._device = device
|
||||
|
||||
@ -349,8 +354,10 @@ def _hash_op_sharding(op: xc.OpSharding):
|
||||
op.type, op.replicate_on_last_tile_dim, tuple(op.last_tile_dims)))
|
||||
|
||||
|
||||
@pxla.use_cpp_class(xc.OpShardingSharding if xc._version >= 95 else None)
|
||||
class OpShardingSharding(XLACompatibleSharding):
|
||||
|
||||
@pxla.use_cpp_method
|
||||
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
|
||||
self._devices = tuple(devices)
|
||||
self._op_sharding = op_sharding
|
||||
|
@ -24,6 +24,8 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
@ -2290,6 +2292,45 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
|
||||
|
||||
|
||||
class ArrayCppPjitTest(ArrayPjitTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.jax_array = config.jax_array
|
||||
self.cpp_pjit = config.FLAGS.experimental_cpp_pjit
|
||||
config.update('experimental_cpp_pjit', True)
|
||||
config.update('jax_array', True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update('experimental_cpp_pjit', self.cpp_pjit)
|
||||
config.update('jax_array', self.jax_array)
|
||||
super().tearDown()
|
||||
|
||||
def test_concurrent_cpp_pjit(self):
|
||||
global_mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
sharding = MeshPspecSharding(global_mesh, P('x',))
|
||||
n = 10
|
||||
with global_mesh:
|
||||
fs = [pjit(lambda x, i: x + i, static_argnums=1) for _ in range(n)]
|
||||
|
||||
def _invoke_with_mesh_twice(arg_tuple):
|
||||
f, x, i = arg_tuple
|
||||
with global_mesh:
|
||||
f(x, i)
|
||||
return f(x, i)
|
||||
|
||||
xs = [
|
||||
array.make_array_from_callback(
|
||||
(i,), sharding, lambda idx: np.arange(i, dtype=np.float32))
|
||||
for i in range(n)
|
||||
]
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
ys = executor.map(_invoke_with_mesh_twice,
|
||||
[(fs[i], x, i) for i, x in enumerate(xs)])
|
||||
for i, x, y in zip(range(n), xs, ys):
|
||||
self.assertAllClose(x + i, y)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
def __init__(self, devices):
|
||||
|
Loading…
x
Reference in New Issue
Block a user