mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Rollforward with fixes: Remove _execute_replicated from UnloadedMeshExecutable.load since it is not required anymore for jit(pmap) cases
PiperOrigin-RevId: 516317920
This commit is contained in:
parent
42ef649e65
commit
5aa74acbcd
@ -932,8 +932,7 @@ def _execute_replicated(name: str,
|
||||
ordered_effects: List[core.Effect],
|
||||
kept_var_idx,
|
||||
has_host_callbacks: bool,
|
||||
*args,
|
||||
from_lower_sharding_computation: bool = False):
|
||||
*args):
|
||||
if has_unordered_effects or ordered_effects:
|
||||
# TODO(sharadmv): support jit-of-pmap with effects
|
||||
raise NotImplementedError(
|
||||
@ -947,8 +946,6 @@ def _execute_replicated(name: str,
|
||||
out_flat = [bufs[0] for bufs in out_bufs_flat_rep] # type: ignore
|
||||
check_special(name, out_flat)
|
||||
out_bufs = unflatten(out_flat, output_buffer_counts)
|
||||
if from_lower_sharding_computation:
|
||||
return result_handler(out_bufs)
|
||||
return result_handler(None, out_bufs)
|
||||
|
||||
|
||||
|
@ -1766,6 +1766,21 @@ class UnloadedPmapExecutable:
|
||||
self.local_input_avals)
|
||||
|
||||
|
||||
def _compile_replicated_pmap_executable_from_hlo(
|
||||
xla_computation, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
|
||||
# Use the standard out_handler.
|
||||
execute_fun = pci.backend.compile_replicated(
|
||||
is_trivial=False, name=pci.name, computation=xla_computation,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
ordered_effects=ordered_effects, in_avals=pci.avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, execute_fun, None, pci.avals)
|
||||
|
||||
|
||||
class PmapExecutable(stages.XlaExecutable):
|
||||
__slots__ = ["xla_executable", "unsafe_call", "fingerprint", "in_avals"]
|
||||
|
||||
@ -3264,7 +3279,8 @@ def get_gspmd_shardings_from_executable(
|
||||
# TODO(b/245667823): Remove this when XLA fixes this.
|
||||
if len(out_shardings_xla) == 1 and len(out_shardings_xla) < num_out_avals:
|
||||
out_shardings_xla = out_shardings_xla * num_out_avals
|
||||
assert len(out_shardings_xla) == num_out_avals
|
||||
assert len(out_shardings_xla) == num_out_avals, (
|
||||
len(out_shardings_xla), num_out_avals)
|
||||
return in_shardings_xla, out_shardings_xla
|
||||
|
||||
|
||||
@ -3292,7 +3308,6 @@ class UnloadedMeshExecutable:
|
||||
output_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
committed: bool
|
||||
are_out_shardings_from_xla: Sequence[bool]
|
||||
pmap_nreps: int
|
||||
name: str
|
||||
unordered_effects: List[core.Effect]
|
||||
ordered_effects: List[core.Effect]
|
||||
@ -3309,22 +3324,10 @@ class UnloadedMeshExecutable:
|
||||
self.output_avals, self.output_shardings, self.committed,
|
||||
self.are_out_shardings_from_xla) # type: ignore # arg-type
|
||||
|
||||
# This path is taken for `jit(pmap)` cases. Nothing else should flow
|
||||
# through this path. This is exactly same to what happens in `jit`.
|
||||
if self.pmap_nreps > 1:
|
||||
has_unordered_effects = bool(self.unordered_effects)
|
||||
buffer_counts = dispatch.get_buffer_counts(
|
||||
self.output_avals, self.ordered_effects, has_unordered_effects)
|
||||
unsafe_call = partial(
|
||||
dispatch._execute_replicated, self.name, self.xla_executable, None,
|
||||
buffer_counts, handle_outs, has_unordered_effects, self.ordered_effects,
|
||||
self.kept_var_idx, bool(self.host_callbacks),
|
||||
from_lower_sharding_computation=True)
|
||||
else:
|
||||
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
||||
self.xla_executable, self.name, self.backend, handle_args,
|
||||
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks), self.kept_var_idx)
|
||||
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
||||
self.xla_executable, self.name, self.backend, handle_args,
|
||||
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks), self.kept_var_idx)
|
||||
|
||||
return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
@ -3425,7 +3428,8 @@ class UnloadedMeshExecutable:
|
||||
for x, o in safe_zip(out_shardings_xla, out_shardings)
|
||||
]
|
||||
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
|
||||
elif out_shardings and any(_is_unspecified(o) for o in out_shardings):
|
||||
elif (out_shardings and any(_is_unspecified(o) for o in out_shardings)
|
||||
and pmap_nreps == 1):
|
||||
assert mesh is None
|
||||
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
xla_executable, device_assignment,
|
||||
@ -3449,17 +3453,25 @@ class UnloadedMeshExecutable:
|
||||
else:
|
||||
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
||||
|
||||
if pmap_nreps > 1:
|
||||
local_devices = xla_executable.local_devices()
|
||||
# Create replicated shardings for jit(pmap) path with local devices
|
||||
# because multihost jit(pmap) is not allowed.
|
||||
in_shardings = [
|
||||
sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
] * len(in_shardings)
|
||||
out_shardings = [
|
||||
sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
] * len(out_shardings)
|
||||
# jit(pmap) will generate Arrays with multi-device sharding.
|
||||
# It is unsupported for these shardings to be uncommited, so force
|
||||
# the outputs to be committed.
|
||||
committed = True
|
||||
|
||||
input_avals, input_shardings = (
|
||||
_get_normalized_avals_and_shardings(
|
||||
global_in_avals, in_shardings, in_is_global)) # type: ignore # arg-type
|
||||
|
||||
if pmap_nreps > 1:
|
||||
local_devices = xla_executable.local_devices()
|
||||
# Create replicated in_shardings for jit(pmap) path with local devices
|
||||
# because multihost jit(pmap) is not allowed.
|
||||
input_shardings = [sharding_impls.GSPMDSharding.get_replicated(
|
||||
local_devices) for _ in input_shardings]
|
||||
|
||||
return UnloadedMeshExecutable(
|
||||
xla_executable=xla_executable,
|
||||
device_assignment=device_assignment,
|
||||
@ -3470,7 +3482,6 @@ class UnloadedMeshExecutable:
|
||||
output_shardings=out_shardings, # type: ignore # arg-type
|
||||
committed=committed,
|
||||
are_out_shardings_from_xla=are_out_shardings_from_xla,
|
||||
pmap_nreps=pmap_nreps,
|
||||
name=name,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
@ -3622,21 +3633,6 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args
|
||||
return out_handler(in_handler(outs))
|
||||
|
||||
|
||||
def _compile_replicated_pmap_executable_from_hlo(
|
||||
xla_computation, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
|
||||
# Use the standard out_handler.
|
||||
execute_fun = pci.backend.compile_replicated(
|
||||
is_trivial=False, name=pci.name, computation=xla_computation,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
ordered_effects=ordered_effects, in_avals=pci.avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, execute_fun, None, pci.avals)
|
||||
|
||||
|
||||
def _compile_replicated_mesh_executable_from_hlo(
|
||||
name, computation, global_in_avals, global_out_avals, in_shardings,
|
||||
out_shardings, in_is_global, auto_spmd_lowering, compile_options,
|
||||
|
@ -537,17 +537,16 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def test_cant_jit_and_pmap_function_with_unordered_effects(self):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
if not jax.config.jax_array:
|
||||
self.skipTest("Only works with jax.Array")
|
||||
@jax.jit
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x + 1
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Cannot execute replicated computation with effects."):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
f(jnp.arange(jax.device_count()))
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
f(jnp.arange(jax.device_count())) # doesn't crash
|
||||
|
||||
def test_cant_jit_and_pmap_function_with_ordered_effects(self):
|
||||
@jax.jit
|
||||
|
@ -1834,6 +1834,21 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertIn("The jitted function foo includes a pmap",
|
||||
str(w[-1].message))
|
||||
|
||||
def testJitOfPmapOutputSharding(self):
|
||||
device_count = jax.device_count()
|
||||
|
||||
if device_count == 1 or config.jax_disable_jit:
|
||||
raise SkipTest("test requires at least two devices")
|
||||
|
||||
@jax.jit
|
||||
@jax.pmap
|
||||
def foo(x): return x + x
|
||||
|
||||
x = np.ones((2,2,2), dtype=np.float32)
|
||||
for _ in range(10):
|
||||
# Does not crash.
|
||||
x = foo(x)
|
||||
|
||||
def testPsumZeroCotangents(self):
|
||||
# https://github.com/google/jax/issues/3651
|
||||
def loss(params, meta_params):
|
||||
|
Loading…
x
Reference in New Issue
Block a user