Rollforward with fixes: Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases

PiperOrigin-RevId: 516317920
This commit is contained in:
Parker Schuh 2023-03-13 14:08:48 -07:00 committed by jax authors
parent 42ef649e65
commit 5aa74acbcd
4 changed files with 59 additions and 52 deletions

View File

@ -932,8 +932,7 @@ def _execute_replicated(name: str,
ordered_effects: List[core.Effect],
kept_var_idx,
has_host_callbacks: bool,
*args,
from_lower_sharding_computation: bool = False):
*args):
if has_unordered_effects or ordered_effects:
# TODO(sharadmv): support jit-of-pmap with effects
raise NotImplementedError(
@ -947,8 +946,6 @@ def _execute_replicated(name: str,
out_flat = [bufs[0] for bufs in out_bufs_flat_rep] # type: ignore
check_special(name, out_flat)
out_bufs = unflatten(out_flat, output_buffer_counts)
if from_lower_sharding_computation:
return result_handler(out_bufs)
return result_handler(None, out_bufs)

View File

@ -1766,6 +1766,21 @@ class UnloadedPmapExecutable:
self.local_input_avals)
def _compile_replicated_pmap_executable_from_hlo(
xla_computation, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=xla_computation,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, execute_fun, None, pci.avals)
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "unsafe_call", "fingerprint", "in_avals"]
@ -3264,7 +3279,8 @@ def get_gspmd_shardings_from_executable(
# TODO(b/245667823): Remove this when XLA fixes this.
if len(out_shardings_xla) == 1 and len(out_shardings_xla) < num_out_avals:
out_shardings_xla = out_shardings_xla * num_out_avals
assert len(out_shardings_xla) == num_out_avals
assert len(out_shardings_xla) == num_out_avals, (
len(out_shardings_xla), num_out_avals)
return in_shardings_xla, out_shardings_xla
@ -3292,7 +3308,6 @@ class UnloadedMeshExecutable:
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
committed: bool
are_out_shardings_from_xla: Sequence[bool]
pmap_nreps: int
name: str
unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect]
@ -3309,22 +3324,10 @@ class UnloadedMeshExecutable:
self.output_avals, self.output_shardings, self.committed,
self.are_out_shardings_from_xla) # type: ignore # arg-type
# This path is taken for `jit(pmap)` cases. Nothing else should flow
# through this path. This is exactly same to what happens in `jit`.
if self.pmap_nreps > 1:
has_unordered_effects = bool(self.unordered_effects)
buffer_counts = dispatch.get_buffer_counts(
self.output_avals, self.ordered_effects, has_unordered_effects)
unsafe_call = partial(
dispatch._execute_replicated, self.name, self.xla_executable, None,
buffer_counts, handle_outs, has_unordered_effects, self.ordered_effects,
self.kept_var_idx, bool(self.host_callbacks),
from_lower_sharding_computation=True)
else:
unsafe_call = ExecuteReplicated( # type: ignore # assignment
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx)
unsafe_call = ExecuteReplicated( # type: ignore # assignment
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx)
return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals,
self.input_shardings, self.output_shardings,
@ -3425,7 +3428,8 @@ class UnloadedMeshExecutable:
for x, o in safe_zip(out_shardings_xla, out_shardings)
]
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
elif out_shardings and any(_is_unspecified(o) for o in out_shardings):
elif (out_shardings and any(_is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1):
assert mesh is None
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment,
@ -3449,17 +3453,25 @@ class UnloadedMeshExecutable:
else:
are_out_shardings_from_xla = (False,) * len(global_out_avals)
if pmap_nreps > 1:
local_devices = xla_executable.local_devices()
# Create replicated shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
in_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * len(in_shardings)
out_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * len(out_shardings)
# jit(pmap) will generate Arrays with multi-device sharding.
# It is unsupported for these shardings to be uncommited, so force
# the outputs to be committed.
committed = True
input_avals, input_shardings = (
_get_normalized_avals_and_shardings(
global_in_avals, in_shardings, in_is_global)) # type: ignore # arg-type
if pmap_nreps > 1:
local_devices = xla_executable.local_devices()
# Create replicated in_shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
input_shardings = [sharding_impls.GSPMDSharding.get_replicated(
local_devices) for _ in input_shardings]
return UnloadedMeshExecutable(
xla_executable=xla_executable,
device_assignment=device_assignment,
@ -3470,7 +3482,6 @@ class UnloadedMeshExecutable:
output_shardings=out_shardings, # type: ignore # arg-type
committed=committed,
are_out_shardings_from_xla=are_out_shardings_from_xla,
pmap_nreps=pmap_nreps,
name=name,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
@ -3622,21 +3633,6 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args
return out_handler(in_handler(outs))
def _compile_replicated_pmap_executable_from_hlo(
xla_computation, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=xla_computation,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, execute_fun, None, pci.avals)
def _compile_replicated_mesh_executable_from_hlo(
name, computation, global_in_avals, global_out_avals, in_shardings,
out_shardings, in_is_global, auto_spmd_lowering, compile_options,

View File

@ -537,17 +537,16 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def test_cant_jit_and_pmap_function_with_unordered_effects(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
if not jax.config.jax_array:
self.skipTest("Only works with jax.Array")
@jax.jit
@jax.pmap
def f(x):
effect_p.bind(effect=bar_effect)
return x + 1
with self.assertRaisesRegex(
NotImplementedError,
"Cannot execute replicated computation with effects."):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f(jnp.arange(jax.device_count()))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f(jnp.arange(jax.device_count())) # doesn't crash
def test_cant_jit_and_pmap_function_with_ordered_effects(self):
@jax.jit

View File

@ -1834,6 +1834,21 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertIn("The jitted function foo includes a pmap",
str(w[-1].message))
def testJitOfPmapOutputSharding(self):
device_count = jax.device_count()
if device_count == 1 or config.jax_disable_jit:
raise SkipTest("test requires at least two devices")
@jax.jit
@jax.pmap
def foo(x): return x + x
x = np.ones((2,2,2), dtype=np.float32)
for _ in range(10):
# Does not crash.
x = foo(x)
def testPsumZeroCotangents(self):
# https://github.com/google/jax/issues/3651
def loss(params, meta_params):