mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove the unaccelerate_deprecation utility
This commit is contained in:
parent
459b83cf4a
commit
f887b66d5d
@ -346,6 +346,14 @@ def xla_computation(fun: Callable,
|
||||
donate_argnums: int | Iterable[int] = ()) -> Callable:
|
||||
"""Creates a function that produces its XLA computation given example args.
|
||||
|
||||
.. warning::
|
||||
|
||||
This function is deprecated as of JAX v0.4.30, and will be removed in a future
|
||||
JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for
|
||||
example, ``jax.xla_computation(fn)(*args)`` can be replaced with
|
||||
``jax.jit(fn).lower(*args).compiler_ir('hlo')``.
|
||||
See the `JAX 0.4.30 Change log`_ for more examples.
|
||||
|
||||
Args:
|
||||
fun: Function from which to form XLA computations.
|
||||
static_argnums: See the :py:func:`jax.jit` docstring.
|
||||
@ -404,7 +412,7 @@ def xla_computation(fun: Callable,
|
||||
>>> import jax
|
||||
>>>
|
||||
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
||||
>>> c = jax.xla_computation(f)(3.)
|
||||
>>> c = jax.xla_computation(f)(3.) # doctest: +SKIP
|
||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
||||
HloModule xla_computation_f.6
|
||||
<BLANKLINE>
|
||||
@ -423,13 +431,13 @@ def xla_computation(fun: Callable,
|
||||
|
||||
>>> import types
|
||||
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
|
||||
>>> c = jax.xla_computation(f)(scalar)
|
||||
>>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP
|
||||
|
||||
|
||||
Here's an example that involves a parallel collective and axis name:
|
||||
|
||||
>>> def f(x): return x - jax.lax.psum(x, 'i')
|
||||
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
|
||||
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP
|
||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
||||
HloModule jaxpr_computation.9
|
||||
primitive_computation.3 {
|
||||
@ -457,7 +465,7 @@ def xla_computation(fun: Callable,
|
||||
... return rowsum, colsum, allsum
|
||||
...
|
||||
>>> axis_env = [('i', 4), ('j', 2)]
|
||||
>>> c = xla_computation(g, axis_env=axis_env)(5.)
|
||||
>>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP
|
||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
||||
HloModule jaxpr_computation__1.19
|
||||
[removed uninteresting text here]
|
||||
@ -469,6 +477,8 @@ def xla_computation(fun: Callable,
|
||||
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
|
||||
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
|
||||
}
|
||||
|
||||
.. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024
|
||||
"""
|
||||
if instantiate_const_outputs is not None:
|
||||
raise ValueError(
|
||||
|
@ -66,6 +66,13 @@ def accelerate_getattr_deprecation(module: ModuleType, name: str) -> None:
|
||||
message, _ = module._deprecations[name]
|
||||
module._deprecations[name] = (message, None)
|
||||
|
||||
def is_accelerated_attribute(module: ModuleType, name: str) -> bool:
|
||||
"""Returns true if given name is accelerated.
|
||||
|
||||
Raises an error if name is not a deprecated attribute in module.
|
||||
"""
|
||||
return module._deprecations[name][1] is None
|
||||
|
||||
# The following mechanism is a separate one, for registering and
|
||||
# accelerating deprecations that are not imports (for example, deprecations
|
||||
# of a function argument).
|
||||
|
@ -188,18 +188,6 @@ def check_eq(xs, ys, err_msg=''):
|
||||
tree_all(tree_map(assert_close, xs, ys))
|
||||
|
||||
|
||||
# TODO(yashkatariya): Make this context manager check for deprecation message
|
||||
# in OSS.
|
||||
@contextmanager
|
||||
def unaccelerate_getattr_deprecation(module, name):
|
||||
message, prev_attr = module._deprecations[name]
|
||||
module._deprecations[name] = (message, getattr(module, f"_deprecated_{name}"))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
module._deprecations[name] = (message, prev_attr)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _capture_output(fp: TextIO) -> Generator[Callable[[], str], None, None]:
|
||||
"""Context manager to capture all output written to a given file object.
|
||||
|
@ -54,16 +54,16 @@ filterwarnings = [
|
||||
"default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'",
|
||||
"default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'",
|
||||
"default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
|
||||
"default:jax.xla_computation is deprecated. Please use the AOT APIs.*:DeprecationWarning",
|
||||
|
||||
# TODO(jakevdp): remove when array_api_tests stabilize
|
||||
# start array_api_tests-related warnings
|
||||
"default:.*not machine-readable.*:UserWarning",
|
||||
"default:Special cases found for .* but none were parsed.*:UserWarning",
|
||||
"default:.*is not JSON-serializable. Using the repr instead.",
|
||||
# end array_api_tests-related warnings
|
||||
# This is a transitive warning coming from TensorFlow dependencies.
|
||||
|
||||
# These are transitive warnings coming from TensorFlow dependencies.
|
||||
# TODO(slebedev): Remove once we bump the minimum TensorFlow version.
|
||||
"default:The key path API is deprecated .*",
|
||||
"default:jax.xla_computation is deprecated.*:DeprecationWarning",
|
||||
]
|
||||
doctest_optionflags = [
|
||||
"NUMBER",
|
||||
|
@ -50,6 +50,7 @@ from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import deprecations
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -3020,8 +3021,11 @@ class APITest(jtu.JaxTestCase):
|
||||
axis_env = [(axis_name, jax.local_device_count())]
|
||||
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
|
||||
|
||||
@jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation')
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
|
||||
def test_xla_computation_axis_env(self):
|
||||
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
|
||||
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
|
||||
|
||||
def fn(x):
|
||||
z = x * jax.lax.axis_index('i').astype(jnp.float32)
|
||||
def inner_fn(carry, a):
|
||||
@ -3029,7 +3033,7 @@ class APITest(jtu.JaxTestCase):
|
||||
return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
|
||||
|
||||
x = jnp.ones((5, 6, 4), dtype=jnp.float32)
|
||||
_ = jax.xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
|
||||
_ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
|
||||
|
||||
def test_concurrent_device_get_and_put(self):
|
||||
def f(x):
|
||||
@ -10760,8 +10764,11 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
|
||||
|
||||
class NamedCallTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.unaccelerate_getattr_deprecation(jax, 'xla_computation')
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
|
||||
def test_default_name(self):
|
||||
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
|
||||
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
|
||||
|
||||
@api.named_call
|
||||
def my_test_function(x):
|
||||
return x**2
|
||||
@ -10770,7 +10777,7 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return my_test_function(x)
|
||||
|
||||
c = jax.xla_computation(f)(2)
|
||||
c = xla_computation(f)(2)
|
||||
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
|
||||
print_opts.print_metadata = True
|
||||
hlo_text = c.as_hlo_module().to_string(print_opts)
|
||||
|
Loading…
x
Reference in New Issue
Block a user