From f63b94574a8b55106da3d9ce32cdddb8f5ee6635 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 13 Jun 2024 09:31:43 -0700 Subject: [PATCH] Deprecate internal pretty-printing APIs, jax.core.pp_* --- CHANGELOG.md | 3 +++ jax/core.py | 37 +++++++++++++++++++++++++------------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e4c0aeed..5cc510647 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ Remember to align the itemized text with the first line of an item within a list * jax now depends on jaxlib directly. This change was enabled by the CUDA plugin switch: there are no longer multiple jaxlib variants. You can install a CPU-only jax with `pip install jax`, no extras required. +* Deprecations + * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed + in a future release. ## jaxlib 0.4.30 diff --git a/jax/core.py b/jax/core.py index edc31778f..c23d37123 100644 --- a/jax/core.py +++ b/jax/core.py @@ -118,18 +118,6 @@ from jax._src.core import ( no_effects as no_effects, non_negative_dim as _deprecated_non_negative_dim, outfeed_primitives as outfeed_primitives, - pp_aval as pp_aval, - pp_eqn as pp_eqn, - pp_eqn_rules as pp_eqn_rules, - pp_eqns as pp_eqns, - pp_jaxpr as pp_jaxpr, - pp_jaxpr_eqn_range as pp_jaxpr_eqn_range, - pp_jaxpr_skeleton as pp_jaxpr_skeleton, - pp_jaxprs as pp_jaxprs, - pp_kv_pair as pp_kv_pair, - pp_kv_pairs as pp_kv_pairs, - pp_var as pp_var, - pp_vars as pp_vars, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, primitive_uses_outfeed as primitive_uses_outfeed, process_env_traces_call as process_env_traces_call, @@ -162,6 +150,19 @@ from jax._src.core import ( from jax._src import core as _src_core _deprecations = { + # Added 2024-06-12 + "pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval), + "pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn), + "pp_eqn_rules": ("jax.core.pp_eqn_rules is deprecated.", _src_core.pp_eqn_rules), + "pp_eqns": ("jax.core.pp_eqns is deprecated.", _src_core.pp_eqns), + "pp_jaxpr": ("jax.core.pp_jaxpr is deprecated.", _src_core.pp_jaxpr), + "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range is deprecated.", _src_core.pp_jaxpr_eqn_range), + "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton is deprecated.", _src_core.pp_jaxpr_skeleton), + "pp_jaxprs": ("jax.core.pp_jaxprs is deprecated.", _src_core.pp_jaxprs), + "pp_kv_pair": ("jax.core.pp_kv_pair is deprecated.", _src_core.pp_kv_pair), + "pp_kv_pairs": ("jax.core.pp_kv_pairs is deprecated.", _src_core.pp_kv_pairs), + "pp_var": ("jax.core.pp_var is deprecated.", _src_core.pp_var), + "pp_vars": ("jax.core.pp_vars is deprecated.", _src_core.pp_vars), # Finalized 2024-05-13; remove after 2024-08-13 "DimSize": ( "jax.core.DimSize is deprecated. Use DimSize = int | Any.", @@ -196,6 +197,18 @@ if typing.TYPE_CHECKING: dimension_as_value = _deprecated_dimension_as_value definitely_equal = _deprecated_definitely_equal non_negative_dim = _deprecated_non_negative_dim + pp_aval = _src_core.pp_aval + pp_eqn = _src_core.pp_eqn + pp_eqn_rules = _src_core.pp_eqn_rules + pp_eqns = _src_core.pp_eqns + pp_jaxpr = _src_core.pp_jaxpr + pp_jaxpr_eqn_range = _src_core.pp_jaxpr_eqn_range + pp_jaxpr_skeleton = _src_core.pp_jaxpr_skeleton + pp_jaxprs = _src_core.pp_jaxprs + pp_kv_pair = _src_core.pp_kv_pair + pp_kv_pairs = _src_core.pp_kv_pairs + pp_var = _src_core.pp_var + pp_vars = _src_core.pp_vars symbolic_equal_dim = _deprecated_definitely_equal else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr