mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
[shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to express `max(d, 0)`. This is needed in several places internally to express index computations involving clamping (for numpy indexing), or striding and dilation (which have a conditional semantics). It seemed that this special case was sufficient, and we expressed `max(a, b)` as `a + non_negative(b - a)` and `min(a, b)` as `a - non_negative(a - b)`. One drawback was that `non_negative` can be a surprising construct when it appears in error messages. Also, users need `max` and `min` computations with dimensions. It is clearer if we use `max` and `min` directly instead of rewriting these to use `non_negative`. The drawback is that we now have to duplicate some internal logic to for `max` and `min`, but overall I feel this is worth it for the better error messages we get.
This commit is contained in:
parent
f33f0e4337
commit
6b7b3a3902
@ -20,18 +20,22 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
devices.
|
||||
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
|
||||
and `descending` arguments.
|
||||
* Several changes to the handling of shape polymorphism (for
|
||||
* Several changes to the handling of shape polymorphism (used in
|
||||
{mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`): cleaner
|
||||
pretty-printing of symbolic expressions ({jax-issue}`#19227`); simplified
|
||||
and faster equality comparisons, where we consider two symbolic dimensions
|
||||
to be equal if the normalized form of their difference reduces to 0
|
||||
({jax-issue}`#19231`; note that this may result in user-visible behavior
|
||||
changes); improved the error messages for inconclusive inequality comparisons
|
||||
({jax-issue}`#19235`).
|
||||
({jax-issue}`#19235`); the `core.non_negative_dim` API (introduced recently)
|
||||
was deprecated and `core.max_dim` and `core.min_dim` were introduced
|
||||
({jax-issue}`#18953`) to express `max` and `min` for symbolic dimensions.
|
||||
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
|
||||
* Refactored the API for `jax.experimental.export`. Instead of
|
||||
`from jax.experimental.export import export` you should use now
|
||||
`from jax.experimental import export`. The old way of importing will
|
||||
continue to work for a deprecation period of 3 months.
|
||||
|
||||
* Deprecations & Removals
|
||||
* A number of previously deprecated functions have been removed, following a
|
||||
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
|
||||
|
@ -2051,7 +2051,7 @@ def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
|
||||
"""
|
||||
if definitely_equal(dilation, 1): # fast path
|
||||
return d
|
||||
return non_negative_dim(1 + dilation * (d - 1))
|
||||
return max_dim(1 + dilation * (d - 1), 0)
|
||||
|
||||
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
|
||||
"""max(0, (d - window_size) // window_stride + 1)
|
||||
@ -2060,26 +2060,40 @@ def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimS
|
||||
We assume window_size >= 1 and window_stride >= 1.
|
||||
"""
|
||||
# If d < window_size then (d - window_size) // window_stride < 0
|
||||
return non_negative_dim((d - window_size) // window_stride + 1)
|
||||
return max_dim((d - window_size) // window_stride + 1, 0)
|
||||
|
||||
# TODO(necula): Deprecated Jan 2024, to be removed.
|
||||
def non_negative_dim(d: DimSize) -> DimSize:
|
||||
"""max(d, 0)."""
|
||||
if is_constant_dim(d):
|
||||
return max(0, d)
|
||||
assert is_symbolic_dim(d)
|
||||
try:
|
||||
d_ge_0 = (d >= 0)
|
||||
return d if d_ge_0 else 0
|
||||
except InconclusiveDimensionOperation:
|
||||
return d.non_negative() # type: ignore
|
||||
return max_dim(d, 0)
|
||||
|
||||
def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
||||
"""Like min(d1, d2) but for both constant and symbolic dimensions."""
|
||||
return d1 - non_negative_dim(d1 - d2)
|
||||
d1_is_constant = is_constant_dim(d1)
|
||||
if d1_is_constant and is_constant_dim(d2):
|
||||
return min(d1, d2)
|
||||
try:
|
||||
d2_ge_d1 = (d2 >= d1)
|
||||
return d1 if d2_ge_d1 else d2
|
||||
except InconclusiveDimensionOperation:
|
||||
if d1_is_constant:
|
||||
return d2.rmin(d1) # type: ignore[union-attr]
|
||||
else:
|
||||
return d1.min(d2) # type: ignore[union-attr]
|
||||
|
||||
def max_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
||||
"""Like max(d1, d2) but for both constant and symbolic dimensions."""
|
||||
return d1 + non_negative_dim(d2 - d1)
|
||||
d1_is_constant = is_constant_dim(d1)
|
||||
if d1_is_constant and is_constant_dim(d2):
|
||||
return max(d1, d2)
|
||||
try:
|
||||
d1_ge_d2 = (d1 >= d2)
|
||||
return d1 if d1_ge_d2 else d2
|
||||
except InconclusiveDimensionOperation:
|
||||
if d1_is_constant:
|
||||
return d2.rmax(d1) # type: ignore[union-attr]
|
||||
else:
|
||||
return d1.max(d2) # type: ignore[union-attr]
|
||||
|
||||
def dimension_as_value(d: DimSize):
|
||||
"""Turns a dimension size into a JAX array.
|
||||
|
@ -116,7 +116,7 @@ from jax._src.core import (
|
||||
new_sublevel as new_sublevel,
|
||||
no_axis_name as no_axis_name,
|
||||
no_effects as no_effects,
|
||||
non_negative_dim as non_negative_dim,
|
||||
non_negative_dim as _deprecated_non_negative_dim,
|
||||
outfeed_primitives as outfeed_primitives,
|
||||
pp_aval as pp_aval,
|
||||
pp_eqn as pp_eqn,
|
||||
@ -265,6 +265,10 @@ _deprecations = {
|
||||
"symbolic_equal_dim": (
|
||||
"jax.core.symbolic_equal_dim is deprecated. Use ==.", _deprecated_definitely_equal,
|
||||
),
|
||||
# Added Jan 8, 2024
|
||||
"non_negative_dim": (
|
||||
"jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _deprecated_non_negative_dim,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
@ -279,6 +283,7 @@ if typing.TYPE_CHECKING:
|
||||
collections = _src_core.collections
|
||||
dimension_as_value = _deprecated_dimension_as_value
|
||||
definitely_equal = _deprecated_definitely_equal
|
||||
non_negative_dim = _deprecated_non_negative_dim
|
||||
dtypes = _src_core.dtypes
|
||||
lu = _src_core.lu
|
||||
map = _src_core.map
|
||||
|
@ -117,7 +117,10 @@ class _DimAtom:
|
||||
#
|
||||
FLOORDIV = "floordiv"
|
||||
MOD = "mod"
|
||||
NON_NEGATIVE = "non_negative" # The max of the operand and 0
|
||||
MAX = "max"
|
||||
MIN = "min"
|
||||
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
|
||||
# max but kept here for backwards compatibility.
|
||||
|
||||
def __init__(self, *operands: _DimExpr,
|
||||
var: str | None = None,
|
||||
@ -231,9 +234,13 @@ class _DimAtom:
|
||||
else:
|
||||
return (-np.inf, np.inf)
|
||||
|
||||
elif self.operation == _DimAtom.NON_NEGATIVE:
|
||||
(b_l, b_h), = opnd_bounds
|
||||
return (max(0, b_l), max(0, b_h))
|
||||
elif self.operation == _DimAtom.MAX:
|
||||
(a_l, a_h), (b_l, b_h) = opnd_bounds
|
||||
return (max(a_l, b_l), max(a_h, b_h))
|
||||
|
||||
elif self.operation == _DimAtom.MIN:
|
||||
(a_l, a_h), (b_l, b_h) = opnd_bounds
|
||||
return (min(a_l, b_l), min(a_h, b_h))
|
||||
|
||||
else:
|
||||
assert False
|
||||
@ -253,15 +260,24 @@ class _DimAtom:
|
||||
return divmod(*operand_values)[0] # type: ignore
|
||||
elif self.operation == _DimAtom.MOD:
|
||||
return divmod(*operand_values)[1] # type: ignore
|
||||
elif self.operation == _DimAtom.NON_NEGATIVE:
|
||||
operand = operand_values[0]
|
||||
if core.is_constant_dim(operand):
|
||||
return max(operand, 0)
|
||||
if core.is_symbolic_dim(operand):
|
||||
return core.non_negative_dim(operand)
|
||||
elif self.operation == _DimAtom.MAX:
|
||||
op1, op2 = operand_values
|
||||
if core.is_constant_dim(op1) and core.is_constant_dim(op2):
|
||||
return max(op1, op2)
|
||||
if core.is_symbolic_dim(op1) or core.is_symbolic_dim(op2):
|
||||
return core.max_dim(op1, op2)
|
||||
# In the context of `evaluate` dimension variables may be mapped to
|
||||
# JAX Tracers.
|
||||
return lax.max(operand, 0)
|
||||
return lax.max(op1, op2)
|
||||
elif self.operation == _DimAtom.MIN:
|
||||
op1, op2 = operand_values
|
||||
if core.is_constant_dim(op1) and core.is_constant_dim(op2):
|
||||
return min(op1, op2)
|
||||
if core.is_symbolic_dim(op1) or core.is_symbolic_dim(op2):
|
||||
return core.min_dim(op1, op2)
|
||||
# In the context of `evaluate` dimension variables may be mapped to
|
||||
# JAX Tracers.
|
||||
return lax.min(op1, op2)
|
||||
else:
|
||||
assert False, self.operation
|
||||
|
||||
@ -489,7 +505,12 @@ class _DimExpr():
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: DimSize) -> _DimExpr:
|
||||
return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1)
|
||||
if operation == _DimAtom.NON_NEGATIVE: # For parsing
|
||||
return _DimExpr.from_monomial(_DimMon.from_operation(_DimAtom.MAX,
|
||||
*operands,
|
||||
0), 1)
|
||||
return _DimExpr.from_monomial(_DimMon.from_operation(operation,
|
||||
*operands), 1)
|
||||
|
||||
def to_monomial(self) -> _DimMon | None:
|
||||
"""Extract the single monomial from a symbolic expression.
|
||||
@ -566,26 +587,42 @@ class _DimExpr():
|
||||
return True
|
||||
if ub < 0:
|
||||
return False
|
||||
# Attempt to handle non_negative. For the decomposition:
|
||||
# e = factor * non_negative(operand)^exp * rest_monomial + rest_expr
|
||||
# use the rule:
|
||||
# e >= 0 IF factor * operand^exp * rest_monomial + rest_expr >= 0 AND
|
||||
# (rest_expr >= 0 OR
|
||||
# exp is odd AND factor * rest_monomial >= 0 OR
|
||||
# exp is even AND factor * rest_monomial <= 0)
|
||||
for dec in _decompose_expr(self_minus_other, _DimAtom.NON_NEGATIVE):
|
||||
# e = factor * non_negative(operands)^exp * rest_monomial + rest_expr
|
||||
e2 = dec.rest_expr + dec.factor * (
|
||||
dec.operands[0] ** dec.exp) * dec.rest_monomial
|
||||
if not definitely_geq_0(e2):
|
||||
continue
|
||||
if definitely_geq_0(dec.rest_expr):
|
||||
return True
|
||||
if dec.exp % 2 == 1:
|
||||
if definitely_geq_0(dec.factor * dec.rest_monomial):
|
||||
# Attempt to handle max. For the decomposition
|
||||
# e = factor * max(op1, op2) + rest_expr
|
||||
# use the rule
|
||||
# e >= 0 IF (factor > 0 AND
|
||||
# (factor * op1 + rest_expr >= 0 OR
|
||||
# factor * op2 + rest_expr >= 0))
|
||||
# OR
|
||||
# (factor < 0 AND
|
||||
# (factor * op1 + rest_expr >= 0 AND
|
||||
# factor * op2 + rest_expr >= 0))
|
||||
for dec in _decompose_expr(self_minus_other, _DimAtom.MAX,
|
||||
with_exp=1, with_rest_monomial=1):
|
||||
op1, op2 = dec.operands
|
||||
if dec.factor > 0:
|
||||
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) or
|
||||
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
|
||||
return True
|
||||
else:
|
||||
if definitely_geq_0(- dec.factor * dec.rest_monomial):
|
||||
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) and
|
||||
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
|
||||
return True
|
||||
|
||||
# Attempt to handle min. For the decomposition
|
||||
# e = factor * min(op1, op2) + rest_expr
|
||||
# use the same rule as for
|
||||
# e = max(factor * op1, factor * op2) + rest_expr
|
||||
for dec in _decompose_expr(self_minus_other, _DimAtom.MIN,
|
||||
with_exp=1, with_rest_monomial=1):
|
||||
op1, op2 = dec.operands
|
||||
if dec.factor > 0:
|
||||
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) and
|
||||
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
|
||||
return True
|
||||
else:
|
||||
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) or
|
||||
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
|
||||
return True
|
||||
|
||||
# Attempt to handle floordiv >= 0
|
||||
@ -837,8 +874,17 @@ class _DimExpr():
|
||||
for mon, coeff in self.monomials()]
|
||||
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
def non_negative(self) -> _DimExpr:
|
||||
return _DimExpr.from_operation(_DimAtom.NON_NEGATIVE, self)
|
||||
def max(self, other: DimSize) -> _DimExpr:
|
||||
return _DimExpr.from_operation(_DimAtom.MAX, self, other)
|
||||
|
||||
def rmax(self, other: DimSize) -> _DimExpr:
|
||||
return _DimExpr.from_operation(_DimAtom.MAX, other, self)
|
||||
|
||||
def min(self, other: DimSize) -> _DimExpr:
|
||||
return _DimExpr.from_operation(_DimAtom.MIN, self, other)
|
||||
|
||||
def rmin(self, other: DimSize) -> _DimExpr:
|
||||
return _DimExpr.from_operation(_DimAtom.MIN, other, self)
|
||||
|
||||
@staticmethod
|
||||
def get_aval(dim: _DimExpr):
|
||||
@ -1285,11 +1331,9 @@ class _Parser:
|
||||
|
||||
def atom(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
|
||||
if tok.exact_type == tokenize.NAME:
|
||||
if tok.string == _DimAtom.MOD:
|
||||
return self.atom_binary_op(_DimAtom.MOD, self.next_tok())
|
||||
if tok.string == _DimAtom.FLOORDIV:
|
||||
return self.atom_binary_op(_DimAtom.FLOORDIV, self.next_tok())
|
||||
if tok.string == _DimAtom.NON_NEGATIVE:
|
||||
if tok.string in (_DimAtom.MOD, _DimAtom.FLOORDIV, _DimAtom.MAX, _DimAtom.MIN):
|
||||
return self.atom_binary_op(tok.string, self.next_tok())
|
||||
if tok.string == _DimAtom.NON_NEGATIVE: # We still parse this for backwards compatibility
|
||||
return self.atom_unary_op(_DimAtom.NON_NEGATIVE, self.next_tok())
|
||||
return _DimExpr.from_var(tok.string), self.next_tok()
|
||||
number_sign = 1
|
||||
|
@ -775,7 +775,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def output_shape(b):
|
||||
return (b + b, b - b, b * b,
|
||||
(b + 13) // b, (b + 13) % b,
|
||||
core.non_negative_dim(b - 5))
|
||||
core.max_dim(b - 5, 0))
|
||||
def f(x): # x: f32[b]
|
||||
b = x.shape[0]
|
||||
return jnp.ones(output_shape(b), dtype=x.dtype)
|
||||
|
@ -140,11 +140,20 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
("a + -1", a - 1),
|
||||
("3 * a * mod(a + 2, b + 2)", 3 * a * ((a + 2) % (b + 2))),
|
||||
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2),
|
||||
("non_negative(a - 2)", "build_inside"),
|
||||
# Keep for backwards compatibility. We ought to be able to parse
|
||||
# non_negative
|
||||
("non_negative(a - 2)", core.max_dim(a - 2, 0)),
|
||||
("max(a, b)", "build_inside"),
|
||||
("min(a, b)", "build_inside"),
|
||||
]])
|
||||
def test_parse_dim(self, dim_spec, dim_poly):
|
||||
if dim_spec == "non_negative(a - 2)":
|
||||
dim_poly = core.non_negative_dim(DimExprTest.a - 2)
|
||||
elif dim_spec == "max(a, b)":
|
||||
dim_poly = core.max_dim(DimExprTest.a, DimExprTest.b)
|
||||
elif dim_spec == "min(a, b)":
|
||||
dim_poly = core.min_dim(DimExprTest.a, DimExprTest.b)
|
||||
|
||||
self.assertEqual((dim_poly,), shape_poly.symbolic_shape(dim_spec))
|
||||
self.assertEqual((dim_poly,), shape_poly.symbolic_shape(str(dim_poly)))
|
||||
|
||||
@ -397,6 +406,13 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
def test_min_dim(self):
|
||||
a, b, c = shape_poly.symbolic_shape("a, b, c")
|
||||
|
||||
self.assertEqual(core.min_dim(a, b).bounds(), (1, np.inf))
|
||||
self.assertEqual(core.min_dim(2, b).bounds(), (1, 2))
|
||||
self.assertEqual(core.min_dim(a, -2), -2)
|
||||
self.assertEqual(core.min_dim(a - 5, 1).bounds(), (-4, 1))
|
||||
self.assertEqual(core.min_dim(15 - a, 10).bounds(), (-np.inf, 10))
|
||||
self.assertEqual(core.min_dim(15 - a, 20).bounds(), (-np.inf, 14))
|
||||
|
||||
self.assertEqual(a, core.min_dim(a, a + 2))
|
||||
self.assertEqual(a - 2, core.min_dim(a, a - 2))
|
||||
self.assertGreaterEqual(a, core.min_dim(a, b))
|
||||
@ -404,9 +420,20 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertGreaterEqual(b, core.min_dim(a, b))
|
||||
self.assertGreaterEqual(b + c - 1, core.min_dim(a, b))
|
||||
|
||||
self.sampled_assertion(core.min_dim(a, 5),
|
||||
core.min_dim, a, 5)
|
||||
self.sampled_assertion(core.min_dim(5, a),
|
||||
core.min_dim, 5, a)
|
||||
def test_max_dim(self):
|
||||
a, b, c = shape_poly.symbolic_shape("a, b, c")
|
||||
|
||||
self.assertEqual(core.max_dim(a, b).bounds(), (1, np.inf))
|
||||
self.assertEqual(core.max_dim(2, b).bounds(), (2, np.inf))
|
||||
self.assertEqual(core.max_dim(a, 2).bounds(), (2, np.inf))
|
||||
self.assertEqual(core.max_dim(a - 5, 1).bounds(), (1, np.inf))
|
||||
self.assertEqual(core.max_dim(15 - a, 0).bounds(), (0, 14))
|
||||
self.assertEqual((core.max_dim(15 - a, 0) // 3).bounds(), (0, 4))
|
||||
|
||||
self.assertEqual(a + 2, core.max_dim(a, a + 2))
|
||||
self.assertEqual(a , core.max_dim(a, a - 2))
|
||||
self.assertGreaterEqual(core.max_dim(a, b), a)
|
||||
@ -415,6 +442,10 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertGreaterEqual(core.max_dim(a, b) + c - 1, b)
|
||||
|
||||
self.assertGreaterEqual(core.max_dim(a, b), core.min_dim(a, b))
|
||||
self.sampled_assertion(core.max_dim(a, 5),
|
||||
core.max_dim, a, 5)
|
||||
self.sampled_assertion(core.max_dim(5, a),
|
||||
core.max_dim, 5, a)
|
||||
|
||||
def test_clamp_dim(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
@ -583,7 +614,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.sampled_assertion(0, core.dilate_dim, 0, 3)
|
||||
self.sampled_assertion(a, core.dilate_dim, a, 1)
|
||||
self.sampled_assertion(2 * a - 1, core.dilate_dim, a, 2)
|
||||
self.sampled_assertion(core.non_negative_dim(2 * a - 3),
|
||||
self.sampled_assertion(core.max_dim(2 * a - 3, 0),
|
||||
core.dilate_dim, a - 1, 2)
|
||||
|
||||
def test_stride_dim(self):
|
||||
@ -600,7 +631,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.sampled_assertion(a - 1, core.stride_dim, a, 2, 1)
|
||||
self.sampled_assertion(a + 1, core.stride_dim, a * stride + 2, 2, stride)
|
||||
self.sampled_assertion((a - 1) // 2 + 1, core.stride_dim, a, 1, 2)
|
||||
self.sampled_assertion(core.non_negative_dim((a - 4) // 2 + 1),
|
||||
self.sampled_assertion(core.max_dim((a - 4) // 2 + 1, 0),
|
||||
core.stride_dim, a, 4, 2)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user