diff --git a/CHANGELOG.md b/CHANGELOG.md index b5758d107..258fad49b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. use `uses_global_constants`. * the `lowering_platforms` kwarg for {func}`jax.export.export`: use `platforms` instead. + * The kwargs `symbolic_scope` and `symbolic_constraints` from + {func}`jax.export.symbolic_args_specs` have been removed. They were + deprecated in June 2024. Use `scope` and `constraints` instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a `TypeError`. * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 15f99533d..cb9a99564 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1198,12 +1198,6 @@ def is_symbolic_dim(p: DimSize) -> bool: """ return isinstance(p, _DimExpr) -def is_poly_dim(p: DimSize) -> bool: - # TODO: deprecated January 2024, remove June 2024. - warnings.warn("is_poly_dim is deprecated, use export.is_symbolic_dim", - DeprecationWarning, stacklevel=2) - return is_symbolic_dim(p) - dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] def _einsum_contract_path(*operands, **kwargs): @@ -1413,8 +1407,6 @@ def symbolic_args_specs( shapes_specs, # prefix pytree of strings constraints: Sequence[str] = (), scope: SymbolicScope | None = None, - symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24 - symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24 ): """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. @@ -1435,25 +1427,10 @@ def symbolic_args_specs( arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). constraints: as for :func:`jax.export.symbolic_shape`. scope: as for :func:`jax.export.symbolic_shape`. - symbolic_constraints: DEPRECATED, use `constraints`. - symbolic_scope: DEPRECATED, use `scope`. Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes replaced with symbolic dimensions as specified by `shapes_specs`. """ - if symbolic_constraints: - warnings.warn("symbolic_constraints is deprecated, use constraints", - DeprecationWarning, stacklevel=2) - if constraints: - raise ValueError("Cannot use both symbolic_constraints and constraints") - constraints = symbolic_constraints - if symbolic_scope is not None: - warnings.warn("symbolic_scope is deprecated, use scope", - DeprecationWarning, stacklevel=2) - if scope is not None: - raise ValueError("Cannot use both symbolic_scope and scope") - scope = symbolic_scope - polymorphic_shapes = shapes_specs args_flat, args_tree = tree_util.tree_flatten(args)