Fix wrong results in multidimensional pad.

When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.

We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.

Fixes https://github.com/jax-ml/jax/issues/26888
This commit is contained in:
Peter Hawkins 2025-03-03 15:16:21 -05:00
parent 5179642eb5
commit 7f05b74bca
2 changed files with 3 additions and 7 deletions

View File

@ -3986,12 +3986,7 @@ def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array
if constant_values.shape[-1] == 1:
widths = [(low, high, 0) for (low, high) in pad_width]
return lax.pad(array, squeeze(constant_values), widths)
elif constant_values.shape[-1] == 2:
widths = [(low, 0, 0) for (low, _) in pad_width]
array = lax.pad(array, constant_values[0], widths)
widths = [(0, high, 0) for (_, high) in pad_width]
return lax.pad(array, constant_values[1], widths)
else:
elif constant_values.shape[-1] != 2:
raise ValueError("jnp.pad: constant_values has unsupported shape "
f"{constant_values.shape}. If the shape is 1D or 2D, the "
"last dimension must be of size 1 or 2.")

View File

@ -1143,7 +1143,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# (constant,)
(0,), (2.718,),
# ((before_const, after_const),)
((0, 2),), ((-1, 3.14),),
(0, 2),
(-1, 3.14),
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i / 2, -3.14 * i) for i in range(len(shape))),
]