Remove trivial execution from jax since it leads to 100x slower dispatch time.

Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.

In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.

```
jit_trivial_dispatch                                   246µs ± 3%                4µs ± 1%  -98.52%          (p=0.008 n=5+5)
jit_trivial                                            250µs ± 3%                5µs ± 1%  -98.19%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 560141018
This commit is contained in:
Yash Katariya 2023-08-25 10:59:10 -07:00 committed by jax authors
parent c71eedf529
commit 970f4c9d4d
6 changed files with 28 additions and 150 deletions

View File

@ -2016,22 +2016,6 @@ def lower_sharding_computation(
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
kept_outputs = [True] * len(global_out_avals)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(is_unspecified(o) for o in out_shardings)):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
global_out_avals=global_out_avals, in_shardings=in_shardings,
backend=backend, da_object=da_object,
committed=committed, kept_var_idx=kept_var_idx, keepalive=None)
# 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
@ -2049,7 +2033,6 @@ def lower_sharding_computation(
return MeshComputation(
str(name_stack),
module,
False,
donated_invars,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
@ -2223,7 +2206,6 @@ def lower_mesh_computation(
return MeshComputation(
str(name_stack),
lowering_result.module,
False,
donated_invars,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
@ -2248,10 +2230,9 @@ class MeshComputation(stages.XlaLowering):
_executable: MeshExecutable | None
def __init__(self, name: str, hlo: ir.Module | None,
is_trivial: bool, donated_invars: Sequence[bool], **compile_args):
donated_invars: Sequence[bool], **compile_args):
self._name = name
self._hlo = hlo
self.is_trivial = is_trivial
self._donated_invars = donated_invars
self.compile_args = compile_args
self._executable = None
@ -2259,24 +2240,13 @@ class MeshComputation(stages.XlaLowering):
# -- stages.XlaLowering overrides
def stablehlo(self) -> ir.Module:
if self.is_trivial:
raise ValueError("A trivial computation has no HLO")
return self._hlo
def compile(
self,
compiler_options=None,
) -> MeshExecutable:
def compile(self, compiler_options=None) -> MeshExecutable:
if self._executable is None or compiler_options is not None:
if self.is_trivial:
executable = MeshExecutable.from_trivial_jaxpr(
**self.compile_args)
else:
executable = UnloadedMeshExecutable.from_hlo(
self._name,
self._hlo,
**self.compile_args,
compiler_options=compiler_options)
executable = UnloadedMeshExecutable.from_hlo(
self._name, self._hlo, **self.compile_args,
compiler_options=compiler_options)
if compiler_options is None:
self._executable = executable
return executable
@ -2735,32 +2705,6 @@ class MeshExecutable(stages.XlaExecutable):
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
in_shardings, backend, da_object,
committed, kept_var_idx, keepalive) -> MeshExecutable:
assert keepalive is None
if hasattr(backend, "compile_replicated"):
return _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings,
backend, da_object, committed, kept_var_idx, 1)
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object)
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
# TODO(yashkatariya): Make local_device_assignment directly usable in the
# downstream code without tuple conversion.
local_device_assignment = tuple(da_object.addressable_device_list)
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
[False] * len(global_out_avals))
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
handle_outs, kept_var_idx)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
None)
# -- stages.XlaExecutable overrides
def xla_extension_executable(self):
@ -2853,47 +2797,6 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
return in_shardings, out_shardings, committed, tuple(local_devices)
def _out_shardings_for_trivial(
jaxpr: core.Jaxpr, consts: Sequence[Any],
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
device_assignment: Sequence[xc.Device],
) -> list[sharding_impls.XLACompatibleSharding]:
# For each jaxpr output, compute a Sharding by:
# * if the output is a forwarded input, get the corresponding in_sharding;
# * if the output is a constant Array, get its .sharding attribute;
# * otherwise, the output is a literal or numpy.ndarray constant, so give it
# a replicated sharding
from jax._src import array
if len(device_assignment) > 1:
rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(
i._original_sharding if hasattr(i, '_original_sharding') else i
for i in in_shardings)
else:
dev, = device_assignment
rep = sharding_impls.SingleDeviceSharding(dev)
in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings)
shardings: dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.ArrayImpl):
shardings[constvar] = constval.sharding
map(shardings.setdefault, jaxpr.invars, in_shardings)
return [rep if isinstance(x, core.Literal) else shardings.get(x, rep)
for x in jaxpr.outvars]
def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args):
env: dict[core.Var, Any] = {}
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
map(env.setdefault, jaxpr.invars, pruned_args)
map(env.setdefault, jaxpr.constvars, consts)
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
for v in jaxpr.outvars]
return out_handler(in_handler(outs))
@weakref_lru_cache
def _compile_replicated_mesh_executable_from_hlo(
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
@ -2926,28 +2829,6 @@ def _compile_replicated_mesh_executable_from_hlo(
kept_var_idx, jaxpr_debug_info, None)
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
da_object, committed, kept_var_idx, pmap_nreps):
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object)
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
[False] * len(global_out_avals))
# Use the standard out_handler.
unsafe_call = backend.compile_replicated(
is_trivial=True, jaxpr=jaxpr, consts=consts,
device_assignment=da_object, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx, out_handler=handle_outs,
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
None)
@lru_cache
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,

View File

@ -661,15 +661,15 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = self.jit(lambda x: x)(x)
self.assertEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
self.assertNotEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
z1, z2 = self.jit(lambda x: (x, x))(x)
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = self.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
self.assertEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertNotEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertEqual(z2, 1)
def test_trivial_computations_with_tokens(self):
@ -1176,7 +1176,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertLen(compiled._executable.in_avals, 2)
# Also works with jax.jit
jitted_f = self.jit(lambda x, y: x, keep_unused=True)
with jtu.count_device_put() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
_ = jitted_f(1, 2)
self.assertEqual(count[0], 1)
@ -3273,15 +3273,15 @@ class APITest(jtu.JaxTestCase):
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = api.jit(lambda x: x)(x)
self.assertEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
self.assertNotEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
z1, z2 = api.jit(lambda x: (x, x))(x)
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
self.assertEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertNotEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertEqual(z2, 1)
def test_nested_jit_hoisting(self):
@ -5455,19 +5455,19 @@ class RematTest(jtu.JaxTestCase):
# https://github.com/google/jax/issues/9661
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
_, f_vjp = jax.vjp(identity, 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
def test_vjp_caching_static_argnums(self):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
static_argnums=(1,))
_, f_vjp = jax.vjp(identity, 1., True)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
def test_fwd_caching(self):
# see above test also

View File

@ -3119,14 +3119,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
_ptr = lambda x: x.unsafe_buffer_pointer()
self.assertEqual(_ptr(x), _ptr(x_view))
self.assertEqual(_ptr(x), _ptr(x_view_jit))
self.assertNotEqual(_ptr(x), _ptr(x_view_jit))
self.assertNotEqual(_ptr(x), _ptr(x_copy))
self.assertNotEqual(_ptr(x), _ptr(x_copy_jit))
x.delete()
self.assertTrue(x_view.is_deleted())
self.assertTrue(x_view_jit.is_deleted())
self.assertFalse(x_view_jit.is_deleted())
self.assertFalse(x_copy.is_deleted())
self.assertFalse(x_copy_jit.is_deleted())

View File

@ -2995,8 +2995,8 @@ def shard_foo_array_handler(x, devices, indices, sharding):
return pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])
def foo_array_constant_handler(x, c):
return array._array_mlir_constant_handler(x.data, c)
def foo_array_constant_handler(x):
return array._array_mlir_constant_handler(x.data)
def make_lowering(*, shape):
return jnp.zeros((*shape, 2), 'uint32')

View File

@ -116,7 +116,8 @@ class MultiDeviceTest(jtu.JaxTestCase):
z1, z2 = jax.jit(lambda x: (x, x))(x_uncommitted)
self.assert_uncommitted_to_device(z1, devices[0])
self.assert_uncommitted_to_device(z2, devices[0])
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
# trivial computation does not exist in JAX anymore.
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
x2_uncommitted = jnp.array([2, 3])
z1, z2, z3 = jax.jit(lambda x, y: (y, 1, x))(x_uncommitted, x2_uncommitted)

View File

@ -1987,14 +1987,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
a = jnp.arange(16).reshape((8, 2))
f = pjit(lambda x: x)
out = f(a)
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
_ = f(out)
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
with jtu.count_pjit_cpp_cache_miss() as count:
out = f(a)
_ = f(out)
self.assertEqual(count[0], 1)
def test_pjit_different_device_recompilation(self):
if jax.device_count() < 2: