diff --git a/CHANGELOG.md b/CHANGELOG.md index df2b5813c..1c29ae7dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,18 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.33 +* Deletion: + * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation + in 0.4.30 JAX release. + Please use the AOT APIs to get the same functionality as `jax.xla_computation`. + * `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with + `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`. + * You can also use `.out_info` property of `jax.stages.Lowered` to get the + output information (like tree structure, shape and dtype). + * For cross-backend lowering, you can replace + `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with + `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. + ## jaxlib 0.4.33 diff --git a/docs/jax.rst b/docs/jax.rst index b2c4ba607..a8781d31a 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -69,7 +69,6 @@ Just-in-time compilation (:code:`jit`) jit disable_jit ensure_compile_time_eval - xla_computation make_jaxpr eval_shape ShapeDtypeStruct diff --git a/jax/__init__.py b/jax/__init__.py index 7e958b21c..e2e302adb 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -127,7 +127,6 @@ from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap -from jax._src.api import xla_computation as _deprecated_xla_computation from jax._src.sharding_impls import NamedSharding as NamedSharding from jax._src.sharding_impls import make_mesh as make_mesh @@ -224,20 +223,18 @@ _deprecations = { "jax.clear_backends is deprecated.", _deprecated_clear_backends ), - # Added Jun 16, 2024 + # Remove after jax 0.4.35 release. "xla_computation": ( - "jax.xla_computation is deprecated. Please use the AOT APIs; see " + "jax.xla_computation is deleted. Please use the AOT APIs; see " "https://jax.readthedocs.io/en/latest/aot.html. For example, replace " "xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See " - "CHANGELOG.md for 0.4.30 for more examples.", - _deprecated_xla_computation + "CHANGELOG.md for 0.4.30 for more examples.", None ), } import typing as _typing if _typing.TYPE_CHECKING: from jax._src.api import clear_backends as clear_backends - from jax._src.api import xla_computation as xla_computation from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves diff --git a/jax/_src/api.py b/jax/_src/api.py index 935995ec5..8ed03e8e3 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -46,7 +46,6 @@ from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dispatch -from jax._src import effects from jax._src import array from jax._src import basearray from jax._src import distributed @@ -60,7 +59,7 @@ from jax._src import xla_bridge as xb from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray from jax._src.api_util import ( flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, - argnums_partial_except, flatten_axes, donation_vector, + flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, fun_sourceinfo) @@ -73,13 +72,11 @@ from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind from jax._src.layout import Layout, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util -from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps, - split_list) +from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.interpreters import xla @@ -337,244 +334,6 @@ def disable_jit(disable: bool = True): yield -def xla_computation(fun: Callable, - static_argnums: int | Iterable[int] = (), - axis_env: Sequence[tuple[AxisName, int]] | None = None, - in_parts=None, out_parts=None, - backend: str | None = None, - tuple_args: bool = False, - instantiate_const_outputs: bool | None = None, - return_shape: bool = False, - donate_argnums: int | Iterable[int] = ()) -> Callable: - """Creates a function that produces its XLA computation given example args. - - .. warning:: - - This function is deprecated as of JAX v0.4.30, and will be removed in a future - JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for - example, ``jax.xla_computation(fn)(*args)`` can be replaced with - ``jax.jit(fn).lower(*args).compiler_ir('hlo')``. - See the `JAX 0.4.30 Change log`_ for more examples. - - Args: - fun: Function from which to form XLA computations. - static_argnums: See the :py:func:`jax.jit` docstring. - axis_env: Optional, a sequence of pairs where the first element is an axis - name and the second element is a positive integer representing the size of - the mapped axis with that name. This parameter is useful when lowering - functions that involve parallel communication collectives, and it - specifies the axis name/size environment that would be set up by - applications of :py:func:`jax.pmap`. See the examples below. - in_parts: Optional, how each argument to ``fun`` should be partitioned or - replicated. This is used to specify partitioned XLA computations, see - ``sharded_jit`` for more info. - out_parts: Optional, how each output of ``fun`` should be partitioned or - replicated. This is used to specify partitioned XLA computations, see - ``sharded_jit`` for more info. - backend: This is an experimental feature and the API is likely to change. - Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting - XLA computation will have a single tuple argument that is unpacked into - the specified function arguments. If `None`, tupling will be enabled when - there are more than 100 arguments, since some platforms have limits on - argument arity. - instantiate_const_outputs: Deprecated argument, does nothing. - return_shape: Optional boolean, defaults to ``False``. If ``True``, the - wrapped function returns a pair where the first element is the XLA - computation and the second element is a pytree with the same structure as - the output of ``fun`` and where the leaves are objects with ``shape`` and - ``dtype`` attributes representing the corresponding types of the output - leaves. - donate_argnums: Specify which arguments are "donated" to the computation. - It is safe to donate arguments if you no longer need them once the - computation has finished. In some cases XLA can make use of donated - buffers to reduce the amount of memory needed to perform a computation, - for example recycling one of your input buffers to store a result. You - should not reuse buffers that you donate to a computation, JAX will raise - an error if you try to. - - Returns: - A wrapped version of ``fun`` that when applied to example arguments returns - a built XLA Computation (see xla_client.py), from which representations of - the unoptimized XLA HLO computation can be extracted using methods like - ``as_hlo_text``, ``as_serialized_hlo_module_proto``, and - ``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the - wrapped function returns a pair where the first element is the XLA - Computation and the second element is a pytree representing the structure, - shapes, dtypes, and named shapes of the output of ``fun``. - - Concrete example arguments are not always necessary. For those arguments not - indicated by ``static_argnums``, any object with ``shape`` and ``dtype`` - attributes is acceptable (excepting namedtuples, which are treated as Python - containers). - - For example: - - >>> import jax - >>> - >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) - >>> c = jax.xla_computation(f)(3.) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule xla_computation_f.6 - - ENTRY xla_computation_f.6 { - constant.2 = pred[] constant(false) - parameter.1 = f32[] parameter(0) - cosine.3 = f32[] cosine(parameter.1) - sine.4 = f32[] sine(cosine.3) - ROOT tuple.5 = (f32[]) tuple(sine.4) - } - - - - - Alternatively, the assignment to ``c`` above could be written: - - >>> import types - >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) - >>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP - - - Here's an example that involves a parallel collective and axis name: - - >>> def f(x): return x - jax.lax.psum(x, 'i') - >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule jaxpr_computation.9 - primitive_computation.3 { - parameter.4 = s32[] parameter(0) - parameter.5 = s32[] parameter(1) - ROOT add.6 = s32[] add(parameter.4, parameter.5) - } - ENTRY jaxpr_computation.9 { - tuple.1 = () tuple() - parameter.2 = s32[] parameter(0) - all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 - ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7) - } - - - - Notice the ``replica_groups`` that were generated. Here's an example that - generates more interesting ``replica_groups``: - - >>> from jax import lax - >>> def g(x): - ... rowsum = lax.psum(x, 'i') - ... colsum = lax.psum(x, 'j') - ... allsum = lax.psum(x, ('i', 'j')) - ... return rowsum, colsum, allsum - ... - >>> axis_env = [('i', 4), ('j', 2)] - >>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule jaxpr_computation__1.19 - [removed uninteresting text here] - ENTRY jaxpr_computation__1.19 { - tuple.1 = () tuple() - parameter.2 = f32[] parameter(0) - all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 - all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 - all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 - ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17) - } - - .. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024 - """ - if instantiate_const_outputs is not None: - raise ValueError( - "instantiate_const_outputs has been deprecated. Please use the ahead of" - " time APIs. You can read more here:" - " https://jax.readthedocs.io/en/latest/aot.html") - if in_parts is not None: - raise ValueError( - "in_parts has been deprecated. Please use the ahead of time APIs. You" - " can read more here: https://jax.readthedocs.io/en/latest/aot.html") - if out_parts is not None: - raise ValueError( - "out_parts has been deprecated. Please use the ahead of time APIs. You" - " can read more here: https://jax.readthedocs.io/en/latest/aot.html") - - check_callable(fun) - static_argnums = _ensure_index_tuple(static_argnums) - donate_argnums = _ensure_index_tuple(donate_argnums) - donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums) - - fun_name = getattr(fun, "__name__", "unknown") - - platform = backend if backend is not None else xb.get_backend().platform - - def make_axis_env(nreps): - if axis_env is None: - return sharding_impls.AxisEnv(nreps, (), ()) - else: - nreps = nreps * math.prod(size for name, size in axis_env) - names, sizes = unzip2(axis_env) - return sharding_impls.AxisEnv(nreps, names, sizes) - - @wraps(fun) - @api_boundary - def computation_maker(*args, **kwargs): - if max(static_argnums + donate_argnums, default=-1) >= len(args): - raise ValueError(f"jitted function has {static_argnums=}, {donate_argnums=} but " - f"was called with only {len(args)} positional arguments.") - - f = lu.wrap_init(fun) - f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False) - args_flat, in_tree = tree_flatten((dyn_args, kwargs)) - if donate_argnums: - donated_invars = donation_vector(donate_argnums, (), in_tree) - else: - donated_invars = (False,) * len(args_flat) - - jaxtree_fun, out_tree = flatten_fun(f, in_tree) - avals = map(shaped_abstractify, args_flat) - with ExitStack() as stack: - for axis_name, size in axis_env or []: - stack.enter_context(core.extend_axis_env(axis_name, size, None)) - jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) - jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) - if axis_env: - jaxpr = core.remove_named_axis_effects( - jaxpr, {axis_name for axis_name, _ in axis_env} - ) - axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr)) - ordered_effects = list( - effects.ordered_effects.filter_in(jaxpr.effects)) - lowering_result = mlir.lower_jaxpr_to_module( - f"xla_computation_{fun_name}", - core.ClosedJaxpr(jaxpr, consts), - ordered_effects=ordered_effects, - backend_or_name=backend, - platforms=[platform], - axis_context=sharding_impls.ReplicaAxisContext(axis_env_), - name_stack=source_info_util.new_name_stack( - wrap_name(fun_name, "xla_computation")), - donated_args=donated_invars, - arg_shardings=None, - result_shardings=None, - lowering_parameters=mlir.LoweringParameters()) - - m = mlir.module_to_bytecode(lowering_result.module) - built = xc._xla.mlir.mlir_module_to_xla_computation( - m, use_tuple_args=tuple_args, return_tuple=True) - out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] - out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] - out_shape = tree_unflatten(out_tree(), out_shapes_flat) - for out_aval in out_avals: - if not isinstance(out_aval, ShapedArray): - raise RuntimeError("As we want to propagate the weak_type, we need " - "to get a ShapedArray, otherwise this " - "information is lost") - - if return_shape: - return built, out_shape - else: - return built - - return computation_maker - def grad(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, diff --git a/tests/api_test.py b/tests/api_test.py index 1a119846b..fabdb9ffe 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -50,7 +50,6 @@ from jax._src import array from jax._src import config from jax._src import core from jax._src import custom_derivatives -from jax._src import deprecations from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge @@ -60,7 +59,6 @@ from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_client from jax._src.lib import xla_extension import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint @@ -2904,74 +2902,6 @@ class APITest(jtu.JaxTestCase): r"sub-dtype of np.floating\), but got complex.*"), lambda: dfn(3. + 1j)) - def test_xla_computation(self): - # these tests basically check the examples in the xla_computation docstring - - def e(x): - return jnp.sin(jnp.cos(x)) - c = api.xla_computation(e)(2.) - self.assertIn('cosine', c.as_hlo_text()) - self.assertIn('sine', c.as_hlo_text()) - - def f(x): - return x - lax.psum(x, 'i') - axis_env = [('i', 4)] - c = api.xla_computation(f, axis_env=axis_env)(2) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text()) - - def g(x): - rowsum = lax.psum(x, 'i') - colsum = lax.psum(x, 'j') - allsum = lax.psum(x, ('i', 'j')) - return rowsum, colsum, allsum - axis_env = [('i', 4), ('j', 2)] - c = api.xla_computation(g, axis_env=axis_env)(5.) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text()) - - def h(x): - rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]) - colsum = lax.psum(x, 'j') - return rowsum, colsum - axis_env = [('i', 4), ('j', 2)] - c = api.xla_computation(h, axis_env=axis_env)(5.) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text()) - - def test_xla_computation_args(self): - def foo(x, y, z): - return x + y + z - - c = api.xla_computation(foo)(1., 2., 3.) - self.assertEqual(len(c.program_shape().parameter_shapes()), 3) - - c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.) - param_shapes = c.program_shape().parameter_shapes() - self.assertEqual(len(param_shapes), 1) - self.assertEqual(param_shapes[0].xla_element_type(), - xla_client.PrimitiveType.TUPLE) - - def test_xla_computation_duck_typing(self): - def foo(x, y, z): - return x + y + z - - x = jax.ShapeDtypeStruct((), np.float32) - y = jax.ShapeDtypeStruct((), np.float32) - z = jax.ShapeDtypeStruct((), np.float32) - - c = api.xla_computation(foo)(x, y, z) - self.assertEqual(len(c.program_shape().parameter_shapes()), 3) - - c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.) - param_shapes = c.program_shape().parameter_shapes() - self.assertEqual(len(param_shapes), 1) - self.assertEqual(param_shapes[0].xla_element_type(), - xla_client.PrimitiveType.TUPLE) - def test_compiler_ir(self): # TODO(phawkins): merge these tests with the `xla_computation` tests. def e(x): @@ -2983,72 +2913,6 @@ class APITest(jtu.JaxTestCase): self.assertIn("stablehlo.cosine", stablehlo) self.assertIn("stablehlo.sine", stablehlo) - def test_staging_out_multi_replica(self): - def f(x): - return api.pmap(jnp.mean)(x) - xla_comp = api.xla_computation(f) - xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash - - def test_xla_computation_instantiate_constant_outputs(self): - def f(): - return jnp.zeros((3, 4)) - - xla_comp = api.xla_computation(f)() - out_shape, = xla_comp.program_shape().result_shape().tuple_shapes() - self.assertEqual(out_shape.dimensions(), (3, 4)) - - def test_xla_computation_static_argnums(self): - def f(x, y): - return x + y - - xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3) - hlo_text = xla_comp.as_hlo_text() - self.assertIn("constant(3)", hlo_text) - # The static arguments should be removed from the function being compiled, - # thus the function should have only a single argument. - self.assertIn("parameter(0)", hlo_text) - self.assertNotIn("parameter(1)", hlo_text) - - def test_xla_computation_return_shape(self): - _, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)), - return_shape=True)(np.int32(1)) - expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32), - api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32)) - self.assertEqual(shape_tree, expected) - - def test_xla_computation_psum_constant(self): - f = lambda: jax.lax.psum(1, "i") - api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash - - @jtu.ignore_warning(message="Some donated buffers were not usable") - def test_xla_computation_donate_argnums(self): - api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash - - def test_xla_computation_lower_fun_axis_env(self): - axis_name = 'i' - def fn(x): - y = lax.all_gather( - x, axis_name=axis_name) - return y * lax.axis_index(axis_name).astype(jnp.float32) - - input_x = jnp.ones((5,6,4), dtype=jnp.float32) - axis_env = [(axis_name, jax.local_device_count())] - _ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x) - - @jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated') - def test_xla_computation_axis_env(self): - is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation') - xla_computation = api.xla_computation if is_accelerated else jax.xla_computation - - def fn(x): - z = x * jax.lax.axis_index('i').astype(jnp.float32) - def inner_fn(carry, a): - return carry + a, () - return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z) - - x = jnp.ones((5, 6, 4), dtype=jnp.float32) - _ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x) - def test_concurrent_device_get_and_put(self): def f(x): for _ in range(100): @@ -3678,7 +3542,7 @@ class APITest(jtu.JaxTestCase): return x + y + y x = np.array([1, 2], dtype=np.float32) - hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n') + hlo_lines = jax.jit(f).lower(x).as_text('hlo').split('\n') hlo_lines = {s.strip() for s in hlo_lines} self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines) self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines) @@ -3805,11 +3669,6 @@ class APITest(jtu.JaxTestCase): with self.assertRaisesRegex(core.ConcretizationTypeError, msg): g(1) - def test_xla_computation_zeros_doesnt_device_put(self): - with jtu.count_device_put() as count: - api.xla_computation(lambda: jnp.zeros(3))() - self.assertEqual(count[0], 0) - def test_join_concrete_arrays_with_omnistaging(self): # https://github.com/google/jax/issues/4622 x = jnp.array([1., 2., 3.]) @@ -5532,13 +5391,12 @@ class RematTest(jtu.JaxTestCase): x, _ = g(x) return x - c = api.xla_computation(f)(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) - self.assertNotIn('opt-barrier', c.as_hlo_text()) + text = jax.jit(f).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) + self.assertNotIn('opt-barrier', text) - c = api.xla_computation(grad(f))(2.) - text = c.as_hlo_text() + text = jax.jit(grad(f)).lower(2.).as_text('hlo') self.assertTrue('while' in text or 'conditional' in text or 'opt-barrier' in text) @@ -5557,13 +5415,13 @@ class RematTest(jtu.JaxTestCase): x, _ = g(x) return x - c = api.xla_computation(f)(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) + text = jax.jit(f).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) - c = api.xla_computation(grad(f))(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) + text = jax.jit(grad(f)).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) @parameterized.named_parameters( {"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat, @@ -6679,7 +6537,7 @@ class JaxprTest(jtu.JaxTestCase): self.assertLen(jaxpr.jaxpr.eqns, 0) def test_convert_element_type_literal_constant_folding(self): - # this convert_elemnt_type is nontrivial, but because it's on a scalar we + # this convert_element_type is nontrivial, but because it's on a scalar we # constant-fold it cet = partial(lax.convert_element_type, new_dtype='float16') jaxpr = api.make_jaxpr(lambda: cet(3.))() @@ -10966,25 +10824,6 @@ class BufferDonationTest(jtu.BufferDonationTestCase): class NamedCallTest(jtu.JaxTestCase): - @jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated') - def test_default_name(self): - is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation') - xla_computation = api.xla_computation if is_accelerated else jax.xla_computation - - @api.named_call - def my_test_function(x): - return x**2 - - @jax.jit - def f(x): - return my_test_function(x) - - c = xla_computation(f)(2) - print_opts = xla_client._xla.HloPrintOptions.short_parsable() - print_opts.print_metadata = True - hlo_text = c.as_hlo_module().to_string(print_opts) - self.assertIn("my_test_function", hlo_text) - def test_non_jaxtype_arg(self): # For the test to fail without the invalid JaxType filter we need to pass # in a valid JaxType that forces the invalid Jaxtype to be raised to an