mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make apply_primitive go via C++ fast dispatch.
This leads to a ~30% faster dispatch time. Ideally, we should replace this with jit, but that has it's own set of problems that I will look into later. ``` eager_unary_dispatch 40.3µs ± 2% 29.2µs ± 9% -27.51% (p=0.008 n=5+5) eager_unary 40.6µs ± 0% 31.1µs ±11% -23.41% (p=0.016 n=4+5) eager_binary_dispatch 49.6µs ± 0% 34.5µs ± 8% -30.58% (p=0.016 n=4+5) eager_binary 50.2µs ± 1% 35.4µs ± 9% -29.38% (p=0.016 n=4+5) bench_remat_eager_retracing_overheads 13.0ms ± 1% 11.3ms ± 8% -13.26% (p=0.008 n=5+5) bench_remat_eager_retracing_overheads_static_argnums 13.3ms ± 0% 12.3ms ± 6% -7.34% (p=0.016 n=4+5) bench_repeated_static_indexing 112ms ± 2% 82ms ± 5% -26.46% (p=0.008 n=5+5) bench_repeated_static_slicing 90.5ms ± 1% 68.3ms ± 5% -24.54% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 561774696
This commit is contained in:
parent
6c1b4b9f3d
commit
ccb88140ec
@ -498,6 +498,7 @@ class Config:
|
||||
self.jax_threefry_partitionable,
|
||||
self.jax_softmax_custom_jvp,
|
||||
self.jax_enable_memories,
|
||||
self.jax_disable_jit,
|
||||
# Technically this affects jaxpr->MHLO lowering, not tracing.
|
||||
self.jax_hlo_source_file_canonicalization_regex)
|
||||
|
||||
|
@ -33,6 +33,7 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import api_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -44,6 +45,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.monitoring import record_event_duration_secs
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding
|
||||
@ -113,8 +115,9 @@ def apply_primitive(prim, *args, **params):
|
||||
|
||||
try:
|
||||
in_avals, in_shardings = util.unzip2([_arg_spec(a) for a in args])
|
||||
in_tree = tree_util.tree_structure(args)
|
||||
compiled_fun = xla_primitive_callable(
|
||||
prim, in_avals, OrigShardings(in_shardings), **params)
|
||||
prim, in_avals, in_tree, OrigShardings(in_shardings), **params)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
# TODO(yashkatariya): Thread through a signature_fun via every primitive
|
||||
@ -128,6 +131,56 @@ def apply_primitive(prim, *args, **params):
|
||||
return compiled_fun(*args)
|
||||
|
||||
|
||||
@util.cache()
|
||||
def xla_primitive_callable(
|
||||
prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...], in_tree,
|
||||
orig_in_shardings: OrigShardings, **params,
|
||||
) -> Callable:
|
||||
def prim_fun(*args):
|
||||
out = prim.bind(*args, **params)
|
||||
if prim.multiple_results:
|
||||
return out
|
||||
else:
|
||||
return out,
|
||||
donated_invars = (False,) * len(in_avals)
|
||||
wrapped_fun = lu.wrap_init(prim_fun)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree)
|
||||
computation = sharded_lowering(
|
||||
flat_fun, prim.name, donated_invars, keep_unused=False,
|
||||
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
|
||||
lowering_platform=None)
|
||||
compiled = computation.compile()
|
||||
if xla_extension_version >= 192:
|
||||
if config.jax_disable_jit:
|
||||
call = compiled.unsafe_call
|
||||
else:
|
||||
call = compiled.create_cpp_call_for_apply_primitive(out_tree())
|
||||
if call is None:
|
||||
call = compiled.unsafe_call
|
||||
else:
|
||||
call = compiled.unsafe_call
|
||||
if not prim.multiple_results:
|
||||
return lambda *args, **kw: call(*args, **kw)[0]
|
||||
else:
|
||||
return call
|
||||
|
||||
|
||||
def sharded_lowering(
|
||||
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
|
||||
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
|
||||
in_shardings: Sequence[Sharding | None], lowering_platform: str | None
|
||||
) -> pxla.MeshComputation:
|
||||
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
|
||||
|
||||
# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
|
||||
# the number of output avals at this stage. lower_sharding_computation will
|
||||
# apply it to all out_avals.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
|
||||
in_avals, keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=None, lowering_platform=lowering_platform)
|
||||
|
||||
|
||||
def simple_impl(prim):
|
||||
prim.def_impl(partial(apply_primitive, prim))
|
||||
|
||||
@ -196,45 +249,6 @@ def wait_for_tokens():
|
||||
runtime_tokens.block_until_ready()
|
||||
|
||||
|
||||
@util.cache()
|
||||
def xla_primitive_callable(
|
||||
prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...],
|
||||
orig_in_shardings: OrigShardings, **params,
|
||||
) -> Callable:
|
||||
def prim_fun(*args):
|
||||
out = prim.bind(*args, **params)
|
||||
if prim.multiple_results:
|
||||
return out
|
||||
else:
|
||||
return out,
|
||||
donated_invars = (False,) * len(in_avals)
|
||||
computation = sharded_lowering(
|
||||
lu.wrap_init(prim_fun), prim.name, donated_invars, keep_unused=False,
|
||||
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
|
||||
lowering_platform=None)
|
||||
compiled = computation.compile().unsafe_call
|
||||
if not prim.multiple_results:
|
||||
return lambda *args, **kw: compiled(*args, **kw)[0]
|
||||
else:
|
||||
return compiled
|
||||
|
||||
|
||||
def sharded_lowering(
|
||||
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
|
||||
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
|
||||
in_shardings: Sequence[Sharding | None], lowering_platform: str | None
|
||||
) -> pxla.MeshComputation:
|
||||
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
|
||||
|
||||
# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
|
||||
# the number of output avals at this stage. lower_sharding_computation will
|
||||
# apply it to all out_avals.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
|
||||
in_avals, keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=None, lowering_platform=lowering_platform)
|
||||
|
||||
|
||||
def is_single_device_sharding(sharding: Sharding) -> bool:
|
||||
# Special case PmapSharding here because PmapSharding maps away an axis
|
||||
# and needs to be handled separately.test_pjit_single_device_sharding_add
|
||||
|
@ -2751,6 +2751,33 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.default_registry)
|
||||
|
||||
def create_cpp_call_for_apply_primitive(self, out_tree):
|
||||
# unsafe_call can be different than ExecuteReplicated for pathways.
|
||||
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
|
||||
not self.unsafe_call.has_unordered_effects and
|
||||
not self.unsafe_call.has_host_callbacks):
|
||||
return None
|
||||
|
||||
def apply_primitive_cache_miss(*args):
|
||||
out_flat = self.unsafe_call(*args)
|
||||
outs = tree_util.tree_unflatten(out_tree, out_flat)
|
||||
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))
|
||||
|
||||
if use_fastpath:
|
||||
out_avals = [o.aval for o in out_flat]
|
||||
out_committed = [o._committed for o in out_flat]
|
||||
kept_var_bitvec = [i in self._kept_var_idx
|
||||
for i in range(len(args))]
|
||||
fastpath_data = MeshExecutableFastpathData(
|
||||
self.xla_executable, out_tree, self._in_shardings,
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data
|
||||
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, apply_primitive_cache_miss,
|
||||
[], [], [], tree_util.default_registry)
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
|
||||
|
@ -517,9 +517,8 @@ class Compiled(Stage):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self._call is None:
|
||||
self._call = self._executable.create_cpp_call(self._no_kwargs,
|
||||
self.in_tree,
|
||||
self.out_tree)
|
||||
self._call = self._executable.create_cpp_call(
|
||||
self._no_kwargs, self.in_tree, self.out_tree)
|
||||
if self._call is None:
|
||||
params = self._params
|
||||
def cpp_call_fallback(*args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user