mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Improve error message from vmap axis size inconsistency
vmap tries hard to give nice error messages when the mapped axes for different arguments have different sizes, but the code to compute the error message can run into InconsistentDimensionOperation in presence of dimension polynomials. Ensure that the comparisons are done symbolically.
This commit is contained in:
parent
469a8eb520
commit
cf4e568e21
@ -1685,8 +1685,13 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
for x, d in zip(vals, dims)]
|
||||
size_counts = collections.Counter(s for s in all_sizes if s is not None)
|
||||
(sz, ct), *other_counts = counts = size_counts.most_common()
|
||||
ex, *examples = [key_paths[all_sizes.index(sz)] for sz, _ in counts]
|
||||
ax, *axs = [dims[all_sizes.index(sz)] for sz, _ in counts]
|
||||
def _all_sizes_index(sz):
|
||||
for i, isz in enumerate(all_sizes):
|
||||
if core.symbolic_equal_dim(isz, sz): return i
|
||||
assert False, (sz, all_sizes)
|
||||
|
||||
ex, *examples = [key_paths[_all_sizes_index(sz)] for sz, _ in counts]
|
||||
ax, *axs = [dims[_all_sizes_index(sz)] for sz, _ in counts]
|
||||
if ct == 1:
|
||||
msg.append(f" * one axis had size {sz}: axis {ax} of {ex};\n")
|
||||
else:
|
||||
|
@ -1461,6 +1461,21 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
expected_output_signature=tf.TensorSpec((None, 3), dtype=tf.float32)
|
||||
)
|
||||
|
||||
def test_vmap_error(self):
|
||||
# vmap is careful to give nice error messages when mapped axes have
|
||||
# different sizes, but this can be foiled by InconsistentDimensionOperation
|
||||
x = y = np.ones((3, 5), dtype=np.float32)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped"):
|
||||
jax2tf.convert(jax.vmap(lambda x, y: x + y),
|
||||
polymorphic_shapes=["b, ...", None])(x, y)
|
||||
|
||||
z = x
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped"):
|
||||
jax2tf.convert(jax.vmap(lambda x, y, z: x + y + z),
|
||||
polymorphic_shapes=["b, ...", "c, ...", None])(x, y, z)
|
||||
|
||||
def test_reshape_compiled(self):
|
||||
# We compile the result of conversion for two shapes, hence we need to
|
||||
# involve the TF compiler twice, but we trace only once with shape polymorphism
|
||||
|
Loading…
x
Reference in New Issue
Block a user