Fix handling of infeed token inside sharded_jit (#3313)

This commit is contained in:
Skye Wanderman-Milne 2020-06-03 15:23:49 -07:00 committed by GitHub
parent c77c0838fe
commit 5ad9feda5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 7 deletions

View File

@ -5118,7 +5118,11 @@ def infeed(token, shape=None, partitions=None):
"ShapedArray values, got {}".format(shape))
if partitions is not None:
# Always replicate token.
partitions = (partitions, None)
# We specifically use type() to raise an error for PartitionSpecs.
if type(partitions) != tuple: # pylint: disable=unidiomatic-typecheck
raise ValueError(f"'partitions' argument to infeed should be a tuple, "
f"got {partitions}")
partitions = partitions + (None,)
xs_and_token = infeed_p.bind(token, shapes=tuple(flat_shapes),
partitions=partitions)
return (treedef.unflatten(xs_and_token[:-1]), xs_and_token[-1])

View File

@ -40,6 +40,7 @@ config.parse_flags_with_absl()
class ShardedJitTest(jtu.JaxTestCase):
def setUp(self):
super(ShardedJitTest, self).setUp()
if jtu.device_under_test() != "tpu":
raise SkipTest
@ -192,25 +193,29 @@ class ShardedJitTest(jtu.JaxTestCase):
shape = (jax.local_device_count() * 2, 4)
# Run computation across all devices so we know which devices to feed.
parts = P(jax.local_device_count(), 1)
infeed_shape = jax.ShapedArray(shape, np.float32)
in_parts = parts if partition_input else None
infeed_shapes = (jax.ShapedArray(shape, np.float32),
jax.ShapedArray((1,), np.float32))
infeed_parts = (parts, None)
@partial(sharded_jit, in_parts=in_parts, out_parts=None)
def f(x):
token = lax.create_token(x)
(y,), token = lax.infeed(token, (infeed_shape,), partitions=parts)
return x @ y.T
(y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts)
return x @ y.T + z
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
y = x + 1
shard_size = shape[0] // jax.local_device_count()
y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)]
z = jnp.array([3.], dtype=np.float32)
result = f(x)
for device, shard in zip(jax.local_devices(), y_shards):
device.transfer_to_infeed((shard,))
assert len(jax.local_devices()) == len(y_shards)
for device, y_shard in zip(jax.local_devices(), y_shards):
device.transfer_to_infeed((y_shard, z))
expected = x @ y.T
expected = x @ y.T + z
self.assertAllClose(result, expected, check_dtypes=False)
@ -218,6 +223,7 @@ class ShardedJitTest(jtu.JaxTestCase):
class ShardedJitErrorsTest(jtu.JaxTestCase):
def setUp(self):
super(ShardedJitErrorsTest, self).setUp()
if jtu.device_under_test() != "tpu":
raise SkipTest
@ -267,6 +273,7 @@ class ShardedJitTestNoTpu(jtu.JaxTestCase):
class PmapOfShardedJitTest(jtu.JaxTestCase):
def setUp(self):
super(PmapOfShardedJitTest, self).setUp()
if jtu.device_under_test() != "tpu":
raise SkipTest