mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Account for tokens in allow_spmd_sharding_propagation_to_parameters
and allow_spmd_sharding_propagation_to_output
compile options
PiperOrigin-RevId: 707723232
This commit is contained in:
parent
46b18d272c
commit
9041b02dff
@ -2909,9 +2909,9 @@ class UnloadedMeshExecutable:
|
||||
da = _create_da_object(tuple(device_assignment))
|
||||
del device_assignment
|
||||
|
||||
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
|
||||
for i in in_shardings)
|
||||
allow_prop_to_outputs = tuple(
|
||||
allow_prop_to_inputs = (False,) * len(ordered_effects) + tuple(
|
||||
isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings)
|
||||
allow_prop_to_outputs = (False,) * len(ordered_effects) + tuple(
|
||||
isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o)
|
||||
for o in out_shardings)
|
||||
|
||||
|
@ -4305,6 +4305,24 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertLen(traced.in_avals[0], 1)
|
||||
self.assertLen(traced.in_avals[1], 0) # empty kwarg
|
||||
|
||||
def test_empty_io_callback_under_shard_map(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("Shardy errors out on empty callbacks.")
|
||||
mesh = jtu.create_mesh((4,), 'i')
|
||||
|
||||
def empty_callback(x):
|
||||
return
|
||||
|
||||
def _f(x, y):
|
||||
jax.experimental.io_callback(
|
||||
empty_callback, (), x, ordered=True)
|
||||
return x + y[..., jnp.newaxis]
|
||||
|
||||
f = jax.jit(shard_map(
|
||||
_f, mesh, in_specs=(P(None, 'i'), P(None)),
|
||||
out_specs=P(None, 'i')))
|
||||
f(jnp.zeros((2, 16)), jnp.ones(2))
|
||||
|
||||
def test_jit_trace_lower_and_compile(self):
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user