1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

[shape_poly] Replace non_negative_dim with max_dim and min_dim.

Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
This commit is contained in:
George Necula 2023-12-13 10:14:27 +01:00
parent f33f0e4337
commit 6b7b3a3902
6 changed files with 154 additions and 56 deletions

@ -20,18 +20,22 @@ Remember to align the itemized text with the first line of an item within a list
devices.
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
and `descending` arguments.
* Several changes to the handling of shape polymorphism (for
* Several changes to the handling of shape polymorphism (used in
{mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`): cleaner
pretty-printing of symbolic expressions ({jax-issue}`#19227`); simplified
and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0
({jax-issue}`#19231`; note that this may result in user-visible behavior
changes); improved the error messages for inconclusive inequality comparisons
({jax-issue}`#19235`).
({jax-issue}`#19235`); the `core.non_negative_dim` API (introduced recently)
was deprecated and `core.max_dim` and `core.min_dim` were introduced
({jax-issue}`#18953`) to express `max` and `min` for symbolic dimensions.
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).

@ -2051,7 +2051,7 @@ def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
"""
if definitely_equal(dilation, 1): # fast path
return d
return non_negative_dim(1 + dilation * (d - 1))
return max_dim(1 + dilation * (d - 1), 0)
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
"""max(0, (d - window_size) // window_stride + 1)
@ -2060,26 +2060,40 @@ def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimS
We assume window_size >= 1 and window_stride >= 1.
"""
# If d < window_size then (d - window_size) // window_stride < 0
return non_negative_dim((d - window_size) // window_stride + 1)
return max_dim((d - window_size) // window_stride + 1, 0)
# TODO(necula): Deprecated Jan 2024, to be removed.
def non_negative_dim(d: DimSize) -> DimSize:
"""max(d, 0)."""
if is_constant_dim(d):
return max(0, d)
assert is_symbolic_dim(d)
try:
d_ge_0 = (d >= 0)
return d if d_ge_0 else 0
except InconclusiveDimensionOperation:
return d.non_negative() # type: ignore
return max_dim(d, 0)
def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
"""Like min(d1, d2) but for both constant and symbolic dimensions."""
return d1 - non_negative_dim(d1 - d2)
d1_is_constant = is_constant_dim(d1)
if d1_is_constant and is_constant_dim(d2):
return min(d1, d2)
try:
d2_ge_d1 = (d2 >= d1)
return d1 if d2_ge_d1 else d2
except InconclusiveDimensionOperation:
if d1_is_constant:
return d2.rmin(d1) # type: ignore[union-attr]
else:
return d1.min(d2) # type: ignore[union-attr]
def max_dim(d1: DimSize, d2: DimSize) -> DimSize:
"""Like max(d1, d2) but for both constant and symbolic dimensions."""
return d1 + non_negative_dim(d2 - d1)
d1_is_constant = is_constant_dim(d1)
if d1_is_constant and is_constant_dim(d2):
return max(d1, d2)
try:
d1_ge_d2 = (d1 >= d2)
return d1 if d1_ge_d2 else d2
except InconclusiveDimensionOperation:
if d1_is_constant:
return d2.rmax(d1) # type: ignore[union-attr]
else:
return d1.max(d2) # type: ignore[union-attr]
def dimension_as_value(d: DimSize):
"""Turns a dimension size into a JAX array.

@ -116,7 +116,7 @@ from jax._src.core import (
new_sublevel as new_sublevel,
no_axis_name as no_axis_name,
no_effects as no_effects,
non_negative_dim as non_negative_dim,
non_negative_dim as _deprecated_non_negative_dim,
outfeed_primitives as outfeed_primitives,
pp_aval as pp_aval,
pp_eqn as pp_eqn,
@ -265,6 +265,10 @@ _deprecations = {
"symbolic_equal_dim": (
"jax.core.symbolic_equal_dim is deprecated. Use ==.", _deprecated_definitely_equal,
),
# Added Jan 8, 2024
"non_negative_dim": (
"jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _deprecated_non_negative_dim,
),
}
import typing
@ -279,6 +283,7 @@ if typing.TYPE_CHECKING:
collections = _src_core.collections
dimension_as_value = _deprecated_dimension_as_value
definitely_equal = _deprecated_definitely_equal
non_negative_dim = _deprecated_non_negative_dim
dtypes = _src_core.dtypes
lu = _src_core.lu
map = _src_core.map

@ -117,7 +117,10 @@ class _DimAtom:
#
FLOORDIV = "floordiv"
MOD = "mod"
NON_NEGATIVE = "non_negative" # The max of the operand and 0
MAX = "max"
MIN = "min"
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
# max but kept here for backwards compatibility.
def __init__(self, *operands: _DimExpr,
var: str | None = None,
@ -231,9 +234,13 @@ class _DimAtom:
else:
return (-np.inf, np.inf)
elif self.operation == _DimAtom.NON_NEGATIVE:
(b_l, b_h), = opnd_bounds
return (max(0, b_l), max(0, b_h))
elif self.operation == _DimAtom.MAX:
(a_l, a_h), (b_l, b_h) = opnd_bounds
return (max(a_l, b_l), max(a_h, b_h))
elif self.operation == _DimAtom.MIN:
(a_l, a_h), (b_l, b_h) = opnd_bounds
return (min(a_l, b_l), min(a_h, b_h))
else:
assert False
@ -253,15 +260,24 @@ class _DimAtom:
return divmod(*operand_values)[0] # type: ignore
elif self.operation == _DimAtom.MOD:
return divmod(*operand_values)[1] # type: ignore
elif self.operation == _DimAtom.NON_NEGATIVE:
operand = operand_values[0]
if core.is_constant_dim(operand):
return max(operand, 0)
if core.is_symbolic_dim(operand):
return core.non_negative_dim(operand)
elif self.operation == _DimAtom.MAX:
op1, op2 = operand_values
if core.is_constant_dim(op1) and core.is_constant_dim(op2):
return max(op1, op2)
if core.is_symbolic_dim(op1) or core.is_symbolic_dim(op2):
return core.max_dim(op1, op2)
# In the context of `evaluate` dimension variables may be mapped to
# JAX Tracers.
return lax.max(operand, 0)
return lax.max(op1, op2)
elif self.operation == _DimAtom.MIN:
op1, op2 = operand_values
if core.is_constant_dim(op1) and core.is_constant_dim(op2):
return min(op1, op2)
if core.is_symbolic_dim(op1) or core.is_symbolic_dim(op2):
return core.min_dim(op1, op2)
# In the context of `evaluate` dimension variables may be mapped to
# JAX Tracers.
return lax.min(op1, op2)
else:
assert False, self.operation
@ -489,7 +505,12 @@ class _DimExpr():
@classmethod
def from_operation(cls, operation: str, *operands: DimSize) -> _DimExpr:
return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1)
if operation == _DimAtom.NON_NEGATIVE: # For parsing
return _DimExpr.from_monomial(_DimMon.from_operation(_DimAtom.MAX,
*operands,
0), 1)
return _DimExpr.from_monomial(_DimMon.from_operation(operation,
*operands), 1)
def to_monomial(self) -> _DimMon | None:
"""Extract the single monomial from a symbolic expression.
@ -566,26 +587,42 @@ class _DimExpr():
return True
if ub < 0:
return False
# Attempt to handle non_negative. For the decomposition:
# e = factor * non_negative(operand)^exp * rest_monomial + rest_expr
# use the rule:
# e >= 0 IF factor * operand^exp * rest_monomial + rest_expr >= 0 AND
# (rest_expr >= 0 OR
# exp is odd AND factor * rest_monomial >= 0 OR
# exp is even AND factor * rest_monomial <= 0)
for dec in _decompose_expr(self_minus_other, _DimAtom.NON_NEGATIVE):
# e = factor * non_negative(operands)^exp * rest_monomial + rest_expr
e2 = dec.rest_expr + dec.factor * (
dec.operands[0] ** dec.exp) * dec.rest_monomial
if not definitely_geq_0(e2):
continue
if definitely_geq_0(dec.rest_expr):
return True
if dec.exp % 2 == 1:
if definitely_geq_0(dec.factor * dec.rest_monomial):
# Attempt to handle max. For the decomposition
# e = factor * max(op1, op2) + rest_expr
# use the rule
# e >= 0 IF (factor > 0 AND
# (factor * op1 + rest_expr >= 0 OR
# factor * op2 + rest_expr >= 0))
# OR
# (factor < 0 AND
# (factor * op1 + rest_expr >= 0 AND
# factor * op2 + rest_expr >= 0))
for dec in _decompose_expr(self_minus_other, _DimAtom.MAX,
with_exp=1, with_rest_monomial=1):
op1, op2 = dec.operands
if dec.factor > 0:
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) or
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
return True
else:
if definitely_geq_0(- dec.factor * dec.rest_monomial):
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) and
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
return True
# Attempt to handle min. For the decomposition
# e = factor * min(op1, op2) + rest_expr
# use the same rule as for
# e = max(factor * op1, factor * op2) + rest_expr
for dec in _decompose_expr(self_minus_other, _DimAtom.MIN,
with_exp=1, with_rest_monomial=1):
op1, op2 = dec.operands
if dec.factor > 0:
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) and
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
return True
else:
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) or
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
return True
# Attempt to handle floordiv >= 0
@ -837,8 +874,17 @@ class _DimExpr():
for mon, coeff in self.monomials()]
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
def non_negative(self) -> _DimExpr:
return _DimExpr.from_operation(_DimAtom.NON_NEGATIVE, self)
def max(self, other: DimSize) -> _DimExpr:
return _DimExpr.from_operation(_DimAtom.MAX, self, other)
def rmax(self, other: DimSize) -> _DimExpr:
return _DimExpr.from_operation(_DimAtom.MAX, other, self)
def min(self, other: DimSize) -> _DimExpr:
return _DimExpr.from_operation(_DimAtom.MIN, self, other)
def rmin(self, other: DimSize) -> _DimExpr:
return _DimExpr.from_operation(_DimAtom.MIN, other, self)
@staticmethod
def get_aval(dim: _DimExpr):
@ -1285,11 +1331,9 @@ class _Parser:
def atom(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
if tok.exact_type == tokenize.NAME:
if tok.string == _DimAtom.MOD:
return self.atom_binary_op(_DimAtom.MOD, self.next_tok())
if tok.string == _DimAtom.FLOORDIV:
return self.atom_binary_op(_DimAtom.FLOORDIV, self.next_tok())
if tok.string == _DimAtom.NON_NEGATIVE:
if tok.string in (_DimAtom.MOD, _DimAtom.FLOORDIV, _DimAtom.MAX, _DimAtom.MIN):
return self.atom_binary_op(tok.string, self.next_tok())
if tok.string == _DimAtom.NON_NEGATIVE: # We still parse this for backwards compatibility
return self.atom_unary_op(_DimAtom.NON_NEGATIVE, self.next_tok())
return _DimExpr.from_var(tok.string), self.next_tok()
number_sign = 1

@ -775,7 +775,7 @@ class JaxExportTest(jtu.JaxTestCase):
def output_shape(b):
return (b + b, b - b, b * b,
(b + 13) // b, (b + 13) % b,
core.non_negative_dim(b - 5))
core.max_dim(b - 5, 0))
def f(x): # x: f32[b]
b = x.shape[0]
return jnp.ones(output_shape(b), dtype=x.dtype)

@ -140,11 +140,20 @@ class DimExprTest(jtu.JaxTestCase):
("a + -1", a - 1),
("3 * a * mod(a + 2, b + 2)", 3 * a * ((a + 2) % (b + 2))),
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2),
("non_negative(a - 2)", "build_inside"),
# Keep for backwards compatibility. We ought to be able to parse
# non_negative
("non_negative(a - 2)", core.max_dim(a - 2, 0)),
("max(a, b)", "build_inside"),
("min(a, b)", "build_inside"),
]])
def test_parse_dim(self, dim_spec, dim_poly):
if dim_spec == "non_negative(a - 2)":
dim_poly = core.non_negative_dim(DimExprTest.a - 2)
elif dim_spec == "max(a, b)":
dim_poly = core.max_dim(DimExprTest.a, DimExprTest.b)
elif dim_spec == "min(a, b)":
dim_poly = core.min_dim(DimExprTest.a, DimExprTest.b)
self.assertEqual((dim_poly,), shape_poly.symbolic_shape(dim_spec))
self.assertEqual((dim_poly,), shape_poly.symbolic_shape(str(dim_poly)))
@ -397,6 +406,13 @@ class DimExprTest(jtu.JaxTestCase):
def test_min_dim(self):
a, b, c = shape_poly.symbolic_shape("a, b, c")
self.assertEqual(core.min_dim(a, b).bounds(), (1, np.inf))
self.assertEqual(core.min_dim(2, b).bounds(), (1, 2))
self.assertEqual(core.min_dim(a, -2), -2)
self.assertEqual(core.min_dim(a - 5, 1).bounds(), (-4, 1))
self.assertEqual(core.min_dim(15 - a, 10).bounds(), (-np.inf, 10))
self.assertEqual(core.min_dim(15 - a, 20).bounds(), (-np.inf, 14))
self.assertEqual(a, core.min_dim(a, a + 2))
self.assertEqual(a - 2, core.min_dim(a, a - 2))
self.assertGreaterEqual(a, core.min_dim(a, b))
@ -404,9 +420,20 @@ class DimExprTest(jtu.JaxTestCase):
self.assertGreaterEqual(b, core.min_dim(a, b))
self.assertGreaterEqual(b + c - 1, core.min_dim(a, b))
self.sampled_assertion(core.min_dim(a, 5),
core.min_dim, a, 5)
self.sampled_assertion(core.min_dim(5, a),
core.min_dim, 5, a)
def test_max_dim(self):
a, b, c = shape_poly.symbolic_shape("a, b, c")
self.assertEqual(core.max_dim(a, b).bounds(), (1, np.inf))
self.assertEqual(core.max_dim(2, b).bounds(), (2, np.inf))
self.assertEqual(core.max_dim(a, 2).bounds(), (2, np.inf))
self.assertEqual(core.max_dim(a - 5, 1).bounds(), (1, np.inf))
self.assertEqual(core.max_dim(15 - a, 0).bounds(), (0, 14))
self.assertEqual((core.max_dim(15 - a, 0) // 3).bounds(), (0, 4))
self.assertEqual(a + 2, core.max_dim(a, a + 2))
self.assertEqual(a , core.max_dim(a, a - 2))
self.assertGreaterEqual(core.max_dim(a, b), a)
@ -415,6 +442,10 @@ class DimExprTest(jtu.JaxTestCase):
self.assertGreaterEqual(core.max_dim(a, b) + c - 1, b)
self.assertGreaterEqual(core.max_dim(a, b), core.min_dim(a, b))
self.sampled_assertion(core.max_dim(a, 5),
core.max_dim, a, 5)
self.sampled_assertion(core.max_dim(5, a),
core.max_dim, 5, a)
def test_clamp_dim(self):
a, b = shape_poly.symbolic_shape("a, b")
@ -583,7 +614,7 @@ class DimExprTest(jtu.JaxTestCase):
self.sampled_assertion(0, core.dilate_dim, 0, 3)
self.sampled_assertion(a, core.dilate_dim, a, 1)
self.sampled_assertion(2 * a - 1, core.dilate_dim, a, 2)
self.sampled_assertion(core.non_negative_dim(2 * a - 3),
self.sampled_assertion(core.max_dim(2 * a - 3, 0),
core.dilate_dim, a - 1, 2)
def test_stride_dim(self):
@ -600,7 +631,7 @@ class DimExprTest(jtu.JaxTestCase):
self.sampled_assertion(a - 1, core.stride_dim, a, 2, 1)
self.sampled_assertion(a + 1, core.stride_dim, a * stride + 2, 2, stride)
self.sampled_assertion((a - 1) // 2 + 1, core.stride_dim, a, 1, 2)
self.sampled_assertion(core.non_negative_dim((a - 4) // 2 + 1),
self.sampled_assertion(core.max_dim((a - 4) // 2 + 1, 0),
core.stride_dim, a, 4, 2)