From 7f05b74bca207855f58ea5adb12b3c8ef3a6fc51 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 3 Mar 2025 15:16:21 -0500 Subject: [PATCH] Fix wrong results in multidimensional pad. When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order. We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths. Fixes https://github.com/jax-ml/jax/issues/26888 --- jax/_src/numpy/lax_numpy.py | 7 +------ tests/lax_numpy_test.py | 3 ++- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a50576720..b2b828cf3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3986,12 +3986,7 @@ def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array if constant_values.shape[-1] == 1: widths = [(low, high, 0) for (low, high) in pad_width] return lax.pad(array, squeeze(constant_values), widths) - elif constant_values.shape[-1] == 2: - widths = [(low, 0, 0) for (low, _) in pad_width] - array = lax.pad(array, constant_values[0], widths) - widths = [(0, high, 0) for (_, high) in pad_width] - return lax.pad(array, constant_values[1], widths) - else: + elif constant_values.shape[-1] != 2: raise ValueError("jnp.pad: constant_values has unsupported shape " f"{constant_values.shape}. If the shape is 1D or 2D, the " "last dimension must be of size 1 or 2.") diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 45698d3b1..09a43dafe 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1143,7 +1143,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # (constant,) (0,), (2.718,), # ((before_const, after_const),) - ((0, 2),), ((-1, 3.14),), + (0, 2), + (-1, 3.14), # ((before_1, after_1), ..., (before_N, after_N)) tuple((i / 2, -3.14 * i) for i in range(len(shape))), ]