From 0831e2e3401dfde3b12e407cb4c366b420b16348 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 20 Nov 2024 20:50:37 -0800 Subject: [PATCH] [shape_poly] Adding shape polymorphism support for the state primitives. --- benchmarks/shape_poly_benchmark.py | 3 +- jax/_src/core.py | 64 ++++++++++++++++++++++++++++++ jax/_src/numpy/lax_numpy.py | 61 +--------------------------- jax/_src/state/indexing.py | 8 ++-- tests/shape_poly_test.py | 30 +++++++++++++- tests/state_test.py | 2 +- 6 files changed, 100 insertions(+), 68 deletions(-) diff --git a/benchmarks/shape_poly_benchmark.py b/benchmarks/shape_poly_benchmark.py index d26801d8d..d365a6fac 100644 --- a/benchmarks/shape_poly_benchmark.py +++ b/benchmarks/shape_poly_benchmark.py @@ -17,7 +17,6 @@ import google_benchmark as benchmark import jax from jax import core -from jax._src.numpy import lax_numpy from jax import export jax.config.parse_flags_with_absl() @@ -76,7 +75,7 @@ def inequalities_slice(state): while state: for _ in range(30): a.scope._clear_caches() - start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b) + start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b) _ = 0 <= slice_size <= b _ = start >= 0 _ = start + slice_size <= b diff --git a/jax/_src/core.py b/jax/_src/core.py index cbf3282fb..faf33f00b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2047,6 +2047,70 @@ def dimension_as_value(d: DimSize): if hasattr(d, "dimension_as_value"): return d.dimension_as_value() return operator.index(d) +def canonicalize_slice( + s: slice, + axis_size: DimSize + ) -> tuple[DimSize, DimSize, DimSize]: + """Computes the start index, step, and size of the slice `x[s]`. + + This is similar to `s.indices(axis_size)`, except that it returns + `(start, step, size)`, and it works when the slice and/or the + `axis_size` are symbolic. + + See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding + """ + def convert_to_index(d: DimSize) -> DimSize: + # Convert np.array and jax.Array to int, leave symbolic dimensions alone + try: + return operator.index(d) + except: + return d + + # Must resolve statically if step is {<0, ==0, >0} + step = convert_to_index(s.step) if s.step is not None else 1 + try: + if step == 0: + raise ValueError("slice step cannot be zero") + step_gt_0 = (step > 0) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + f"In slice with non-constant elements the step ({step}) must " + + f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") + + def clamp_index(i: DimSize, which: str): + try: + i_ge_0 = (i >= 0) + except InconclusiveDimensionOperation as e: + raise InconclusiveDimensionOperation( + f"In slice with non-constant elements the {which} ({i}) must " + + f"be resolved statically if it is >= 0.\nDetails: {e}") + if i_ge_0: + if step_gt_0: + return min_dim(axis_size, i) + else: + return min_dim(axis_size - 1, i) + else: + if step_gt_0: + return max_dim(0, axis_size + i) + else: + return max_dim(-1, axis_size + i) + + if s.start is None: + start = 0 if step_gt_0 else axis_size - 1 + else: + start = clamp_index(convert_to_index(s.start), "start") + + if s.stop is None: + stop = axis_size if step_gt_0 else -1 + else: + stop = clamp_index(convert_to_index(s.stop), "stop") + + gap = step if step_gt_0 else - step + distance = (stop - start) if step_gt_0 else (start - stop) + slice_size = max_dim(0, distance + gap - 1) // gap + return start, step, slice_size + + class SomeTracer: __slots__ = () def __repr__(self): return "[dynamic]" diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 898e4255d..5f380fad9 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -12116,7 +12116,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], "arrays within JIT compiled functions).") raise IndexError(msg) - start, step, slice_size = _preprocess_slice(i, x_shape[x_axis]) + start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis]) slice_shape.append(slice_size) if core.definitely_equal(step, 1): @@ -12319,65 +12319,6 @@ def _canonicalize_tuple_index(arr_ndim, idx): idx = tuple(idx) + colons return idx -def _preprocess_slice( - s: slice, - axis_size: core.DimSize - ) -> tuple[core.DimSize, core.DimSize, core.DimSize]: - """Computes the start index, step, and size of the slice `x[s]`.""" - # See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding - # "this is harder to get right than you may think" - # (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275) - def convert_to_index(d: DimSize) -> DimSize: - # Convert np.array and jax.Array to int, leave symbolic dimensions alone - try: - return operator.index(d) - except: - return d - - # Must resolve statically if step is {<0, ==0, >0} - step = convert_to_index(s.step) if s.step is not None else 1 - try: - if step == 0: - raise ValueError("slice step cannot be zero") - step_gt_0 = (step > 0) - except core.InconclusiveDimensionOperation as e: - raise core.InconclusiveDimensionOperation( - f"In slice with non-constant elements the step ({step}) must " + - f"be resolved statically if it is > 0 or < 0.\nDetails: {e}") - - def clamp_index(i: DimSize, which: str): - try: - i_ge_0 = (i >= 0) - except core.InconclusiveDimensionOperation as e: - raise core.InconclusiveDimensionOperation( - f"In slice with non-constant elements the {which} ({i}) must " + - f"be resolved statically if it is >= 0.\nDetails: {e}") - if i_ge_0: - if step_gt_0: - return core.min_dim(axis_size, i) - else: - return core.min_dim(axis_size - 1, i) - else: - if step_gt_0: - return core.max_dim(0, axis_size + i) - else: - return core.max_dim(-1, axis_size + i) - - if s.start is None: - start = 0 if step_gt_0 else axis_size - 1 - else: - start = clamp_index(convert_to_index(s.start), "start") - - if s.stop is None: - stop = axis_size if step_gt_0 else -1 - else: - stop = clamp_index(convert_to_index(s.stop), "stop") - - gap = step if step_gt_0 else - step - distance = (stop - start) if step_gt_0 else (start - stop) - slice_size = core.max_dim(0, distance + gap - 1) // gap - return start, step, slice_size - @export def blackman(M: int) -> Array: diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 538f3f8e4..2da93e3d8 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -46,11 +46,11 @@ class Slice: @property def is_dynamic_start(self): - return not isinstance(self.start, int) + return not core.is_dim(self.start) @property def is_dynamic_size(self): - return not isinstance(self.size, int) + return not core.is_dim(self.size) def tree_flatten(self): # If `start` is statically known, we treat it as static information @@ -72,10 +72,10 @@ class Slice: @classmethod def from_slice(cls, slc: slice, size: int) -> Slice: - start, stop, step = slc.indices(size) + start, step, size = core.canonicalize_slice(slc, size) if step < 1: raise ValueError(f"slice must have a step >= 1 (found: {step})") - return cls(start, max((stop - start + step - 1) // step, 0), step) + return cls(start, size, step) def dslice( diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index eda4c4309..668907ffe 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -48,6 +48,9 @@ from jax._src.export import shape_poly from jax._src.export import shape_poly_decision from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow +from jax._src.state import discharge +from jax._src.state import primitives as ref_primitives + import numpy as np config.parse_flags_with_absl() @@ -2062,6 +2065,31 @@ class ShapePolyTest(jtu.JaxTestCase): polymorphic_shapes=["b, ...", "c, ...", None]) + @jtu.parameterized_filterable( + kwargs=[ + dict(slc=slc) + for slc in [ + slice(None, None, None), + slice(2, 5), + ] + ]) + def test_stateful(self, slc: slice): + w, = export.symbolic_shape("w", constraints=["w >= 3"]) + def f(x_ref): + ones = jnp.ones_like(x_ref)[slc] + ref_primitives.ref_addupdate(x_ref, slc, ones) + x1 = ref_primitives.ref_get(x_ref, slc) + x2 = x1 + ones + ref_primitives.ref_set(x_ref, slc, x2) + + exp = export.export(jax.jit(discharge.run_state(f)))( + jax.ShapeDtypeStruct((w,), dtype=_f32)) + x = np.ones((32,), dtype=_f32) + expected = np.copy(x) + expected[slc] = 3. + self.assertAllClose(exp.call(x), expected) + + # List containing either harnesses, or lists of harnesses _POLY_SHAPE_TEST_HARNESSES = [ PolyHarness("add", "", @@ -3603,7 +3631,7 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase): not harness.polymorphic_shapes[0].endswith("...") and jtu.test_device_matches(["tpu"])): raise unittest.SkipTest( - "Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.") + "Shape polymorphism for Eigh and Svd is only supported for batch dimensions on TPU.") config_flags = harness.override_jax_config_flags # Update this here rather than in harness object because vmap_random_gamma is derived diff --git a/tests/state_test.py b/tests/state_test.py index c84587426..44caded0c 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -752,7 +752,7 @@ class StateDischargeTest(jtu.JaxTestCase): lu.wrap_init(f), [scalar_ref_1, scalar_ref_2]) discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) - prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns) + prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns) self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr)) self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr))