[shape_poly] Improve error message from vmap axis size inconsistency

vmap tries hard to give nice error messages when the mapped axes
for different arguments have different sizes, but the code to
compute the error message can run into InconsistentDimensionOperation
in presence of dimension polynomials. Ensure that the comparisons
are done symbolically.
This commit is contained in:
George Necula 2023-01-17 10:42:20 +02:00
parent 469a8eb520
commit cf4e568e21
2 changed files with 22 additions and 2 deletions

View File

@ -1685,8 +1685,13 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
for x, d in zip(vals, dims)]
size_counts = collections.Counter(s for s in all_sizes if s is not None)
(sz, ct), *other_counts = counts = size_counts.most_common()
ex, *examples = [key_paths[all_sizes.index(sz)] for sz, _ in counts]
ax, *axs = [dims[all_sizes.index(sz)] for sz, _ in counts]
def _all_sizes_index(sz):
for i, isz in enumerate(all_sizes):
if core.symbolic_equal_dim(isz, sz): return i
assert False, (sz, all_sizes)
ex, *examples = [key_paths[_all_sizes_index(sz)] for sz, _ in counts]
ax, *axs = [dims[_all_sizes_index(sz)] for sz, _ in counts]
if ct == 1:
msg.append(f" * one axis had size {sz}: axis {ax} of {ex};\n")
else:

View File

@ -1461,6 +1461,21 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
expected_output_signature=tf.TensorSpec((None, 3), dtype=tf.float32)
)
def test_vmap_error(self):
# vmap is careful to give nice error messages when mapped axes have
# different sizes, but this can be foiled by InconsistentDimensionOperation
x = y = np.ones((3, 5), dtype=np.float32)
with self.assertRaisesRegex(ValueError,
"vmap got inconsistent sizes for array axes to be mapped"):
jax2tf.convert(jax.vmap(lambda x, y: x + y),
polymorphic_shapes=["b, ...", None])(x, y)
z = x
with self.assertRaisesRegex(ValueError,
"vmap got inconsistent sizes for array axes to be mapped"):
jax2tf.convert(jax.vmap(lambda x, y, z: x + y + z),
polymorphic_shapes=["b, ...", "c, ...", None])(x, y, z)
def test_reshape_compiled(self):
# We compile the result of conversion for two shapes, hence we need to
# involve the TF compiler twice, but we trace only once with shape polymorphism