Make apply_primitive go via C++ fast dispatch.

This leads to a ~30% faster dispatch time. Ideally, we should replace this with jit, but that has it's own set of problems that I will look into later.

```
eager_unary_dispatch                                  40.3µs ± 2%             29.2µs ± 9%  -27.51%          (p=0.008 n=5+5)
eager_unary                                           40.6µs ± 0%             31.1µs ±11%  -23.41%          (p=0.016 n=4+5)
eager_binary_dispatch                                 49.6µs ± 0%             34.5µs ± 8%  -30.58%          (p=0.016 n=4+5)
eager_binary                                          50.2µs ± 1%             35.4µs ± 9%  -29.38%          (p=0.016 n=4+5)
bench_remat_eager_retracing_overheads                 13.0ms ± 1%             11.3ms ± 8%  -13.26%          (p=0.008 n=5+5)
bench_remat_eager_retracing_overheads_static_argnums  13.3ms ± 0%             12.3ms ± 6%   -7.34%          (p=0.016 n=4+5)
bench_repeated_static_indexing                         112ms ± 2%               82ms ± 5%  -26.46%          (p=0.008 n=5+5)
bench_repeated_static_slicing                         90.5ms ± 1%             68.3ms ± 5%  -24.54%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 561774696
This commit is contained in:
Yash Katariya 2023-08-31 15:17:57 -07:00 committed by jax authors
parent 6c1b4b9f3d
commit ccb88140ec
4 changed files with 84 additions and 43 deletions

View File

@ -498,6 +498,7 @@ class Config:
self.jax_threefry_partitionable,
self.jax_softmax_custom_jvp,
self.jax_enable_memories,
self.jax_disable_jit,
# Technically this affects jaxpr->MHLO lowering, not tracing.
self.jax_hlo_source_file_canonicalization_regex)

View File

@ -33,6 +33,7 @@ from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import api_util
from jax._src import tree_util
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
@ -44,6 +45,7 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.monitoring import record_event_duration_secs
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
@ -113,8 +115,9 @@ def apply_primitive(prim, *args, **params):
try:
in_avals, in_shardings = util.unzip2([_arg_spec(a) for a in args])
in_tree = tree_util.tree_structure(args)
compiled_fun = xla_primitive_callable(
prim, in_avals, OrigShardings(in_shardings), **params)
prim, in_avals, in_tree, OrigShardings(in_shardings), **params)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
# TODO(yashkatariya): Thread through a signature_fun via every primitive
@ -128,6 +131,56 @@ def apply_primitive(prim, *args, **params):
return compiled_fun(*args)
@util.cache()
def xla_primitive_callable(
prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...], in_tree,
orig_in_shardings: OrigShardings, **params,
) -> Callable:
def prim_fun(*args):
out = prim.bind(*args, **params)
if prim.multiple_results:
return out
else:
return out,
donated_invars = (False,) * len(in_avals)
wrapped_fun = lu.wrap_init(prim_fun)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree)
computation = sharded_lowering(
flat_fun, prim.name, donated_invars, keep_unused=False,
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
lowering_platform=None)
compiled = computation.compile()
if xla_extension_version >= 192:
if config.jax_disable_jit:
call = compiled.unsafe_call
else:
call = compiled.create_cpp_call_for_apply_primitive(out_tree())
if call is None:
call = compiled.unsafe_call
else:
call = compiled.unsafe_call
if not prim.multiple_results:
return lambda *args, **kw: call(*args, **kw)[0]
else:
return call
def sharded_lowering(
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
in_shardings: Sequence[Sharding | None], lowering_platform: str | None
) -> pxla.MeshComputation:
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, inline=inline,
devices_from_context=None, lowering_platform=lowering_platform)
def simple_impl(prim):
prim.def_impl(partial(apply_primitive, prim))
@ -196,45 +249,6 @@ def wait_for_tokens():
runtime_tokens.block_until_ready()
@util.cache()
def xla_primitive_callable(
prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...],
orig_in_shardings: OrigShardings, **params,
) -> Callable:
def prim_fun(*args):
out = prim.bind(*args, **params)
if prim.multiple_results:
return out
else:
return out,
donated_invars = (False,) * len(in_avals)
computation = sharded_lowering(
lu.wrap_init(prim_fun), prim.name, donated_invars, keep_unused=False,
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
lowering_platform=None)
compiled = computation.compile().unsafe_call
if not prim.multiple_results:
return lambda *args, **kw: compiled(*args, **kw)[0]
else:
return compiled
def sharded_lowering(
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
in_shardings: Sequence[Sharding | None], lowering_platform: str | None
) -> pxla.MeshComputation:
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, inline=inline,
devices_from_context=None, lowering_platform=lowering_platform)
def is_single_device_sharding(sharding: Sharding) -> bool:
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.test_pjit_single_device_sharding_add

View File

@ -2751,6 +2751,33 @@ class MeshExecutable(stages.XlaExecutable):
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.default_registry)
def create_cpp_call_for_apply_primitive(self, out_tree):
# unsafe_call can be different than ExecuteReplicated for pathways.
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
not self.unsafe_call.has_unordered_effects and
not self.unsafe_call.has_host_callbacks):
return None
def apply_primitive_cache_miss(*args):
out_flat = self.unsafe_call(*args)
outs = tree_util.tree_unflatten(out_tree, out_flat)
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args))]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree, self._in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
else:
fastpath_data = None
return outs, fastpath_data
return xc._xla.pjit(self.unsafe_call.name, None, apply_primitive_cache_miss,
[], [], [], tree_util.default_registry)
def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.JaxprDebugInfo | None = None):

View File

@ -517,9 +517,8 @@ class Compiled(Stage):
def __call__(self, *args, **kwargs):
if self._call is None:
self._call = self._executable.create_cpp_call(self._no_kwargs,
self.in_tree,
self.out_tree)
self._call = self._executable.create_cpp_call(
self._no_kwargs, self.in_tree, self.out_tree)
if self._call is None:
params = self._params
def cpp_call_fallback(*args, **kwargs):