mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Extend the handling of jnp.arange with shape polymorphism.
Previously, only `arange(stop, dtype=...)` was being handled in presence of shape polymorphism. Here we extend to add support for `start` and `step` to be also present. There are still plenty of restrictions: * no floating point constants are allowed among start, stop and step * we must resolve statically if step is positive or negative * we must resolve statically if the distance between start and stop is negative or positive.
This commit is contained in:
parent
76b922aade
commit
c368c69625
@ -2232,12 +2232,13 @@ def arange(start: DimSize, stop: Optional[DimSize] = None,
|
||||
if val is not None and np.ndim(val) != 0:
|
||||
raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}")
|
||||
if _any(core.is_special_dim_size(d) for d in (start, stop, step)):
|
||||
if stop is not None or step is not None:
|
||||
raise ValueError(
|
||||
"jax.numpy.arange supports non-constant arguments only in "
|
||||
"single-argument form. Found "
|
||||
f"jax.numpy.arange({start=}, {stop=}, {step=})")
|
||||
return lax.iota(dtype or int_, start)
|
||||
if stop is None and step is None:
|
||||
stop = start
|
||||
start = 0
|
||||
step = 1
|
||||
elif stop is not None and step is None:
|
||||
step = 1
|
||||
return _arange_dynamic(start, stop, step, dtype or int_)
|
||||
if dtype is None:
|
||||
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
|
||||
dtype = _jnp_dtype(dtype)
|
||||
@ -2255,6 +2256,46 @@ def arange(start: DimSize, stop: Optional[DimSize] = None,
|
||||
return array(np.arange(start, stop=stop, step=step, dtype=dtype))
|
||||
|
||||
|
||||
def _arange_dynamic(
|
||||
start: DimSize, stop: DimSize, step: DimSize, dtype: DTypeLike) -> Array:
|
||||
# Here if at least one of start, stop, step are dynamic.
|
||||
if any([(not core.is_special_dim_size(v) and not isinstance(v, int))
|
||||
for v in (start, stop, step)]):
|
||||
raise ValueError(
|
||||
"In arange with non-constant arguments all of start, stop, and step "
|
||||
f"must be either dimension expressions or integers: start={start}, "
|
||||
f"stop={stop}, step={step}")
|
||||
# Must resolve statically if step is {<0, ==0, >0}
|
||||
try:
|
||||
if step == 0:
|
||||
raise ValueError("arange has step == 0")
|
||||
step_gt_0 = (step > 0)
|
||||
except core.InconclusiveDimensionOperation as e:
|
||||
raise core.InconclusiveDimensionOperation(
|
||||
f"In arange with non-constant arguments the step ({step}) must " +
|
||||
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
|
||||
gap = step if step_gt_0 else - step
|
||||
distance = (stop - start) if step_gt_0 else (start - stop)
|
||||
try:
|
||||
if distance >= 1 - gap:
|
||||
size = (distance + gap - 1) // gap
|
||||
else:
|
||||
size = 0
|
||||
except core.InconclusiveDimensionOperation:
|
||||
# Cannot resolve "distance >= 1 - gap". Perhaps we can resolve "distance >= 1"
|
||||
try:
|
||||
if distance >= 1:
|
||||
assert False
|
||||
else:
|
||||
size = 0
|
||||
except core.InconclusiveDimensionOperation:
|
||||
raise core.InconclusiveDimensionOperation(
|
||||
"In arange with non-constant dimensions the distance between "
|
||||
f"start ({start}) and stop ({stop}) must be resolved statically "
|
||||
f"if it is >= {1 - gap} or >= 1.")
|
||||
return (array(start, dtype=dtype) +
|
||||
array(step, dtype=dtype) * lax.iota(dtype, size))
|
||||
|
||||
@overload
|
||||
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
endpoint: bool = True, retstep: Literal[False] = False,
|
||||
|
@ -635,13 +635,71 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=PS("h", "h"),
|
||||
expected_output_signature=tf.TensorSpec([None, None]))
|
||||
|
||||
def test_arange(self):
|
||||
def f_jax(x):
|
||||
return x + jnp.arange(x.shape[0], dtype=np.float32)
|
||||
x = np.ones((3,), dtype=np.float32)
|
||||
@parameterized.named_parameters([
|
||||
dict(testcase_name=f"_{name}",
|
||||
make_args=make_args)
|
||||
for name, make_args in [
|
||||
# make_args invoked with op.shape[0]: start, stop, step, dtype
|
||||
("b", lambda b: (b, None, None, None)),
|
||||
("0_b+1", lambda b: (0, b + 1, None, None)),
|
||||
("0_5b_2", lambda b: (0, 5 * b, 2, None)),
|
||||
("0_5b+1_2", lambda b: (0, 5 * b + 1, 2, None)),
|
||||
("b_5b+2_2", lambda b: (b, 5 * b + 2, 2, None)),
|
||||
("0_b-1_2", lambda b: (0, b - 1, 2, None)),
|
||||
("0_b-2_2", lambda b: (0, b - 2, 2, None)),
|
||||
("0_-b_2", lambda b: (0, -b, 2, None)),
|
||||
("0_1-b_2", lambda b: (0, 1 - b, 2, None)),
|
||||
# Negative step
|
||||
("b_0_-1", lambda b: (b, 0, -1, None)),
|
||||
("b_1_-2", lambda b: (b, 1, -2, None)),
|
||||
("b_-1_-1", lambda b: (b, -1, -1, None)),
|
||||
("5b+1_0_-2", lambda b: (5 * b + 1, 0, -2, None)),
|
||||
("5b+2_0_-2", lambda b: (5 * b + 2, 0, -2, None)),
|
||||
# Symbolic step
|
||||
("0_10_b", lambda b: (0, 10, b)),
|
||||
("0_0_b", lambda b: (0, 0, b)),
|
||||
("10_0_-b", lambda b: (10, 0, -b)),
|
||||
("b_1_-b", lambda b: (b, 1, -b)),
|
||||
# Float return type
|
||||
("0_b_1_f32", lambda b: (0, b, 1, np.float32))
|
||||
]
|
||||
])
|
||||
def test_arange(self, make_args=lambda b: (0, 0, b)):
|
||||
def f_jax(x): # x: i32[b]
|
||||
return x[0] + jnp.arange(*(make_args(x.shape[0])))
|
||||
x = np.ones((3,), dtype=np.int32)
|
||||
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="b")(x),
|
||||
f_jax(x))
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(testcase_name=f"_{name}",
|
||||
make_args=make_args,
|
||||
expect_error=expect_error, expect_msg=expect_msg)
|
||||
for name, make_args, expect_error, expect_msg in [
|
||||
# make_args invoked with op.shape[0]: start, stop, step, dtype
|
||||
("float_start", lambda b: (0., b, None),
|
||||
ValueError, "must be either dimension expressions or integers"),
|
||||
("float_step", lambda b: (0, b, 0.5),
|
||||
ValueError, "must be either dimension expressions or integers"),
|
||||
("step_0", lambda b: (0, b, 0),
|
||||
ValueError, "has step == 0"),
|
||||
("inconclusive_step_sign", lambda b: (0, b, b - 2),
|
||||
core.InconclusiveDimensionOperation,
|
||||
"must be resolved statically if it is > 0 or < 0"),
|
||||
("inconclusive_distance", lambda b: (0, b - 3, 2),
|
||||
core.InconclusiveDimensionOperation,
|
||||
"must be resolved statically if it is >= -1 or >= 1"),
|
||||
]
|
||||
])
|
||||
def test_arange_error(self, make_args=lambda b: (0., b, 2),
|
||||
expect_error=ValueError,
|
||||
expect_msg="must be either dimension expressions or integers"):
|
||||
def f_jax(x): # x: i32[b]
|
||||
return x[0] + jnp.arange(*(make_args(x.shape[0])))
|
||||
x = np.ones((3,), dtype=np.int32)
|
||||
with self.assertRaisesRegex(expect_error, expect_msg):
|
||||
jax2tf.convert(f_jax, polymorphic_shapes="b")(x)
|
||||
|
||||
def test_argmax(self):
|
||||
def f_jax(x): # x: f32[b, 4, 5]
|
||||
return lax.argmax(x, axis=1, index_dtype=np.int32)
|
||||
@ -1687,29 +1745,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
PolyHarness("arange", "start",
|
||||
lambda op: jnp.arange(2 * op.shape[0], dtype=_f32) + op[0],
|
||||
arg_descriptors=[RandArg((3,), _f32)],
|
||||
poly_axes=[0]).both_enable_and_disable_xla(),
|
||||
PolyHarness("arange", "start_no_dtype",
|
||||
lambda op: jnp.arange(op.shape[0]) + op[0],
|
||||
arg_descriptors=[RandArg((3,), _f32)],
|
||||
poly_axes=[0]),
|
||||
PolyHarness("arange", "error1",
|
||||
lambda op: jnp.arange(op.shape[0], 10),
|
||||
arg_descriptors=[RandArg((3,), _f32)],
|
||||
poly_axes=[0],
|
||||
expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")),
|
||||
PolyHarness("arange", "error2",
|
||||
lambda op: jnp.arange(1, op.shape[0]),
|
||||
arg_descriptors=[RandArg((3,), _f32)],
|
||||
poly_axes=[0],
|
||||
expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")),
|
||||
PolyHarness("arange", "error3",
|
||||
lambda op: jnp.arange(1, 5, op.shape[0]),
|
||||
arg_descriptors=[RandArg((3,), _f32)],
|
||||
poly_axes=[0],
|
||||
expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")),
|
||||
# Reduce the poly dimension
|
||||
PolyHarness("argmax", "0",
|
||||
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user