mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[shape_poly] Deprecate shape_poly.PolySpec.
This class has limited usefulness, and it seems worth removing it in favor of using strings for polymorphic specifications, thus reducing the API surface.
This commit is contained in:
parent
761cf8ba7d
commit
6cac99e664
@ -28,13 +28,15 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
({jax-issue}`#19231`; note that this may result in user-visible behavior
|
||||
changes)
|
||||
* improved the error messages for inconclusive inequality comparisons
|
||||
({jax-issue}`#19235`)
|
||||
({jax-issue}`#19235`).
|
||||
* the `core.non_negative_dim` API (introduced recently)
|
||||
was deprecated and `core.max_dim` and `core.min_dim` were introduced
|
||||
({jax-issue}`#18953`) to express `max` and `min` for symbolic dimensions.
|
||||
You can use `core.max_dim(d, 0)` instead of `core.non_negative_dim(d)`.
|
||||
* the `shape_poly.is_poly_dim` is deprecated in favor if `export.is_symbolic_dim`
|
||||
({jax-issue}`#19282`).
|
||||
* the `shape_poly.PolyShape` and `jax2tf.PolyShape` are deprecated, use
|
||||
strings for polymorphic shapes specifications ({jax-issue}`#19284`).
|
||||
* Refactored the API for `jax.experimental.export`. Instead of
|
||||
`from jax.experimental.export import export` you should use now
|
||||
`from jax.experimental import export`. The old way of importing will
|
||||
|
@ -1150,9 +1150,13 @@ class PolyShape(tuple):
|
||||
"""
|
||||
|
||||
def __init__(self, *dim_specs):
|
||||
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
tuple.__init__(dim_specs)
|
||||
|
||||
def __new__(cls, *dim_specs):
|
||||
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
for ds in dim_specs:
|
||||
if not isinstance(ds, (int, str)) and ds != ...:
|
||||
msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string "
|
||||
@ -1164,7 +1168,7 @@ class PolyShape(tuple):
|
||||
return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")"
|
||||
|
||||
|
||||
def symbolic_shape(shape_spec: str | PolyShape | None,
|
||||
def symbolic_shape(shape_spec: str | None,
|
||||
*,
|
||||
like: Sequence[int | None] | None = None
|
||||
) -> Sequence[DimSize]:
|
||||
@ -1188,7 +1192,7 @@ def symbolic_shape(shape_spec: str | PolyShape | None,
|
||||
shape_spec_repr = repr(shape_spec)
|
||||
if shape_spec is None:
|
||||
shape_spec = "..."
|
||||
elif isinstance(shape_spec, PolyShape):
|
||||
elif isinstance(shape_spec, PolyShape): # TODO: deprecate
|
||||
shape_spec = str(shape_spec)
|
||||
elif not isinstance(shape_spec, str):
|
||||
raise ValueError("polymorphic shape spec should be None or a string. "
|
||||
|
@ -59,7 +59,6 @@ from jax._src.internal_test_util import test_harnesses
|
||||
from jax._src.internal_test_util.test_harnesses import Harness, CustomArg, RandArg, StaticArg
|
||||
from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
|
||||
|
||||
PS = jax2tf.PolyShape
|
||||
_f32 = np.float32
|
||||
_i32 = np.int32
|
||||
|
||||
@ -430,7 +429,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
"""Test conversion of actual arguments to abstract values."""
|
||||
|
||||
def check_avals(*, arg_shapes: Sequence[Sequence[int | None]],
|
||||
polymorphic_shapes: Sequence[str | PS | None],
|
||||
polymorphic_shapes: Sequence[str | None],
|
||||
expected_avals: Sequence[core.ShapedArray] | None = None,
|
||||
expected_shapeenv: dict[str, int] | None = None,
|
||||
eager_mode: bool = False):
|
||||
@ -486,27 +485,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["(_, 3)"],
|
||||
expected_avals=(shaped_array("2, 3", [2, 3]),))
|
||||
|
||||
check_avals(
|
||||
arg_shapes=[(2, 3)],
|
||||
polymorphic_shapes=[PS("_", 3)],
|
||||
expected_avals=(shaped_array("2, 3", [2, 3]),))
|
||||
|
||||
check_avals(
|
||||
arg_shapes=[(2, 3)],
|
||||
polymorphic_shapes=["..."],
|
||||
expected_avals=(shaped_array("2, 3", [2, 3]),))
|
||||
|
||||
check_avals(
|
||||
arg_shapes=[(2, 3)],
|
||||
polymorphic_shapes=[PS(...)],
|
||||
expected_avals=(shaped_array("2, 3", [2, 3]),))
|
||||
|
||||
# Partially known shapes for the arguments
|
||||
check_avals(
|
||||
arg_shapes=[(None, 3)],
|
||||
polymorphic_shapes=[PS("b", ...)],
|
||||
expected_avals=(shaped_array("(b, 3)", (2, 3)),))
|
||||
|
||||
check_avals(
|
||||
arg_shapes=[(None, None)],
|
||||
polymorphic_shapes=["h, h"],
|
||||
@ -526,25 +510,25 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# Check cases when the specifications are polynomials
|
||||
check_avals(
|
||||
arg_shapes=[(2, 3)],
|
||||
polymorphic_shapes=[PS("a + 1", "b + 2")],
|
||||
polymorphic_shapes=["a + 1, b + 2"],
|
||||
eager_mode=True,
|
||||
expected_shapeenv=dict(a=1, b=1))
|
||||
|
||||
check_avals(
|
||||
arg_shapes=[(7, 5)],
|
||||
polymorphic_shapes=[PS("2 * a + b", "b + 2")],
|
||||
polymorphic_shapes=["2 * a + b, b + 2"],
|
||||
eager_mode=True,
|
||||
expected_shapeenv=dict(a=2, b=3))
|
||||
|
||||
check_avals(
|
||||
arg_shapes=[(7, 11, 4)],
|
||||
polymorphic_shapes=[PS("2 * a + b", "b * b + 2", "b + 1")],
|
||||
polymorphic_shapes=["2 * a + b, b * b + 2, b + 1"],
|
||||
eager_mode=True,
|
||||
expected_shapeenv=dict(a=2, b=3))
|
||||
|
||||
check_avals(
|
||||
arg_shapes=[(7, 11, 19, 7)],
|
||||
polymorphic_shapes=[PS("2 * a + b", "b * b + 2", "b + c * c", "2 * c + -1")],
|
||||
polymorphic_shapes=["2 * a + b, b * b + 2, b + c * c, 2 * c + -1"],
|
||||
eager_mode=True,
|
||||
expected_shapeenv=dict(a=2, b=3, c=4))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user