mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[shape_poly] Remove some deprecated kwargs
PiperOrigin-RevId: 703116755
This commit is contained in:
parent
e5102957b0
commit
5fe5206b6a
@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
use `uses_global_constants`.
|
||||
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
|
||||
`platforms` instead.
|
||||
* The kwargs `symbolic_scope` and `symbolic_constraints` from
|
||||
{func}`jax.export.symbolic_args_specs` have been removed. They were
|
||||
deprecated in June 2024. Use `scope` and `constraints` instead.
|
||||
* Hashing of tracers, which has been deprecated since version 0.4.30, now
|
||||
results in a `TypeError`.
|
||||
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
|
||||
|
@ -1198,12 +1198,6 @@ def is_symbolic_dim(p: DimSize) -> bool:
|
||||
"""
|
||||
return isinstance(p, _DimExpr)
|
||||
|
||||
def is_poly_dim(p: DimSize) -> bool:
|
||||
# TODO: deprecated January 2024, remove June 2024.
|
||||
warnings.warn("is_poly_dim is deprecated, use export.is_symbolic_dim",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return is_symbolic_dim(p)
|
||||
|
||||
dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int]
|
||||
|
||||
def _einsum_contract_path(*operands, **kwargs):
|
||||
@ -1413,8 +1407,6 @@ def symbolic_args_specs(
|
||||
shapes_specs, # prefix pytree of strings
|
||||
constraints: Sequence[str] = (),
|
||||
scope: SymbolicScope | None = None,
|
||||
symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24
|
||||
symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24
|
||||
):
|
||||
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
|
||||
|
||||
@ -1435,25 +1427,10 @@ def symbolic_args_specs(
|
||||
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||||
constraints: as for :func:`jax.export.symbolic_shape`.
|
||||
scope: as for :func:`jax.export.symbolic_shape`.
|
||||
symbolic_constraints: DEPRECATED, use `constraints`.
|
||||
symbolic_scope: DEPRECATED, use `scope`.
|
||||
|
||||
Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes
|
||||
replaced with symbolic dimensions as specified by `shapes_specs`.
|
||||
"""
|
||||
if symbolic_constraints:
|
||||
warnings.warn("symbolic_constraints is deprecated, use constraints",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
if constraints:
|
||||
raise ValueError("Cannot use both symbolic_constraints and constraints")
|
||||
constraints = symbolic_constraints
|
||||
if symbolic_scope is not None:
|
||||
warnings.warn("symbolic_scope is deprecated, use scope",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
if scope is not None:
|
||||
raise ValueError("Cannot use both symbolic_scope and scope")
|
||||
scope = symbolic_scope
|
||||
|
||||
polymorphic_shapes = shapes_specs
|
||||
args_flat, args_tree = tree_util.tree_flatten(args)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user