Merge pull request #26895 from hawkinsp:pad2

PiperOrigin-RevId: 733034756
This commit is contained in:
jax authors 2025-03-03 13:15:07 -08:00
commit 439c412cd4
2 changed files with 3 additions and 7 deletions

View File

@ -3986,12 +3986,7 @@ def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array
if constant_values.shape[-1] == 1:
widths = [(low, high, 0) for (low, high) in pad_width]
return lax.pad(array, squeeze(constant_values), widths)
elif constant_values.shape[-1] == 2:
widths = [(low, 0, 0) for (low, _) in pad_width]
array = lax.pad(array, constant_values[0], widths)
widths = [(0, high, 0) for (_, high) in pad_width]
return lax.pad(array, constant_values[1], widths)
else:
elif constant_values.shape[-1] != 2:
raise ValueError("jnp.pad: constant_values has unsupported shape "
f"{constant_values.shape}. If the shape is 1D or 2D, the "
"last dimension must be of size 1 or 2.")

View File

@ -1143,7 +1143,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# (constant,)
(0,), (2.718,),
# ((before_const, after_const),)
((0, 2),), ((-1, 3.14),),
(0, 2),
(-1, 3.14),
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i / 2, -3.14 * i) for i in range(len(shape))),
]