[shape_poly] Extend the handling of jnp.arange with shape polymorphism.

Previously, only `arange(stop, dtype=...)` was being handled in presence
of shape polymorphism. Here we extend to add support for `start` and `step`
to be also present. There are still plenty of restrictions:

   * no floating point constants are allowed among start, stop and step
   * we must resolve statically if step is positive or negative
   * we must resolve statically if the distance between start and stop
     is negative or positive.
This commit is contained in:
George Necula 2023-03-27 09:27:10 +02:00
parent 76b922aade
commit c368c69625
2 changed files with 109 additions and 33 deletions

View File

@ -2232,12 +2232,13 @@ def arange(start: DimSize, stop: Optional[DimSize] = None,
if val is not None and np.ndim(val) != 0:
raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}")
if _any(core.is_special_dim_size(d) for d in (start, stop, step)):
if stop is not None or step is not None:
raise ValueError(
"jax.numpy.arange supports non-constant arguments only in "
"single-argument form. Found "
f"jax.numpy.arange({start=}, {stop=}, {step=})")
return lax.iota(dtype or int_, start)
if stop is None and step is None:
stop = start
start = 0
step = 1
elif stop is not None and step is None:
step = 1
return _arange_dynamic(start, stop, step, dtype or int_)
if dtype is None:
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
dtype = _jnp_dtype(dtype)
@ -2255,6 +2256,46 @@ def arange(start: DimSize, stop: Optional[DimSize] = None,
return array(np.arange(start, stop=stop, step=step, dtype=dtype))
def _arange_dynamic(
start: DimSize, stop: DimSize, step: DimSize, dtype: DTypeLike) -> Array:
# Here if at least one of start, stop, step are dynamic.
if any([(not core.is_special_dim_size(v) and not isinstance(v, int))
for v in (start, stop, step)]):
raise ValueError(
"In arange with non-constant arguments all of start, stop, and step "
f"must be either dimension expressions or integers: start={start}, "
f"stop={stop}, step={step}")
# Must resolve statically if step is {<0, ==0, >0}
try:
if step == 0:
raise ValueError("arange has step == 0")
step_gt_0 = (step > 0)
except core.InconclusiveDimensionOperation as e:
raise core.InconclusiveDimensionOperation(
f"In arange with non-constant arguments the step ({step}) must " +
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
gap = step if step_gt_0 else - step
distance = (stop - start) if step_gt_0 else (start - stop)
try:
if distance >= 1 - gap:
size = (distance + gap - 1) // gap
else:
size = 0
except core.InconclusiveDimensionOperation:
# Cannot resolve "distance >= 1 - gap". Perhaps we can resolve "distance >= 1"
try:
if distance >= 1:
assert False
else:
size = 0
except core.InconclusiveDimensionOperation:
raise core.InconclusiveDimensionOperation(
"In arange with non-constant dimensions the distance between "
f"start ({start}) and stop ({stop}) must be resolved statically "
f"if it is >= {1 - gap} or >= 1.")
return (array(start, dtype=dtype) +
array(step, dtype=dtype) * lax.iota(dtype, size))
@overload
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
endpoint: bool = True, retstep: Literal[False] = False,

View File

@ -635,13 +635,71 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=PS("h", "h"),
expected_output_signature=tf.TensorSpec([None, None]))
def test_arange(self):
def f_jax(x):
return x + jnp.arange(x.shape[0], dtype=np.float32)
x = np.ones((3,), dtype=np.float32)
@parameterized.named_parameters([
dict(testcase_name=f"_{name}",
make_args=make_args)
for name, make_args in [
# make_args invoked with op.shape[0]: start, stop, step, dtype
("b", lambda b: (b, None, None, None)),
("0_b+1", lambda b: (0, b + 1, None, None)),
("0_5b_2", lambda b: (0, 5 * b, 2, None)),
("0_5b+1_2", lambda b: (0, 5 * b + 1, 2, None)),
("b_5b+2_2", lambda b: (b, 5 * b + 2, 2, None)),
("0_b-1_2", lambda b: (0, b - 1, 2, None)),
("0_b-2_2", lambda b: (0, b - 2, 2, None)),
("0_-b_2", lambda b: (0, -b, 2, None)),
("0_1-b_2", lambda b: (0, 1 - b, 2, None)),
# Negative step
("b_0_-1", lambda b: (b, 0, -1, None)),
("b_1_-2", lambda b: (b, 1, -2, None)),
("b_-1_-1", lambda b: (b, -1, -1, None)),
("5b+1_0_-2", lambda b: (5 * b + 1, 0, -2, None)),
("5b+2_0_-2", lambda b: (5 * b + 2, 0, -2, None)),
# Symbolic step
("0_10_b", lambda b: (0, 10, b)),
("0_0_b", lambda b: (0, 0, b)),
("10_0_-b", lambda b: (10, 0, -b)),
("b_1_-b", lambda b: (b, 1, -b)),
# Float return type
("0_b_1_f32", lambda b: (0, b, 1, np.float32))
]
])
def test_arange(self, make_args=lambda b: (0, 0, b)):
def f_jax(x): # x: i32[b]
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((3,), dtype=np.int32)
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="b")(x),
f_jax(x))
@parameterized.named_parameters([
dict(testcase_name=f"_{name}",
make_args=make_args,
expect_error=expect_error, expect_msg=expect_msg)
for name, make_args, expect_error, expect_msg in [
# make_args invoked with op.shape[0]: start, stop, step, dtype
("float_start", lambda b: (0., b, None),
ValueError, "must be either dimension expressions or integers"),
("float_step", lambda b: (0, b, 0.5),
ValueError, "must be either dimension expressions or integers"),
("step_0", lambda b: (0, b, 0),
ValueError, "has step == 0"),
("inconclusive_step_sign", lambda b: (0, b, b - 2),
core.InconclusiveDimensionOperation,
"must be resolved statically if it is > 0 or < 0"),
("inconclusive_distance", lambda b: (0, b - 3, 2),
core.InconclusiveDimensionOperation,
"must be resolved statically if it is >= -1 or >= 1"),
]
])
def test_arange_error(self, make_args=lambda b: (0., b, 2),
expect_error=ValueError,
expect_msg="must be either dimension expressions or integers"):
def f_jax(x): # x: i32[b]
return x[0] + jnp.arange(*(make_args(x.shape[0])))
x = np.ones((3,), dtype=np.int32)
with self.assertRaisesRegex(expect_error, expect_msg):
jax2tf.convert(f_jax, polymorphic_shapes="b")(x)
def test_argmax(self):
def f_jax(x): # x: f32[b, 4, 5]
return lax.argmax(x, axis=1, index_dtype=np.int32)
@ -1687,29 +1745,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + jnp.sin(x))),
arg_descriptors=[RandArg((3, 4), _f32)],
poly_axes=[0]),
PolyHarness("arange", "start",
lambda op: jnp.arange(2 * op.shape[0], dtype=_f32) + op[0],
arg_descriptors=[RandArg((3,), _f32)],
poly_axes=[0]).both_enable_and_disable_xla(),
PolyHarness("arange", "start_no_dtype",
lambda op: jnp.arange(op.shape[0]) + op[0],
arg_descriptors=[RandArg((3,), _f32)],
poly_axes=[0]),
PolyHarness("arange", "error1",
lambda op: jnp.arange(op.shape[0], 10),
arg_descriptors=[RandArg((3,), _f32)],
poly_axes=[0],
expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")),
PolyHarness("arange", "error2",
lambda op: jnp.arange(1, op.shape[0]),
arg_descriptors=[RandArg((3,), _f32)],
poly_axes=[0],
expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")),
PolyHarness("arange", "error3",
lambda op: jnp.arange(1, 5, op.shape[0]),
arg_descriptors=[RandArg((3,), _f32)],
poly_axes=[0],
expect_error=(ValueError, "jax.numpy.arange supports non-constant arguments only in single-argument form")),
# Reduce the poly dimension
PolyHarness("argmax", "0",
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),