mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26895 from hawkinsp:pad2
PiperOrigin-RevId: 733034756
This commit is contained in:
commit
439c412cd4
@ -3986,12 +3986,7 @@ def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array
|
||||
if constant_values.shape[-1] == 1:
|
||||
widths = [(low, high, 0) for (low, high) in pad_width]
|
||||
return lax.pad(array, squeeze(constant_values), widths)
|
||||
elif constant_values.shape[-1] == 2:
|
||||
widths = [(low, 0, 0) for (low, _) in pad_width]
|
||||
array = lax.pad(array, constant_values[0], widths)
|
||||
widths = [(0, high, 0) for (_, high) in pad_width]
|
||||
return lax.pad(array, constant_values[1], widths)
|
||||
else:
|
||||
elif constant_values.shape[-1] != 2:
|
||||
raise ValueError("jnp.pad: constant_values has unsupported shape "
|
||||
f"{constant_values.shape}. If the shape is 1D or 2D, the "
|
||||
"last dimension must be of size 1 or 2.")
|
||||
|
@ -1143,7 +1143,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# (constant,)
|
||||
(0,), (2.718,),
|
||||
# ((before_const, after_const),)
|
||||
((0, 2),), ((-1, 3.14),),
|
||||
(0, 2),
|
||||
(-1, 3.14),
|
||||
# ((before_1, after_1), ..., (before_N, after_N))
|
||||
tuple((i / 2, -3.14 * i) for i in range(len(shape))),
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user