Remove the unaccelerate_deprecation utility

This commit is contained in:
Jake VanderPlas 2024-07-23 05:07:49 -07:00
parent 459b83cf4a
commit f887b66d5d
5 changed files with 36 additions and 24 deletions

View File

@ -346,6 +346,14 @@ def xla_computation(fun: Callable,
donate_argnums: int | Iterable[int] = ()) -> Callable:
"""Creates a function that produces its XLA computation given example args.
.. warning::
This function is deprecated as of JAX v0.4.30, and will be removed in a future
JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for
example, ``jax.xla_computation(fn)(*args)`` can be replaced with
``jax.jit(fn).lower(*args).compiler_ir('hlo')``.
See the `JAX 0.4.30 Change log`_ for more examples.
Args:
fun: Function from which to form XLA computations.
static_argnums: See the :py:func:`jax.jit` docstring.
@ -404,7 +412,7 @@ def xla_computation(fun: Callable,
>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> c = jax.xla_computation(f)(3.)
>>> c = jax.xla_computation(f)(3.) # doctest: +SKIP
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule xla_computation_f.6
<BLANKLINE>
@ -423,13 +431,13 @@ def xla_computation(fun: Callable,
>>> import types
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> c = jax.xla_computation(f)(scalar)
>>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP
Here's an example that involves a parallel collective and axis name:
>>> def f(x): return x - jax.lax.psum(x, 'i')
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule jaxpr_computation.9
primitive_computation.3 {
@ -457,7 +465,7 @@ def xla_computation(fun: Callable,
... return rowsum, colsum, allsum
...
>>> axis_env = [('i', 4), ('j', 2)]
>>> c = xla_computation(g, axis_env=axis_env)(5.)
>>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule jaxpr_computation__1.19
[removed uninteresting text here]
@ -469,6 +477,8 @@ def xla_computation(fun: Callable,
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}
.. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024
"""
if instantiate_const_outputs is not None:
raise ValueError(

View File

@ -66,6 +66,13 @@ def accelerate_getattr_deprecation(module: ModuleType, name: str) -> None:
message, _ = module._deprecations[name]
module._deprecations[name] = (message, None)
def is_accelerated_attribute(module: ModuleType, name: str) -> bool:
"""Returns true if given name is accelerated.
Raises an error if name is not a deprecated attribute in module.
"""
return module._deprecations[name][1] is None
# The following mechanism is a separate one, for registering and
# accelerating deprecations that are not imports (for example, deprecations
# of a function argument).

View File

@ -188,18 +188,6 @@ def check_eq(xs, ys, err_msg=''):
tree_all(tree_map(assert_close, xs, ys))
# TODO(yashkatariya): Make this context manager check for deprecation message
# in OSS.
@contextmanager
def unaccelerate_getattr_deprecation(module, name):
message, prev_attr = module._deprecations[name]
module._deprecations[name] = (message, getattr(module, f"_deprecated_{name}"))
try:
yield
finally:
module._deprecations[name] = (message, prev_attr)
@contextmanager
def _capture_output(fp: TextIO) -> Generator[Callable[[], str], None, None]:
"""Context manager to capture all output written to a given file object.

View File

@ -54,16 +54,16 @@ filterwarnings = [
"default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'",
"default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'",
"default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
"default:jax.xla_computation is deprecated. Please use the AOT APIs.*:DeprecationWarning",
# TODO(jakevdp): remove when array_api_tests stabilize
# start array_api_tests-related warnings
"default:.*not machine-readable.*:UserWarning",
"default:Special cases found for .* but none were parsed.*:UserWarning",
"default:.*is not JSON-serializable. Using the repr instead.",
# end array_api_tests-related warnings
# This is a transitive warning coming from TensorFlow dependencies.
# These are transitive warnings coming from TensorFlow dependencies.
# TODO(slebedev): Remove once we bump the minimum TensorFlow version.
"default:The key path API is deprecated .*",
"default:jax.xla_computation is deprecated.*:DeprecationWarning",
]
doctest_optionflags = [
"NUMBER",

View File

@ -50,6 +50,7 @@ from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import deprecations
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -3020,8 +3021,11 @@ class APITest(jtu.JaxTestCase):
axis_env = [(axis_name, jax.local_device_count())]
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
@jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation')
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
def test_xla_computation_axis_env(self):
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
def fn(x):
z = x * jax.lax.axis_index('i').astype(jnp.float32)
def inner_fn(carry, a):
@ -3029,7 +3033,7 @@ class APITest(jtu.JaxTestCase):
return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
x = jnp.ones((5, 6, 4), dtype=jnp.float32)
_ = jax.xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
_ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
def test_concurrent_device_get_and_put(self):
def f(x):
@ -10760,8 +10764,11 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
class NamedCallTest(jtu.JaxTestCase):
@jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation')
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
def test_default_name(self):
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
@api.named_call
def my_test_function(x):
return x**2
@ -10770,7 +10777,7 @@ class NamedCallTest(jtu.JaxTestCase):
def f(x):
return my_test_function(x)
c = jax.xla_computation(f)(2)
c = xla_computation(f)(2)
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
print_opts.print_metadata = True
hlo_text = c.as_hlo_module().to_string(print_opts)