[Rollback] Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases

PiperOrigin-RevId: 515659122
This commit is contained in:
Yash Katariya 2023-03-10 09:35:39 -08:00 committed by jax authors
parent d58be3d4df
commit 00b90e9073
3 changed files with 45 additions and 28 deletions

View File

@ -931,7 +931,8 @@ def _execute_replicated(name: str,
ordered_effects: List[core.Effect],
kept_var_idx,
has_host_callbacks: bool,
*args):
*args,
from_lower_sharding_computation: bool = False):
if has_unordered_effects or ordered_effects:
# TODO(sharadmv): support jit-of-pmap with effects
raise NotImplementedError(
@ -945,6 +946,8 @@ def _execute_replicated(name: str,
out_flat = [bufs[0] for bufs in out_bufs_flat_rep] # type: ignore
check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts)
if from_lower_sharding_computation:
return result_handler(out_bufs)
return result_handler(None, out_bufs)

View File

@ -1751,21 +1751,6 @@ class UnloadedPmapExecutable:
self.local_input_avals)
def _compile_replicated_pmap_executable_from_hlo(
xla_computation, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=xla_computation,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, execute_fun, None, pci.avals)
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "unsafe_call", "fingerprint", "in_avals"]
@ -3524,6 +3509,7 @@ class UnloadedMeshExecutable:
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
committed: bool
are_out_shardings_from_xla: Sequence[bool]
pmap_nreps: int
name: str
unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect]
@ -3540,10 +3526,22 @@ class UnloadedMeshExecutable:
self.output_avals, self.output_shardings, self.committed,
self.are_out_shardings_from_xla) # type: ignore # arg-type
unsafe_call = ExecuteReplicated( # type: ignore # assignment
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx)
# This path is taken for `jit(pmap)` cases. Nothing else should flow
# through this path. This is exactly same to what happens in `jit`.
if self.pmap_nreps > 1:
has_unordered_effects = bool(self.unordered_effects)
buffer_counts = dispatch.get_buffer_counts(
self.output_avals, self.ordered_effects, has_unordered_effects)
unsafe_call = partial(
dispatch._execute_replicated, self.name, self.xla_executable, None,
buffer_counts, handle_outs, has_unordered_effects, self.ordered_effects,
self.kept_var_idx, bool(self.host_callbacks),
from_lower_sharding_computation=True)
else:
unsafe_call = ExecuteReplicated( # type: ignore # assignment
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx)
return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals,
self.input_shardings, self.output_shardings,
@ -3676,9 +3674,8 @@ class UnloadedMeshExecutable:
local_devices = xla_executable.local_devices()
# Create replicated in_shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
input_shardings = [
sharding_internal.GSPMDSharding.get_replicated(local_devices)
] * len(input_shardings)
input_shardings = [sharding_internal.GSPMDSharding.get_replicated(
local_devices) for _ in input_shardings]
return UnloadedMeshExecutable(
xla_executable=xla_executable,
@ -3690,6 +3687,7 @@ class UnloadedMeshExecutable:
output_shardings=out_shardings, # type: ignore # arg-type
committed=committed,
are_out_shardings_from_xla=are_out_shardings_from_xla,
pmap_nreps=pmap_nreps,
name=name,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
@ -3841,6 +3839,21 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args
return out_handler(in_handler(outs))
def _compile_replicated_pmap_executable_from_hlo(
xla_computation, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=xla_computation,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, execute_fun, None, pci.avals)
def _compile_replicated_mesh_executable_from_hlo(
name, computation, global_in_avals, global_out_avals, in_shardings,
out_shardings, in_is_global, auto_spmd_lowering, compile_options,

View File

@ -538,16 +538,17 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_cant_jit_and_pmap_function_with_unordered_effects(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
if not jax.config.jax_array:
self.skipTest("Only works with jax.Array")
@jax.jit
@jax.pmap
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f(jnp.arange(jax.device_count())) # doesn't crash
with self.assertRaisesRegex(
NotImplementedError,
"Cannot execute replicated computation with effects."):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f(jnp.arange(jax.device_count()))
def test_cant_jit_and_pmap_function_with_ordered_effects(self):
@jax.jit