mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Rollback] Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 515659122
This commit is contained in:
parent
d58be3d4df
commit
00b90e9073
@ -931,7 +931,8 @@ def _execute_replicated(name: str,
|
||||
ordered_effects: List[core.Effect],
|
||||
kept_var_idx,
|
||||
has_host_callbacks: bool,
|
||||
*args):
|
||||
*args,
|
||||
from_lower_sharding_computation: bool = False):
|
||||
if has_unordered_effects or ordered_effects:
|
||||
# TODO(sharadmv): support jit-of-pmap with effects
|
||||
raise NotImplementedError(
|
||||
@ -945,6 +946,8 @@ def _execute_replicated(name: str,
|
||||
out_flat = [bufs[0] for bufs in out_bufs_flat_rep] # type: ignore
|
||||
check_special(name, out_flat)
|
||||
out_bufs = unflatten(out_flat, output_buffer_counts)
|
||||
if from_lower_sharding_computation:
|
||||
return result_handler(out_bufs)
|
||||
return result_handler(None, out_bufs)
|
||||
|
||||
|
||||
|
@ -1751,21 +1751,6 @@ class UnloadedPmapExecutable:
|
||||
self.local_input_avals)
|
||||
|
||||
|
||||
def _compile_replicated_pmap_executable_from_hlo(
|
||||
xla_computation, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
|
||||
# Use the standard out_handler.
|
||||
execute_fun = pci.backend.compile_replicated(
|
||||
is_trivial=False, name=pci.name, computation=xla_computation,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
ordered_effects=ordered_effects, in_avals=pci.avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, execute_fun, None, pci.avals)
|
||||
|
||||
|
||||
class PmapExecutable(stages.XlaExecutable):
|
||||
__slots__ = ["xla_executable", "unsafe_call", "fingerprint", "in_avals"]
|
||||
|
||||
@ -3524,6 +3509,7 @@ class UnloadedMeshExecutable:
|
||||
output_shardings: Sequence[sharding_internal.XLACompatibleSharding]
|
||||
committed: bool
|
||||
are_out_shardings_from_xla: Sequence[bool]
|
||||
pmap_nreps: int
|
||||
name: str
|
||||
unordered_effects: List[core.Effect]
|
||||
ordered_effects: List[core.Effect]
|
||||
@ -3540,10 +3526,22 @@ class UnloadedMeshExecutable:
|
||||
self.output_avals, self.output_shardings, self.committed,
|
||||
self.are_out_shardings_from_xla) # type: ignore # arg-type
|
||||
|
||||
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
||||
self.xla_executable, self.name, self.backend, handle_args,
|
||||
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks), self.kept_var_idx)
|
||||
# This path is taken for `jit(pmap)` cases. Nothing else should flow
|
||||
# through this path. This is exactly same to what happens in `jit`.
|
||||
if self.pmap_nreps > 1:
|
||||
has_unordered_effects = bool(self.unordered_effects)
|
||||
buffer_counts = dispatch.get_buffer_counts(
|
||||
self.output_avals, self.ordered_effects, has_unordered_effects)
|
||||
unsafe_call = partial(
|
||||
dispatch._execute_replicated, self.name, self.xla_executable, None,
|
||||
buffer_counts, handle_outs, has_unordered_effects, self.ordered_effects,
|
||||
self.kept_var_idx, bool(self.host_callbacks),
|
||||
from_lower_sharding_computation=True)
|
||||
else:
|
||||
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
||||
self.xla_executable, self.name, self.backend, handle_args,
|
||||
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks), self.kept_var_idx)
|
||||
|
||||
return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
@ -3676,9 +3674,8 @@ class UnloadedMeshExecutable:
|
||||
local_devices = xla_executable.local_devices()
|
||||
# Create replicated in_shardings for jit(pmap) path with local devices
|
||||
# because multihost jit(pmap) is not allowed.
|
||||
input_shardings = [
|
||||
sharding_internal.GSPMDSharding.get_replicated(local_devices)
|
||||
] * len(input_shardings)
|
||||
input_shardings = [sharding_internal.GSPMDSharding.get_replicated(
|
||||
local_devices) for _ in input_shardings]
|
||||
|
||||
return UnloadedMeshExecutable(
|
||||
xla_executable=xla_executable,
|
||||
@ -3690,6 +3687,7 @@ class UnloadedMeshExecutable:
|
||||
output_shardings=out_shardings, # type: ignore # arg-type
|
||||
committed=committed,
|
||||
are_out_shardings_from_xla=are_out_shardings_from_xla,
|
||||
pmap_nreps=pmap_nreps,
|
||||
name=name,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
@ -3841,6 +3839,21 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args
|
||||
return out_handler(in_handler(outs))
|
||||
|
||||
|
||||
def _compile_replicated_pmap_executable_from_hlo(
|
||||
xla_computation, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
|
||||
# Use the standard out_handler.
|
||||
execute_fun = pci.backend.compile_replicated(
|
||||
is_trivial=False, name=pci.name, computation=xla_computation,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
ordered_effects=ordered_effects, in_avals=pci.avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, execute_fun, None, pci.avals)
|
||||
|
||||
|
||||
def _compile_replicated_mesh_executable_from_hlo(
|
||||
name, computation, global_in_avals, global_out_avals, in_shardings,
|
||||
out_shardings, in_is_global, auto_spmd_lowering, compile_options,
|
||||
|
@ -538,16 +538,17 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def test_cant_jit_and_pmap_function_with_unordered_effects(self):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
if not jax.config.jax_array:
|
||||
self.skipTest("Only works with jax.Array")
|
||||
@jax.jit
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
f(jnp.arange(jax.device_count())) # doesn't crash
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Cannot execute replicated computation with effects."):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
f(jnp.arange(jax.device_count()))
|
||||
|
||||
def test_cant_jit_and_pmap_function_with_ordered_effects(self):
|
||||
@jax.jit
|
||||
|
Loading…
x
Reference in New Issue
Block a user