[shape_poly] Adding shape polymorphism support for the state primitives.

This commit is contained in:
George Necula 2024-11-20 20:50:37 -08:00
parent 8d84f28373
commit 0831e2e340
6 changed files with 100 additions and 68 deletions

View File

@ -17,7 +17,6 @@ import google_benchmark as benchmark
import jax
from jax import core
from jax._src.numpy import lax_numpy
from jax import export
jax.config.parse_flags_with_absl()
@ -76,7 +75,7 @@ def inequalities_slice(state):
while state:
for _ in range(30):
a.scope._clear_caches()
start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
_ = 0 <= slice_size <= b
_ = start >= 0
_ = start + slice_size <= b

View File

@ -2047,6 +2047,70 @@ def dimension_as_value(d: DimSize):
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
return operator.index(d)
def canonicalize_slice(
s: slice,
axis_size: DimSize
) -> tuple[DimSize, DimSize, DimSize]:
"""Computes the start index, step, and size of the slice `x[s]`.
This is similar to `s.indices(axis_size)`, except that it returns
`(start, step, size)`, and it works when the slice and/or the
`axis_size` are symbolic.
See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
"""
def convert_to_index(d: DimSize) -> DimSize:
# Convert np.array and jax.Array to int, leave symbolic dimensions alone
try:
return operator.index(d)
except:
return d
# Must resolve statically if step is {<0, ==0, >0}
step = convert_to_index(s.step) if s.step is not None else 1
try:
if step == 0:
raise ValueError("slice step cannot be zero")
step_gt_0 = (step > 0)
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(
f"In slice with non-constant elements the step ({step}) must " +
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
def clamp_index(i: DimSize, which: str):
try:
i_ge_0 = (i >= 0)
except InconclusiveDimensionOperation as e:
raise InconclusiveDimensionOperation(
f"In slice with non-constant elements the {which} ({i}) must " +
f"be resolved statically if it is >= 0.\nDetails: {e}")
if i_ge_0:
if step_gt_0:
return min_dim(axis_size, i)
else:
return min_dim(axis_size - 1, i)
else:
if step_gt_0:
return max_dim(0, axis_size + i)
else:
return max_dim(-1, axis_size + i)
if s.start is None:
start = 0 if step_gt_0 else axis_size - 1
else:
start = clamp_index(convert_to_index(s.start), "start")
if s.stop is None:
stop = axis_size if step_gt_0 else -1
else:
stop = clamp_index(convert_to_index(s.stop), "stop")
gap = step if step_gt_0 else - step
distance = (stop - start) if step_gt_0 else (start - stop)
slice_size = max_dim(0, distance + gap - 1) // gap
return start, step, slice_size
class SomeTracer:
__slots__ = ()
def __repr__(self): return "[dynamic]"

View File

@ -12116,7 +12116,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
"arrays within JIT compiled functions).")
raise IndexError(msg)
start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis])
slice_shape.append(slice_size)
if core.definitely_equal(step, 1):
@ -12319,65 +12319,6 @@ def _canonicalize_tuple_index(arr_ndim, idx):
idx = tuple(idx) + colons
return idx
def _preprocess_slice(
s: slice,
axis_size: core.DimSize
) -> tuple[core.DimSize, core.DimSize, core.DimSize]:
"""Computes the start index, step, and size of the slice `x[s]`."""
# See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
# "this is harder to get right than you may think"
# (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275)
def convert_to_index(d: DimSize) -> DimSize:
# Convert np.array and jax.Array to int, leave symbolic dimensions alone
try:
return operator.index(d)
except:
return d
# Must resolve statically if step is {<0, ==0, >0}
step = convert_to_index(s.step) if s.step is not None else 1
try:
if step == 0:
raise ValueError("slice step cannot be zero")
step_gt_0 = (step > 0)
except core.InconclusiveDimensionOperation as e:
raise core.InconclusiveDimensionOperation(
f"In slice with non-constant elements the step ({step}) must " +
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
def clamp_index(i: DimSize, which: str):
try:
i_ge_0 = (i >= 0)
except core.InconclusiveDimensionOperation as e:
raise core.InconclusiveDimensionOperation(
f"In slice with non-constant elements the {which} ({i}) must " +
f"be resolved statically if it is >= 0.\nDetails: {e}")
if i_ge_0:
if step_gt_0:
return core.min_dim(axis_size, i)
else:
return core.min_dim(axis_size - 1, i)
else:
if step_gt_0:
return core.max_dim(0, axis_size + i)
else:
return core.max_dim(-1, axis_size + i)
if s.start is None:
start = 0 if step_gt_0 else axis_size - 1
else:
start = clamp_index(convert_to_index(s.start), "start")
if s.stop is None:
stop = axis_size if step_gt_0 else -1
else:
stop = clamp_index(convert_to_index(s.stop), "stop")
gap = step if step_gt_0 else - step
distance = (stop - start) if step_gt_0 else (start - stop)
slice_size = core.max_dim(0, distance + gap - 1) // gap
return start, step, slice_size
@export
def blackman(M: int) -> Array:

View File

@ -46,11 +46,11 @@ class Slice:
@property
def is_dynamic_start(self):
return not isinstance(self.start, int)
return not core.is_dim(self.start)
@property
def is_dynamic_size(self):
return not isinstance(self.size, int)
return not core.is_dim(self.size)
def tree_flatten(self):
# If `start` is statically known, we treat it as static information
@ -72,10 +72,10 @@ class Slice:
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop, step = slc.indices(size)
start, step, size = core.canonicalize_slice(slc, size)
if step < 1:
raise ValueError(f"slice must have a step >= 1 (found: {step})")
return cls(start, max((stop - start + step - 1) // step, 0), step)
return cls(start, size, step)
def dslice(

View File

@ -48,6 +48,9 @@ from jax._src.export import shape_poly
from jax._src.export import shape_poly_decision
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src.state import discharge
from jax._src.state import primitives as ref_primitives
import numpy as np
config.parse_flags_with_absl()
@ -2062,6 +2065,31 @@ class ShapePolyTest(jtu.JaxTestCase):
polymorphic_shapes=["b, ...", "c, ...", None])
@jtu.parameterized_filterable(
kwargs=[
dict(slc=slc)
for slc in [
slice(None, None, None),
slice(2, 5),
]
])
def test_stateful(self, slc: slice):
w, = export.symbolic_shape("w", constraints=["w >= 3"])
def f(x_ref):
ones = jnp.ones_like(x_ref)[slc]
ref_primitives.ref_addupdate(x_ref, slc, ones)
x1 = ref_primitives.ref_get(x_ref, slc)
x2 = x1 + ones
ref_primitives.ref_set(x_ref, slc, x2)
exp = export.export(jax.jit(discharge.run_state(f)))(
jax.ShapeDtypeStruct((w,), dtype=_f32))
x = np.ones((32,), dtype=_f32)
expected = np.copy(x)
expected[slc] = 3.
self.assertAllClose(exp.call(x), expected)
# List containing either harnesses, or lists of harnesses
_POLY_SHAPE_TEST_HARNESSES = [
PolyHarness("add", "",
@ -3603,7 +3631,7 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
not harness.polymorphic_shapes[0].endswith("...") and
jtu.test_device_matches(["tpu"])):
raise unittest.SkipTest(
"Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.")
"Shape polymorphism for Eigh and Svd is only supported for batch dimensions on TPU.")
config_flags = harness.override_jax_config_flags
# Update this here rather than in harness object because vmap_random_gamma is derived

View File

@ -752,7 +752,7 @@ class StateDischargeTest(jtu.JaxTestCase):
lu.wrap_init(f), [scalar_ref_1, scalar_ref_2])
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns)
prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns)
self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr))
self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr))