mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove trivial execution from jax since it leads to 100x slower dispatch time.
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist. In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about. ``` jit_trivial_dispatch 246µs ± 3% 4µs ± 1% -98.52% (p=0.008 n=5+5) jit_trivial 250µs ± 3% 5µs ± 1% -98.19% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 560141018
This commit is contained in:
parent
c71eedf529
commit
970f4c9d4d
@ -2016,22 +2016,6 @@ def lower_sharding_computation(
|
||||
"To fix this error, run your `jitted` computation inside "
|
||||
"`with jax.spmd_mode('allow_all'):` context manager.")
|
||||
|
||||
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
|
||||
kept_outputs = [True] * len(global_out_avals)
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
# and don't need to evaluate their arguments.
|
||||
if (not always_lower and not (jaxpr.effects or has_outfeed) and
|
||||
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
|
||||
all(is_unspecified(o) for o in out_shardings)):
|
||||
return MeshComputation(
|
||||
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
|
||||
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals, in_shardings=in_shardings,
|
||||
backend=backend, da_object=da_object,
|
||||
committed=committed, kept_var_idx=kept_var_idx, keepalive=None)
|
||||
|
||||
# 2. Build up the HLO
|
||||
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
||||
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
|
||||
@ -2049,7 +2033,6 @@ def lower_sharding_computation(
|
||||
return MeshComputation(
|
||||
str(name_stack),
|
||||
module,
|
||||
False,
|
||||
donated_invars,
|
||||
global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals,
|
||||
@ -2223,7 +2206,6 @@ def lower_mesh_computation(
|
||||
return MeshComputation(
|
||||
str(name_stack),
|
||||
lowering_result.module,
|
||||
False,
|
||||
donated_invars,
|
||||
global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals,
|
||||
@ -2248,10 +2230,9 @@ class MeshComputation(stages.XlaLowering):
|
||||
_executable: MeshExecutable | None
|
||||
|
||||
def __init__(self, name: str, hlo: ir.Module | None,
|
||||
is_trivial: bool, donated_invars: Sequence[bool], **compile_args):
|
||||
donated_invars: Sequence[bool], **compile_args):
|
||||
self._name = name
|
||||
self._hlo = hlo
|
||||
self.is_trivial = is_trivial
|
||||
self._donated_invars = donated_invars
|
||||
self.compile_args = compile_args
|
||||
self._executable = None
|
||||
@ -2259,24 +2240,13 @@ class MeshComputation(stages.XlaLowering):
|
||||
# -- stages.XlaLowering overrides
|
||||
|
||||
def stablehlo(self) -> ir.Module:
|
||||
if self.is_trivial:
|
||||
raise ValueError("A trivial computation has no HLO")
|
||||
return self._hlo
|
||||
|
||||
def compile(
|
||||
self,
|
||||
compiler_options=None,
|
||||
) -> MeshExecutable:
|
||||
def compile(self, compiler_options=None) -> MeshExecutable:
|
||||
if self._executable is None or compiler_options is not None:
|
||||
if self.is_trivial:
|
||||
executable = MeshExecutable.from_trivial_jaxpr(
|
||||
**self.compile_args)
|
||||
else:
|
||||
executable = UnloadedMeshExecutable.from_hlo(
|
||||
self._name,
|
||||
self._hlo,
|
||||
**self.compile_args,
|
||||
compiler_options=compiler_options)
|
||||
executable = UnloadedMeshExecutable.from_hlo(
|
||||
self._name, self._hlo, **self.compile_args,
|
||||
compiler_options=compiler_options)
|
||||
if compiler_options is None:
|
||||
self._executable = executable
|
||||
return executable
|
||||
@ -2735,32 +2705,6 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self._unsafe_call = self.build_unsafe_call()
|
||||
return self._unsafe_call
|
||||
|
||||
@staticmethod
|
||||
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
|
||||
in_shardings, backend, da_object,
|
||||
committed, kept_var_idx, keepalive) -> MeshExecutable:
|
||||
assert keepalive is None
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
return _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
jaxpr, consts, global_in_avals, global_out_avals, in_shardings,
|
||||
backend, da_object, committed, kept_var_idx, 1)
|
||||
|
||||
out_shardings = _out_shardings_for_trivial(
|
||||
jaxpr, consts, in_shardings, da_object)
|
||||
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
|
||||
# TODO(yashkatariya): Make local_device_assignment directly usable in the
|
||||
# downstream code without tuple conversion.
|
||||
local_device_assignment = tuple(da_object.addressable_device_list)
|
||||
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_out_avals, out_shardings, committed,
|
||||
[False] * len(global_out_avals))
|
||||
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
|
||||
handle_outs, kept_var_idx)
|
||||
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, False, kept_var_idx,
|
||||
None)
|
||||
|
||||
# -- stages.XlaExecutable overrides
|
||||
|
||||
def xla_extension_executable(self):
|
||||
@ -2853,47 +2797,6 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
return in_shardings, out_shardings, committed, tuple(local_devices)
|
||||
|
||||
|
||||
def _out_shardings_for_trivial(
|
||||
jaxpr: core.Jaxpr, consts: Sequence[Any],
|
||||
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
device_assignment: Sequence[xc.Device],
|
||||
) -> list[sharding_impls.XLACompatibleSharding]:
|
||||
# For each jaxpr output, compute a Sharding by:
|
||||
# * if the output is a forwarded input, get the corresponding in_sharding;
|
||||
# * if the output is a constant Array, get its .sharding attribute;
|
||||
# * otherwise, the output is a literal or numpy.ndarray constant, so give it
|
||||
# a replicated sharding
|
||||
from jax._src import array
|
||||
|
||||
if len(device_assignment) > 1:
|
||||
rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(
|
||||
i._original_sharding if hasattr(i, '_original_sharding') else i
|
||||
for i in in_shardings)
|
||||
else:
|
||||
dev, = device_assignment
|
||||
rep = sharding_impls.SingleDeviceSharding(dev)
|
||||
in_shardings = (sharding_impls.SingleDeviceSharding(dev),) * len(in_shardings)
|
||||
|
||||
shardings: dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
|
||||
for constvar, constval in zip(jaxpr.constvars, consts):
|
||||
if isinstance(constval, array.ArrayImpl):
|
||||
shardings[constvar] = constval.sharding
|
||||
map(shardings.setdefault, jaxpr.invars, in_shardings)
|
||||
return [rep if isinstance(x, core.Literal) else shardings.get(x, rep)
|
||||
for x in jaxpr.outvars]
|
||||
|
||||
|
||||
def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args):
|
||||
env: dict[core.Var, Any] = {}
|
||||
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
||||
map(env.setdefault, jaxpr.invars, pruned_args)
|
||||
map(env.setdefault, jaxpr.constvars, consts)
|
||||
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
|
||||
for v in jaxpr.outvars]
|
||||
return out_handler(in_handler(outs))
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _compile_replicated_mesh_executable_from_hlo(
|
||||
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
|
||||
@ -2926,28 +2829,6 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
kept_var_idx, jaxpr_debug_info, None)
|
||||
|
||||
|
||||
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
|
||||
da_object, committed, kept_var_idx, pmap_nreps):
|
||||
out_shardings = _out_shardings_for_trivial(
|
||||
jaxpr, consts, in_shardings, da_object)
|
||||
|
||||
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_out_avals, out_shardings, committed,
|
||||
[False] * len(global_out_avals))
|
||||
# Use the standard out_handler.
|
||||
unsafe_call = backend.compile_replicated(
|
||||
is_trivial=True, jaxpr=jaxpr, consts=consts,
|
||||
device_assignment=da_object, in_avals=global_in_avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=kept_var_idx, out_handler=handle_outs,
|
||||
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
|
||||
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, False, kept_var_idx,
|
||||
None)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_mesh_pspec_sharding(
|
||||
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None,
|
||||
|
@ -661,15 +661,15 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
def test_trivial_computations(self):
|
||||
x = jnp.array([1, 2, 3])
|
||||
y = self.jit(lambda x: x)(x)
|
||||
self.assertEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
|
||||
self.assertNotEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
|
||||
|
||||
z1, z2 = self.jit(lambda x: (x, x))(x)
|
||||
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
|
||||
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
|
||||
|
||||
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
|
||||
z1, z2, z3 = self.jit(lambda x, y: (y, 1, x))(x1, x2)
|
||||
self.assertEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
|
||||
self.assertEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
|
||||
self.assertNotEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
|
||||
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
|
||||
self.assertEqual(z2, 1)
|
||||
|
||||
def test_trivial_computations_with_tokens(self):
|
||||
@ -1176,7 +1176,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertLen(compiled._executable.in_avals, 2)
|
||||
# Also works with jax.jit
|
||||
jitted_f = self.jit(lambda x, y: x, keep_unused=True)
|
||||
with jtu.count_device_put() as count:
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
_ = jitted_f(1, 2)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@ -3273,15 +3273,15 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_trivial_computations(self):
|
||||
x = jnp.array([1, 2, 3])
|
||||
y = api.jit(lambda x: x)(x)
|
||||
self.assertEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
|
||||
self.assertNotEqual(x.unsafe_buffer_pointer(), y.unsafe_buffer_pointer())
|
||||
|
||||
z1, z2 = api.jit(lambda x: (x, x))(x)
|
||||
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
|
||||
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
|
||||
|
||||
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
|
||||
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
|
||||
self.assertEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
|
||||
self.assertEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
|
||||
self.assertNotEqual(z1.unsafe_buffer_pointer(), x2.unsafe_buffer_pointer())
|
||||
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
|
||||
self.assertEqual(z2, 1)
|
||||
|
||||
def test_nested_jit_hoisting(self):
|
||||
@ -5455,19 +5455,19 @@ class RematTest(jtu.JaxTestCase):
|
||||
# https://github.com/google/jax/issues/9661
|
||||
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
|
||||
_, f_vjp = jax.vjp(identity, 1.)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
|
||||
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
def test_vjp_caching_static_argnums(self):
|
||||
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
|
||||
static_argnums=(1,))
|
||||
_, f_vjp = jax.vjp(identity, 1., True)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
|
||||
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
def test_fwd_caching(self):
|
||||
# see above test also
|
||||
|
@ -3119,14 +3119,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
_ptr = lambda x: x.unsafe_buffer_pointer()
|
||||
|
||||
self.assertEqual(_ptr(x), _ptr(x_view))
|
||||
self.assertEqual(_ptr(x), _ptr(x_view_jit))
|
||||
self.assertNotEqual(_ptr(x), _ptr(x_view_jit))
|
||||
self.assertNotEqual(_ptr(x), _ptr(x_copy))
|
||||
self.assertNotEqual(_ptr(x), _ptr(x_copy_jit))
|
||||
|
||||
x.delete()
|
||||
|
||||
self.assertTrue(x_view.is_deleted())
|
||||
self.assertTrue(x_view_jit.is_deleted())
|
||||
self.assertFalse(x_view_jit.is_deleted())
|
||||
|
||||
self.assertFalse(x_copy.is_deleted())
|
||||
self.assertFalse(x_copy_jit.is_deleted())
|
||||
|
@ -2995,8 +2995,8 @@ def shard_foo_array_handler(x, devices, indices, sharding):
|
||||
return pxla.batched_device_put(
|
||||
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])
|
||||
|
||||
def foo_array_constant_handler(x, c):
|
||||
return array._array_mlir_constant_handler(x.data, c)
|
||||
def foo_array_constant_handler(x):
|
||||
return array._array_mlir_constant_handler(x.data)
|
||||
|
||||
def make_lowering(*, shape):
|
||||
return jnp.zeros((*shape, 2), 'uint32')
|
||||
|
@ -116,7 +116,8 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
z1, z2 = jax.jit(lambda x: (x, x))(x_uncommitted)
|
||||
self.assert_uncommitted_to_device(z1, devices[0])
|
||||
self.assert_uncommitted_to_device(z2, devices[0])
|
||||
self.assertEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
|
||||
# trivial computation does not exist in JAX anymore.
|
||||
self.assertNotEqual(z1.unsafe_buffer_pointer(), z2.unsafe_buffer_pointer())
|
||||
|
||||
x2_uncommitted = jnp.array([2, 3])
|
||||
z1, z2, z3 = jax.jit(lambda x, y: (y, 1, x))(x_uncommitted, x2_uncommitted)
|
||||
|
@ -1987,14 +1987,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
a = jnp.arange(16).reshape((8, 2))
|
||||
f = pjit(lambda x: x)
|
||||
|
||||
out = f(a)
|
||||
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
|
||||
_ = f(out)
|
||||
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
out = f(a)
|
||||
_ = f(out)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_pjit_different_device_recompilation(self):
|
||||
if jax.device_count() < 2:
|
||||
|
Loading…
x
Reference in New Issue
Block a user