Merge pull request #14030 from gnecula:poly_vmap_error

PiperOrigin-RevId: 502546564
This commit is contained in:
jax authors 2023-01-17 04:19:16 -08:00
commit 7ce9fa2f87
2 changed files with 22 additions and 2 deletions

View File

@ -1685,8 +1685,13 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
for x, d in zip(vals, dims)]
size_counts = collections.Counter(s for s in all_sizes if s is not None)
(sz, ct), *other_counts = counts = size_counts.most_common()
ex, *examples = [key_paths[all_sizes.index(sz)] for sz, _ in counts]
ax, *axs = [dims[all_sizes.index(sz)] for sz, _ in counts]
def _all_sizes_index(sz):
for i, isz in enumerate(all_sizes):
if core.symbolic_equal_dim(isz, sz): return i
assert False, (sz, all_sizes)
ex, *examples = [key_paths[_all_sizes_index(sz)] for sz, _ in counts]
ax, *axs = [dims[_all_sizes_index(sz)] for sz, _ in counts]
if ct == 1:
msg.append(f" * one axis had size {sz}: axis {ax} of {ex};\n")
else:

View File

@ -1461,6 +1461,21 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
expected_output_signature=tf.TensorSpec((None, 3), dtype=tf.float32)
)
def test_vmap_error(self):
# vmap is careful to give nice error messages when mapped axes have
# different sizes, but this can be foiled by InconsistentDimensionOperation
x = y = np.ones((3, 5), dtype=np.float32)
with self.assertRaisesRegex(ValueError,
"vmap got inconsistent sizes for array axes to be mapped"):
jax2tf.convert(jax.vmap(lambda x, y: x + y),
polymorphic_shapes=["b, ...", None])(x, y)
z = x
with self.assertRaisesRegex(ValueError,
"vmap got inconsistent sizes for array axes to be mapped"):
jax2tf.convert(jax.vmap(lambda x, y, z: x + y + z),
polymorphic_shapes=["b, ...", "c, ...", None])(x, y, z)
def test_reshape_compiled(self):
# We compile the result of conversion for two shapes, hence we need to
# involve the TF compiler twice, but we trace only once with shape polymorphism