mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14030 from gnecula:poly_vmap_error
PiperOrigin-RevId: 502546564
This commit is contained in:
commit
7ce9fa2f87
@ -1685,8 +1685,13 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
for x, d in zip(vals, dims)]
|
||||
size_counts = collections.Counter(s for s in all_sizes if s is not None)
|
||||
(sz, ct), *other_counts = counts = size_counts.most_common()
|
||||
ex, *examples = [key_paths[all_sizes.index(sz)] for sz, _ in counts]
|
||||
ax, *axs = [dims[all_sizes.index(sz)] for sz, _ in counts]
|
||||
def _all_sizes_index(sz):
|
||||
for i, isz in enumerate(all_sizes):
|
||||
if core.symbolic_equal_dim(isz, sz): return i
|
||||
assert False, (sz, all_sizes)
|
||||
|
||||
ex, *examples = [key_paths[_all_sizes_index(sz)] for sz, _ in counts]
|
||||
ax, *axs = [dims[_all_sizes_index(sz)] for sz, _ in counts]
|
||||
if ct == 1:
|
||||
msg.append(f" * one axis had size {sz}: axis {ax} of {ex};\n")
|
||||
else:
|
||||
|
@ -1461,6 +1461,21 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
expected_output_signature=tf.TensorSpec((None, 3), dtype=tf.float32)
|
||||
)
|
||||
|
||||
def test_vmap_error(self):
|
||||
# vmap is careful to give nice error messages when mapped axes have
|
||||
# different sizes, but this can be foiled by InconsistentDimensionOperation
|
||||
x = y = np.ones((3, 5), dtype=np.float32)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped"):
|
||||
jax2tf.convert(jax.vmap(lambda x, y: x + y),
|
||||
polymorphic_shapes=["b, ...", None])(x, y)
|
||||
|
||||
z = x
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"vmap got inconsistent sizes for array axes to be mapped"):
|
||||
jax2tf.convert(jax.vmap(lambda x, y, z: x + y + z),
|
||||
polymorphic_shapes=["b, ...", "c, ...", None])(x, y, z)
|
||||
|
||||
def test_reshape_compiled(self):
|
||||
# We compile the result of conversion for two shapes, hence we need to
|
||||
# involve the TF compiler twice, but we trace only once with shape polymorphism
|
||||
|
Loading…
x
Reference in New Issue
Block a user