mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #22405 from gnecula:poly_pad
PiperOrigin-RevId: 652511693
This commit is contained in:
commit
5216719996
@ -886,8 +886,10 @@ class _DimExpr:
|
||||
if config.enable_checks.value:
|
||||
v1 = divisor * quotient
|
||||
v2 = v1 + remainder
|
||||
assert self == v2, (self, v2, type(self), type(v2))
|
||||
assert self == divisor * quotient + remainder, (self, divisor, quotient, remainder)
|
||||
assert self == _ensure_poly(v2, "check", self.scope), (
|
||||
self, v2, type(self), type(v2))
|
||||
assert self == _ensure_poly(divisor * quotient + remainder, "test", self.scope), (
|
||||
self, divisor, quotient, remainder)
|
||||
return quotient, remainder
|
||||
except InconclusiveDimensionOperation:
|
||||
return (_DimExpr._from_operation(_DimFactor.FLOORDIV, self, divisor,
|
||||
|
@ -2765,13 +2765,14 @@ def _pad_wrap(array: Array, pad_width: PadValue[int]) -> Array:
|
||||
_check_no_padding(pad_width[i], "wrap")
|
||||
continue
|
||||
size = array.shape[i]
|
||||
repeats, (left_remainder, right_remainder) = np.divmod(pad_width[i], size)
|
||||
total_repeats = repeats.sum() + 1
|
||||
left_repeats, left_remainder = divmod(pad_width[i][0], size)
|
||||
right_repeats, right_remainder = divmod(pad_width[i][1], size)
|
||||
total_repeats = left_repeats + right_repeats + 1
|
||||
parts = []
|
||||
if left_remainder:
|
||||
if left_remainder > 0:
|
||||
parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)]
|
||||
parts += total_repeats * [array]
|
||||
if right_remainder:
|
||||
if right_remainder > 0:
|
||||
parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)]
|
||||
array = lax.concatenate(parts, dimension=i)
|
||||
return array
|
||||
@ -2787,8 +2788,7 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
|
||||
_check_no_padding(pad_width[i], mode)
|
||||
continue
|
||||
|
||||
n = array.shape[i]
|
||||
offset = 1 if (mode == "reflect" and n > 1) else 0
|
||||
axis_size = array.shape[i]
|
||||
|
||||
def build_padding(array, padding, before):
|
||||
if before:
|
||||
@ -2796,23 +2796,41 @@ def _pad_symmetric_or_reflect(array: Array, pad_width: PadValue[int],
|
||||
else:
|
||||
edge = lax.slice_in_dim(array, -1, None, axis=i)
|
||||
|
||||
# Try to give nicer error messages for unsupported shape polymorphic uses
|
||||
shape_poly_error_msg = lambda: (
|
||||
"Shape polymorphism is supported for jnp.pad with 'reflect' or "
|
||||
"'symmetric' padding mode only when it is possible to determine "
|
||||
f"at lowering time that the axis size (= {axis_size}) is larger than 1 "
|
||||
f"and larger or equal than the padding length (= {padding}). "
|
||||
f"Error while handling {'left' if before else 'right'} padding on axis {i}.")
|
||||
try:
|
||||
# We check that we can determine all comparisions.
|
||||
offset = 1 if (mode == "reflect" and axis_size > 1) else 0
|
||||
has_poly_dim = not core.is_constant_shape((axis_size, padding))
|
||||
# For shape polymorphism, ensure the loop below ends after 1 iteration
|
||||
if has_poly_dim and not (axis_size > 1 and axis_size - offset >= padding):
|
||||
raise ValueError(shape_poly_error_msg())
|
||||
except core.InconclusiveDimensionOperation as e:
|
||||
raise ValueError(shape_poly_error_msg()) from e
|
||||
|
||||
while padding > 0:
|
||||
curr_pad = min(padding, n - offset)
|
||||
curr_pad = min(padding, axis_size - offset)
|
||||
padding -= curr_pad
|
||||
if has_poly_dim: assert padding == 0
|
||||
|
||||
if before:
|
||||
start = offset
|
||||
stop = offset + curr_pad
|
||||
else:
|
||||
start = -(curr_pad + offset)
|
||||
stop = None if (mode == "symmetric" or n == 1) else -1
|
||||
stop = None if (mode == "symmetric" or axis_size == 1) else -1
|
||||
|
||||
x = lax.slice_in_dim(array, start, stop, axis=i)
|
||||
x = flip(x, axis=i)
|
||||
|
||||
if reflect_type == 'odd':
|
||||
x = 2 * edge - x
|
||||
if n > 1:
|
||||
if axis_size > 1:
|
||||
if before:
|
||||
edge = lax.slice_in_dim(x, 0, 1, axis=i)
|
||||
else:
|
||||
@ -4308,7 +4326,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: bool = False,
|
||||
dtype: DTypeLike | None = None,
|
||||
axis: int = 0) -> Array | tuple[Array, Array]:
|
||||
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
|
||||
num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace")
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
|
||||
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
|
||||
|
||||
@ -4337,13 +4355,13 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
bounds_shape.insert(axis, 1)
|
||||
div = (num - 1) if endpoint else num
|
||||
if num > 1:
|
||||
delta: Array = lax.convert_element_type(stop - start, computation_dtype) / div
|
||||
delta: Array = lax.convert_element_type(stop - start, computation_dtype) / array(div, dtype=computation_dtype)
|
||||
iota_shape = [1,] * len(bounds_shape)
|
||||
iota_shape[axis] = div
|
||||
# This approach recovers the endpoints with float32 arithmetic,
|
||||
# but can lead to rounding errors for integer outputs.
|
||||
real_dtype = finfo(computation_dtype).dtype
|
||||
step = reshape(lax.iota(real_dtype, div), iota_shape) / div
|
||||
step = reshape(lax.iota(real_dtype, div), iota_shape) / array(div, real_dtype)
|
||||
step = step.astype(computation_dtype)
|
||||
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
|
||||
reshape(broadcast_stop, bounds_shape) * step)
|
||||
@ -4355,7 +4373,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
elif num == 1:
|
||||
delta = asarray(nan if endpoint else stop - start, dtype=computation_dtype)
|
||||
out = reshape(broadcast_start, bounds_shape)
|
||||
else: # num == 0 degenerate case, match numpy behavior
|
||||
else: # num == 0 degenerate case, match numpy behavior
|
||||
empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
|
||||
empty_shape.insert(axis, 0)
|
||||
delta = asarray(nan, dtype=computation_dtype)
|
||||
|
@ -697,6 +697,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
(3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b, 0),
|
||||
(a * a - b * b, a + b, a - b, 0),
|
||||
(256 * a * b, 32, 8 * a * b, 0),
|
||||
(0, b, 0, 0),
|
||||
(a, b, "floordiv(a, b)", "mod(a, b)"),
|
||||
(3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"),
|
||||
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, b + a)", "mod(2*a*b + b^2, b + a)"),
|
||||
@ -2532,6 +2533,15 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: x + lax.iota(_f32, x.shape[0]),
|
||||
arg_descriptors=[RandArg((3,), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("linspace", "",
|
||||
lambda x: jnp.linspace(0, x.shape[0], 4),
|
||||
arg_descriptors=[RandArg((30,), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("linspace", "num_poly",
|
||||
lambda x: jnp.linspace(0, 100, x.shape[0]),
|
||||
arg_descriptors=[RandArg((30,), _f32)],
|
||||
polymorphic_shapes=["b, ..."],
|
||||
symbolic_constraints=["b >= 2"]),
|
||||
PolyHarness("matmul", "0",
|
||||
jnp.matmul,
|
||||
arg_descriptors=[RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)],
|
||||
@ -2613,6 +2623,58 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
mode="edge"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("jnp.pad", "mode=maximum",
|
||||
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
|
||||
mode="maximum"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("jnp.pad", "mode=maximum_stat_length=b",
|
||||
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
|
||||
mode="maximum", stat_length=((x.shape[0] // 2, 2), (2, 2))),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."],
|
||||
symbolic_constraints=["b >= 2"]),
|
||||
PolyHarness("jnp.pad", "mode=linear_ramp",
|
||||
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
|
||||
mode="linear_ramp"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."],
|
||||
symbolic_constraints=["b >= 2"]),
|
||||
PolyHarness("jnp.pad", "mode=reflect_odd",
|
||||
lambda x: jnp.pad(x, [[x.shape[0] - 1, 0], [x.shape[1], 1]],
|
||||
mode="reflect", reflect_type="odd"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."],
|
||||
symbolic_constraints=["b >= 2"]),
|
||||
PolyHarness("jnp.pad", "mode=reflect_odd_error",
|
||||
lambda x: jnp.pad(x, [[x.shape[0] - 1, 0], [x.shape[1], 1]],
|
||||
mode="reflect", reflect_type="odd"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."],
|
||||
expect_error=(ValueError, "Shape polymorphism is supported for jnp.pad")),
|
||||
PolyHarness("jnp.pad", "mode=reflect_even",
|
||||
lambda x: jnp.pad(x, [[x.shape[0] - 1, 0], [x.shape[1], 1]],
|
||||
mode="reflect", reflect_type="even"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."],
|
||||
symbolic_constraints=["b >= 2"]),
|
||||
PolyHarness("jnp.pad", "mode=symmetric_odd",
|
||||
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
|
||||
mode="symmetric", reflect_type="odd"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."],
|
||||
symbolic_constraints=["b >= 2"]),
|
||||
PolyHarness("jnp.pad", "mode=symmetric_even",
|
||||
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
|
||||
mode="symmetric", reflect_type="even"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."],
|
||||
symbolic_constraints=["b >= 2"]),
|
||||
PolyHarness("jnp.pad", "mode=wrap",
|
||||
lambda x: jnp.pad(x, [[x.shape[0], 0], [x.shape[1], 1]],
|
||||
mode="wrap"),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("percentile", "axis=None",
|
||||
lambda x: jnp.percentile(x, 50, axis=None),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user