Deprecate internal pretty-printing APIs, jax.core.pp_*

This commit is contained in:
Jake VanderPlas 2024-06-13 09:31:43 -07:00
parent 2679ece82d
commit f63b94574a
2 changed files with 28 additions and 12 deletions

View File

@ -16,6 +16,9 @@ Remember to align the itemized text with the first line of an item within a list
* jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
* Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release.
## jaxlib 0.4.30

View File

@ -118,18 +118,6 @@ from jax._src.core import (
no_effects as no_effects,
non_negative_dim as _deprecated_non_negative_dim,
outfeed_primitives as outfeed_primitives,
pp_aval as pp_aval,
pp_eqn as pp_eqn,
pp_eqn_rules as pp_eqn_rules,
pp_eqns as pp_eqns,
pp_jaxpr as pp_jaxpr,
pp_jaxpr_eqn_range as pp_jaxpr_eqn_range,
pp_jaxpr_skeleton as pp_jaxpr_skeleton,
pp_jaxprs as pp_jaxprs,
pp_kv_pair as pp_kv_pair,
pp_kv_pairs as pp_kv_pairs,
pp_var as pp_var,
pp_vars as pp_vars,
primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
primitive_uses_outfeed as primitive_uses_outfeed,
process_env_traces_call as process_env_traces_call,
@ -162,6 +150,19 @@ from jax._src.core import (
from jax._src import core as _src_core
_deprecations = {
# Added 2024-06-12
"pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval),
"pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn),
"pp_eqn_rules": ("jax.core.pp_eqn_rules is deprecated.", _src_core.pp_eqn_rules),
"pp_eqns": ("jax.core.pp_eqns is deprecated.", _src_core.pp_eqns),
"pp_jaxpr": ("jax.core.pp_jaxpr is deprecated.", _src_core.pp_jaxpr),
"pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range is deprecated.", _src_core.pp_jaxpr_eqn_range),
"pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton is deprecated.", _src_core.pp_jaxpr_skeleton),
"pp_jaxprs": ("jax.core.pp_jaxprs is deprecated.", _src_core.pp_jaxprs),
"pp_kv_pair": ("jax.core.pp_kv_pair is deprecated.", _src_core.pp_kv_pair),
"pp_kv_pairs": ("jax.core.pp_kv_pairs is deprecated.", _src_core.pp_kv_pairs),
"pp_var": ("jax.core.pp_var is deprecated.", _src_core.pp_var),
"pp_vars": ("jax.core.pp_vars is deprecated.", _src_core.pp_vars),
# Finalized 2024-05-13; remove after 2024-08-13
"DimSize": (
"jax.core.DimSize is deprecated. Use DimSize = int | Any.",
@ -196,6 +197,18 @@ if typing.TYPE_CHECKING:
dimension_as_value = _deprecated_dimension_as_value
definitely_equal = _deprecated_definitely_equal
non_negative_dim = _deprecated_non_negative_dim
pp_aval = _src_core.pp_aval
pp_eqn = _src_core.pp_eqn
pp_eqn_rules = _src_core.pp_eqn_rules
pp_eqns = _src_core.pp_eqns
pp_jaxpr = _src_core.pp_jaxpr
pp_jaxpr_eqn_range = _src_core.pp_jaxpr_eqn_range
pp_jaxpr_skeleton = _src_core.pp_jaxpr_skeleton
pp_jaxprs = _src_core.pp_jaxprs
pp_kv_pair = _src_core.pp_kv_pair
pp_kv_pairs = _src_core.pp_kv_pairs
pp_var = _src_core.pp_var
pp_vars = _src_core.pp_vars
symbolic_equal_dim = _deprecated_definitely_equal
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr