mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order. We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths. Fixes https://github.com/jax-ml/jax/issues/26888
This commit is contained in:
parent
5179642eb5
commit
7f05b74bca
@ -3986,12 +3986,7 @@ def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array
|
||||
if constant_values.shape[-1] == 1:
|
||||
widths = [(low, high, 0) for (low, high) in pad_width]
|
||||
return lax.pad(array, squeeze(constant_values), widths)
|
||||
elif constant_values.shape[-1] == 2:
|
||||
widths = [(low, 0, 0) for (low, _) in pad_width]
|
||||
array = lax.pad(array, constant_values[0], widths)
|
||||
widths = [(0, high, 0) for (_, high) in pad_width]
|
||||
return lax.pad(array, constant_values[1], widths)
|
||||
else:
|
||||
elif constant_values.shape[-1] != 2:
|
||||
raise ValueError("jnp.pad: constant_values has unsupported shape "
|
||||
f"{constant_values.shape}. If the shape is 1D or 2D, the "
|
||||
"last dimension must be of size 1 or 2.")
|
||||
|
@ -1143,7 +1143,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# (constant,)
|
||||
(0,), (2.718,),
|
||||
# ((before_const, after_const),)
|
||||
((0, 2),), ((-1, 3.14),),
|
||||
(0, 2),
|
||||
(-1, 3.14),
|
||||
# ((before_1, after_1), ..., (before_N, after_N))
|
||||
tuple((i / 2, -3.14 * i) for i in range(len(shape))),
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user