diff --git a/CHANGELOG.md b/CHANGELOG.md index 57e73ac37..aa1b1c152 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,13 +28,15 @@ Remember to align the itemized text with the first line of an item within a list ({jax-issue}`#19231`; note that this may result in user-visible behavior changes) * improved the error messages for inconclusive inequality comparisons - ({jax-issue}`#19235`) + ({jax-issue}`#19235`). * the `core.non_negative_dim` API (introduced recently) was deprecated and `core.max_dim` and `core.min_dim` were introduced ({jax-issue}`#18953`) to express `max` and `min` for symbolic dimensions. You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`. * the `shape_poly.is_poly_dim` is deprecated in favor if `export.is_symbolic_dim` ({jax-issue}`#19282`). + * the `shape_poly.PolyShape` and `jax2tf.PolyShape` are deprecated, use + strings for polymorphic shapes specifications ({jax-issue}`#19284`). * Refactored the API for `jax.experimental.export`. Instead of `from jax.experimental.export import export` you should use now `from jax.experimental import export`. The old way of importing will diff --git a/jax/experimental/export/shape_poly.py b/jax/experimental/export/shape_poly.py index d667708a9..1a760d6b3 100644 --- a/jax/experimental/export/shape_poly.py +++ b/jax/experimental/export/shape_poly.py @@ -1150,9 +1150,13 @@ class PolyShape(tuple): """ def __init__(self, *dim_specs): + warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes", + DeprecationWarning, stacklevel=2) tuple.__init__(dim_specs) def __new__(cls, *dim_specs): + warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes", + DeprecationWarning, stacklevel=2) for ds in dim_specs: if not isinstance(ds, (int, str)) and ds != ...: msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string " @@ -1164,7 +1168,7 @@ class PolyShape(tuple): return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")" -def symbolic_shape(shape_spec: str | PolyShape | None, +def symbolic_shape(shape_spec: str | None, *, like: Sequence[int | None] | None = None ) -> Sequence[DimSize]: @@ -1188,7 +1192,7 @@ def symbolic_shape(shape_spec: str | PolyShape | None, shape_spec_repr = repr(shape_spec) if shape_spec is None: shape_spec = "..." - elif isinstance(shape_spec, PolyShape): + elif isinstance(shape_spec, PolyShape): # TODO: deprecate shape_spec = str(shape_spec) elif not isinstance(shape_spec, str): raise ValueError("polymorphic shape spec should be None or a string. " diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 5d0ea4a75..79d02d5bd 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -59,7 +59,6 @@ from jax._src.internal_test_util import test_harnesses from jax._src.internal_test_util.test_harnesses import Harness, CustomArg, RandArg, StaticArg from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation -PS = jax2tf.PolyShape _f32 = np.float32 _i32 = np.int32 @@ -430,7 +429,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): """Test conversion of actual arguments to abstract values.""" def check_avals(*, arg_shapes: Sequence[Sequence[int | None]], - polymorphic_shapes: Sequence[str | PS | None], + polymorphic_shapes: Sequence[str | None], expected_avals: Sequence[core.ShapedArray] | None = None, expected_shapeenv: dict[str, int] | None = None, eager_mode: bool = False): @@ -486,27 +485,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): polymorphic_shapes=["(_, 3)"], expected_avals=(shaped_array("2, 3", [2, 3]),)) - check_avals( - arg_shapes=[(2, 3)], - polymorphic_shapes=[PS("_", 3)], - expected_avals=(shaped_array("2, 3", [2, 3]),)) - check_avals( arg_shapes=[(2, 3)], polymorphic_shapes=["..."], expected_avals=(shaped_array("2, 3", [2, 3]),)) - check_avals( - arg_shapes=[(2, 3)], - polymorphic_shapes=[PS(...)], - expected_avals=(shaped_array("2, 3", [2, 3]),)) - # Partially known shapes for the arguments - check_avals( - arg_shapes=[(None, 3)], - polymorphic_shapes=[PS("b", ...)], - expected_avals=(shaped_array("(b, 3)", (2, 3)),)) - check_avals( arg_shapes=[(None, None)], polymorphic_shapes=["h, h"], @@ -526,25 +510,25 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): # Check cases when the specifications are polynomials check_avals( arg_shapes=[(2, 3)], - polymorphic_shapes=[PS("a + 1", "b + 2")], + polymorphic_shapes=["a + 1, b + 2"], eager_mode=True, expected_shapeenv=dict(a=1, b=1)) check_avals( arg_shapes=[(7, 5)], - polymorphic_shapes=[PS("2 * a + b", "b + 2")], + polymorphic_shapes=["2 * a + b, b + 2"], eager_mode=True, expected_shapeenv=dict(a=2, b=3)) check_avals( arg_shapes=[(7, 11, 4)], - polymorphic_shapes=[PS("2 * a + b", "b * b + 2", "b + 1")], + polymorphic_shapes=["2 * a + b, b * b + 2, b + 1"], eager_mode=True, expected_shapeenv=dict(a=2, b=3)) check_avals( arg_shapes=[(7, 11, 19, 7)], - polymorphic_shapes=[PS("2 * a + b", "b * b + 2", "b + c * c", "2 * c + -1")], + polymorphic_shapes=["2 * a + b, b * b + 2, b + c * c, 2 * c + -1"], eager_mode=True, expected_shapeenv=dict(a=2, b=3, c=4))