mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix handling of infeed token inside sharded_jit (#3313)
This commit is contained in:
parent
c77c0838fe
commit
5ad9feda5f
@ -5118,7 +5118,11 @@ def infeed(token, shape=None, partitions=None):
|
||||
"ShapedArray values, got {}".format(shape))
|
||||
if partitions is not None:
|
||||
# Always replicate token.
|
||||
partitions = (partitions, None)
|
||||
# We specifically use type() to raise an error for PartitionSpecs.
|
||||
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
|
||||
raise ValueError(f"'partitions' argument to infeed should be a tuple, "
|
||||
f"got {partitions}")
|
||||
partitions = partitions + (None,)
|
||||
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes),
|
||||
partitions=partitions)
|
||||
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])
|
||||
|
@ -40,6 +40,7 @@ config.parse_flags_with_absl()
|
||||
class ShardedJitTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ShardedJitTest, self).setUp()
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest
|
||||
|
||||
@ -192,25 +193,29 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
shape = (jax.local_device_count() * 2, 4)
|
||||
# Run computation across all devices so we know which devices to feed.
|
||||
parts = P(jax.local_device_count(), 1)
|
||||
infeed_shape = jax.ShapedArray(shape, np.float32)
|
||||
in_parts = parts if partition_input else None
|
||||
infeed_shapes = (jax.ShapedArray(shape, np.float32),
|
||||
jax.ShapedArray((1,), np.float32))
|
||||
infeed_parts = (parts, None)
|
||||
|
||||
@partial(sharded_jit, in_parts=in_parts, out_parts=None)
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
(y,), token = lax.infeed(token, (infeed_shape,), partitions=parts)
|
||||
return x @ y.T
|
||||
(y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts)
|
||||
return x @ y.T + z
|
||||
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
y = x + 1
|
||||
shard_size = shape[0] // jax.local_device_count()
|
||||
y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)]
|
||||
z = jnp.array([3.], dtype=np.float32)
|
||||
|
||||
result = f(x)
|
||||
for device, shard in zip(jax.local_devices(), y_shards):
|
||||
device.transfer_to_infeed((shard,))
|
||||
assert len(jax.local_devices()) == len(y_shards)
|
||||
for device, y_shard in zip(jax.local_devices(), y_shards):
|
||||
device.transfer_to_infeed((y_shard, z))
|
||||
|
||||
expected = x @ y.T
|
||||
expected = x @ y.T + z
|
||||
self.assertAllClose(result, expected, check_dtypes=False)
|
||||
|
||||
|
||||
@ -218,6 +223,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
class ShardedJitErrorsTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ShardedJitErrorsTest, self).setUp()
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest
|
||||
|
||||
@ -267,6 +273,7 @@ class ShardedJitTestNoTpu(jtu.JaxTestCase):
|
||||
class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(PmapOfShardedJitTest, self).setUp()
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user