mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[shape_poly] Adding shape polymorphism support for the state primitives.
This commit is contained in:
parent
8d84f28373
commit
0831e2e340
@ -17,7 +17,6 @@ import google_benchmark as benchmark
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax import export
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
@ -76,7 +75,7 @@ def inequalities_slice(state):
|
||||
while state:
|
||||
for _ in range(30):
|
||||
a.scope._clear_caches()
|
||||
start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
|
||||
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
|
||||
_ = 0 <= slice_size <= b
|
||||
_ = start >= 0
|
||||
_ = start + slice_size <= b
|
||||
|
@ -2047,6 +2047,70 @@ def dimension_as_value(d: DimSize):
|
||||
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
|
||||
return operator.index(d)
|
||||
|
||||
def canonicalize_slice(
|
||||
s: slice,
|
||||
axis_size: DimSize
|
||||
) -> tuple[DimSize, DimSize, DimSize]:
|
||||
"""Computes the start index, step, and size of the slice `x[s]`.
|
||||
|
||||
This is similar to `s.indices(axis_size)`, except that it returns
|
||||
`(start, step, size)`, and it works when the slice and/or the
|
||||
`axis_size` are symbolic.
|
||||
|
||||
See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
|
||||
"""
|
||||
def convert_to_index(d: DimSize) -> DimSize:
|
||||
# Convert np.array and jax.Array to int, leave symbolic dimensions alone
|
||||
try:
|
||||
return operator.index(d)
|
||||
except:
|
||||
return d
|
||||
|
||||
# Must resolve statically if step is {<0, ==0, >0}
|
||||
step = convert_to_index(s.step) if s.step is not None else 1
|
||||
try:
|
||||
if step == 0:
|
||||
raise ValueError("slice step cannot be zero")
|
||||
step_gt_0 = (step > 0)
|
||||
except InconclusiveDimensionOperation as e:
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"In slice with non-constant elements the step ({step}) must " +
|
||||
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
|
||||
|
||||
def clamp_index(i: DimSize, which: str):
|
||||
try:
|
||||
i_ge_0 = (i >= 0)
|
||||
except InconclusiveDimensionOperation as e:
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"In slice with non-constant elements the {which} ({i}) must " +
|
||||
f"be resolved statically if it is >= 0.\nDetails: {e}")
|
||||
if i_ge_0:
|
||||
if step_gt_0:
|
||||
return min_dim(axis_size, i)
|
||||
else:
|
||||
return min_dim(axis_size - 1, i)
|
||||
else:
|
||||
if step_gt_0:
|
||||
return max_dim(0, axis_size + i)
|
||||
else:
|
||||
return max_dim(-1, axis_size + i)
|
||||
|
||||
if s.start is None:
|
||||
start = 0 if step_gt_0 else axis_size - 1
|
||||
else:
|
||||
start = clamp_index(convert_to_index(s.start), "start")
|
||||
|
||||
if s.stop is None:
|
||||
stop = axis_size if step_gt_0 else -1
|
||||
else:
|
||||
stop = clamp_index(convert_to_index(s.stop), "stop")
|
||||
|
||||
gap = step if step_gt_0 else - step
|
||||
distance = (stop - start) if step_gt_0 else (start - stop)
|
||||
slice_size = max_dim(0, distance + gap - 1) // gap
|
||||
return start, step, slice_size
|
||||
|
||||
|
||||
class SomeTracer:
|
||||
__slots__ = ()
|
||||
def __repr__(self): return "[dynamic]"
|
||||
|
@ -12116,7 +12116,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
|
||||
"arrays within JIT compiled functions).")
|
||||
raise IndexError(msg)
|
||||
|
||||
start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
|
||||
start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis])
|
||||
slice_shape.append(slice_size)
|
||||
|
||||
if core.definitely_equal(step, 1):
|
||||
@ -12319,65 +12319,6 @@ def _canonicalize_tuple_index(arr_ndim, idx):
|
||||
idx = tuple(idx) + colons
|
||||
return idx
|
||||
|
||||
def _preprocess_slice(
|
||||
s: slice,
|
||||
axis_size: core.DimSize
|
||||
) -> tuple[core.DimSize, core.DimSize, core.DimSize]:
|
||||
"""Computes the start index, step, and size of the slice `x[s]`."""
|
||||
# See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
|
||||
# "this is harder to get right than you may think"
|
||||
# (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275)
|
||||
def convert_to_index(d: DimSize) -> DimSize:
|
||||
# Convert np.array and jax.Array to int, leave symbolic dimensions alone
|
||||
try:
|
||||
return operator.index(d)
|
||||
except:
|
||||
return d
|
||||
|
||||
# Must resolve statically if step is {<0, ==0, >0}
|
||||
step = convert_to_index(s.step) if s.step is not None else 1
|
||||
try:
|
||||
if step == 0:
|
||||
raise ValueError("slice step cannot be zero")
|
||||
step_gt_0 = (step > 0)
|
||||
except core.InconclusiveDimensionOperation as e:
|
||||
raise core.InconclusiveDimensionOperation(
|
||||
f"In slice with non-constant elements the step ({step}) must " +
|
||||
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
|
||||
|
||||
def clamp_index(i: DimSize, which: str):
|
||||
try:
|
||||
i_ge_0 = (i >= 0)
|
||||
except core.InconclusiveDimensionOperation as e:
|
||||
raise core.InconclusiveDimensionOperation(
|
||||
f"In slice with non-constant elements the {which} ({i}) must " +
|
||||
f"be resolved statically if it is >= 0.\nDetails: {e}")
|
||||
if i_ge_0:
|
||||
if step_gt_0:
|
||||
return core.min_dim(axis_size, i)
|
||||
else:
|
||||
return core.min_dim(axis_size - 1, i)
|
||||
else:
|
||||
if step_gt_0:
|
||||
return core.max_dim(0, axis_size + i)
|
||||
else:
|
||||
return core.max_dim(-1, axis_size + i)
|
||||
|
||||
if s.start is None:
|
||||
start = 0 if step_gt_0 else axis_size - 1
|
||||
else:
|
||||
start = clamp_index(convert_to_index(s.start), "start")
|
||||
|
||||
if s.stop is None:
|
||||
stop = axis_size if step_gt_0 else -1
|
||||
else:
|
||||
stop = clamp_index(convert_to_index(s.stop), "stop")
|
||||
|
||||
gap = step if step_gt_0 else - step
|
||||
distance = (stop - start) if step_gt_0 else (start - stop)
|
||||
slice_size = core.max_dim(0, distance + gap - 1) // gap
|
||||
return start, step, slice_size
|
||||
|
||||
|
||||
@export
|
||||
def blackman(M: int) -> Array:
|
||||
|
@ -46,11 +46,11 @@ class Slice:
|
||||
|
||||
@property
|
||||
def is_dynamic_start(self):
|
||||
return not isinstance(self.start, int)
|
||||
return not core.is_dim(self.start)
|
||||
|
||||
@property
|
||||
def is_dynamic_size(self):
|
||||
return not isinstance(self.size, int)
|
||||
return not core.is_dim(self.size)
|
||||
|
||||
def tree_flatten(self):
|
||||
# If `start` is statically known, we treat it as static information
|
||||
@ -72,10 +72,10 @@ class Slice:
|
||||
|
||||
@classmethod
|
||||
def from_slice(cls, slc: slice, size: int) -> Slice:
|
||||
start, stop, step = slc.indices(size)
|
||||
start, step, size = core.canonicalize_slice(slc, size)
|
||||
if step < 1:
|
||||
raise ValueError(f"slice must have a step >= 1 (found: {step})")
|
||||
return cls(start, max((stop - start + step - 1) // step, 0), step)
|
||||
return cls(start, size, step)
|
||||
|
||||
|
||||
def dslice(
|
||||
|
@ -48,6 +48,9 @@ from jax._src.export import shape_poly
|
||||
from jax._src.export import shape_poly_decision
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.state import discharge
|
||||
from jax._src.state import primitives as ref_primitives
|
||||
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -2062,6 +2065,31 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
polymorphic_shapes=["b, ...", "c, ...", None])
|
||||
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(slc=slc)
|
||||
for slc in [
|
||||
slice(None, None, None),
|
||||
slice(2, 5),
|
||||
]
|
||||
])
|
||||
def test_stateful(self, slc: slice):
|
||||
w, = export.symbolic_shape("w", constraints=["w >= 3"])
|
||||
def f(x_ref):
|
||||
ones = jnp.ones_like(x_ref)[slc]
|
||||
ref_primitives.ref_addupdate(x_ref, slc, ones)
|
||||
x1 = ref_primitives.ref_get(x_ref, slc)
|
||||
x2 = x1 + ones
|
||||
ref_primitives.ref_set(x_ref, slc, x2)
|
||||
|
||||
exp = export.export(jax.jit(discharge.run_state(f)))(
|
||||
jax.ShapeDtypeStruct((w,), dtype=_f32))
|
||||
x = np.ones((32,), dtype=_f32)
|
||||
expected = np.copy(x)
|
||||
expected[slc] = 3.
|
||||
self.assertAllClose(exp.call(x), expected)
|
||||
|
||||
|
||||
# List containing either harnesses, or lists of harnesses
|
||||
_POLY_SHAPE_TEST_HARNESSES = [
|
||||
PolyHarness("add", "",
|
||||
@ -3603,7 +3631,7 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
not harness.polymorphic_shapes[0].endswith("...") and
|
||||
jtu.test_device_matches(["tpu"])):
|
||||
raise unittest.SkipTest(
|
||||
"Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.")
|
||||
"Shape polymorphism for Eigh and Svd is only supported for batch dimensions on TPU.")
|
||||
|
||||
config_flags = harness.override_jax_config_flags
|
||||
# Update this here rather than in harness object because vmap_random_gamma is derived
|
||||
|
@ -752,7 +752,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
lu.wrap_init(f), [scalar_ref_1, scalar_ref_2])
|
||||
|
||||
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
|
||||
prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns)
|
||||
prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns)
|
||||
self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr))
|
||||
self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user