[shape_poly] Deprecate shape_poly.PolySpec.

This class has limited usefulness, and it seems worth
removing it in favor of using strings for polymorphic
specifications, thus reducing the API surface.
This commit is contained in:
George Necula 2024-01-10 09:05:16 +02:00
parent 761cf8ba7d
commit 6cac99e664
3 changed files with 14 additions and 24 deletions

View File

@ -28,13 +28,15 @@ Remember to align the itemized text with the first line of an item within a list
({jax-issue}`#19231`; note that this may result in user-visible behavior
changes)
* improved the error messages for inconclusive inequality comparisons
({jax-issue}`#19235`)
({jax-issue}`#19235`).
* the `core.non_negative_dim` API (introduced recently)
was deprecated and `core.max_dim` and `core.min_dim` were introduced
({jax-issue}`#18953`) to express `max` and `min` for symbolic dimensions.
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
* the `shape_poly.is_poly_dim` is deprecated in favor if `export.is_symbolic_dim`
({jax-issue}`#19282`).
* the `shape_poly.PolyShape` and `jax2tf.PolyShape` are deprecated, use
strings for polymorphic shapes specifications ({jax-issue}`#19284`).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will

View File

@ -1150,9 +1150,13 @@ class PolyShape(tuple):
"""
def __init__(self, *dim_specs):
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
DeprecationWarning, stacklevel=2)
tuple.__init__(dim_specs)
def __new__(cls, *dim_specs):
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
DeprecationWarning, stacklevel=2)
for ds in dim_specs:
if not isinstance(ds, (int, str)) and ds != ...:
msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string "
@ -1164,7 +1168,7 @@ class PolyShape(tuple):
return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")"
def symbolic_shape(shape_spec: str | PolyShape | None,
def symbolic_shape(shape_spec: str | None,
*,
like: Sequence[int | None] | None = None
) -> Sequence[DimSize]:
@ -1188,7 +1192,7 @@ def symbolic_shape(shape_spec: str | PolyShape | None,
shape_spec_repr = repr(shape_spec)
if shape_spec is None:
shape_spec = "..."
elif isinstance(shape_spec, PolyShape):
elif isinstance(shape_spec, PolyShape): # TODO: deprecate
shape_spec = str(shape_spec)
elif not isinstance(shape_spec, str):
raise ValueError("polymorphic shape spec should be None or a string. "

View File

@ -59,7 +59,6 @@ from jax._src.internal_test_util import test_harnesses
from jax._src.internal_test_util.test_harnesses import Harness, CustomArg, RandArg, StaticArg
from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
PS = jax2tf.PolyShape
_f32 = np.float32
_i32 = np.int32
@ -430,7 +429,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
"""Test conversion of actual arguments to abstract values."""
def check_avals(*, arg_shapes: Sequence[Sequence[int | None]],
polymorphic_shapes: Sequence[str | PS | None],
polymorphic_shapes: Sequence[str | None],
expected_avals: Sequence[core.ShapedArray] | None = None,
expected_shapeenv: dict[str, int] | None = None,
eager_mode: bool = False):
@ -486,27 +485,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=["(_, 3)"],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=[PS("_", 3)],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["..."],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=[PS(...)],
expected_avals=(shaped_array("2, 3", [2, 3]),))
# Partially known shapes for the arguments
check_avals(
arg_shapes=[(None, 3)],
polymorphic_shapes=[PS("b", ...)],
expected_avals=(shaped_array("(b, 3)", (2, 3)),))
check_avals(
arg_shapes=[(None, None)],
polymorphic_shapes=["h, h"],
@ -526,25 +510,25 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# Check cases when the specifications are polynomials
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=[PS("a + 1", "b + 2")],
polymorphic_shapes=["a + 1, b + 2"],
eager_mode=True,
expected_shapeenv=dict(a=1, b=1))
check_avals(
arg_shapes=[(7, 5)],
polymorphic_shapes=[PS("2 * a + b", "b + 2")],
polymorphic_shapes=["2 * a + b, b + 2"],
eager_mode=True,
expected_shapeenv=dict(a=2, b=3))
check_avals(
arg_shapes=[(7, 11, 4)],
polymorphic_shapes=[PS("2 * a + b", "b * b + 2", "b + 1")],
polymorphic_shapes=["2 * a + b, b * b + 2, b + 1"],
eager_mode=True,
expected_shapeenv=dict(a=2, b=3))
check_avals(
arg_shapes=[(7, 11, 19, 7)],
polymorphic_shapes=[PS("2 * a + b", "b * b + 2", "b + c * c", "2 * c + -1")],
polymorphic_shapes=["2 * a + b, b * b + 2, b + c * c, 2 * c + -1"],
eager_mode=True,
expected_shapeenv=dict(a=2, b=3, c=4))