Delete jax.xla_computation since it's been 3 months since it was deprecated.

PiperOrigin-RevId: 673938336
This commit is contained in:
Yash Katariya 2024-09-12 11:47:03 -07:00 committed by jax authors
parent 985e74f2f3
commit de9b98e0a8
5 changed files with 30 additions and 424 deletions

View File

@ -12,6 +12,18 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
## jax 0.4.33
* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
* For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
## jaxlib 0.4.33

View File

@ -69,7 +69,6 @@ Just-in-time compilation (:code:`jit`)
jit
disable_jit
ensure_compile_time_eval
xla_computation
make_jaxpr
eval_shape
ShapeDtypeStruct

View File

@ -127,7 +127,6 @@ from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import xla_computation as _deprecated_xla_computation
from jax._src.sharding_impls import NamedSharding as NamedSharding
from jax._src.sharding_impls import make_mesh as make_mesh
@ -224,20 +223,18 @@ _deprecations = {
"jax.clear_backends is deprecated.",
_deprecated_clear_backends
),
# Added Jun 16, 2024
# Remove after jax 0.4.35 release.
"xla_computation": (
"jax.xla_computation is deprecated. Please use the AOT APIs; see "
"jax.xla_computation is deleted. Please use the AOT APIs; see "
"https://jax.readthedocs.io/en/latest/aot.html. For example, replace "
"xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See "
"CHANGELOG.md for 0.4.30 for more examples.",
_deprecated_xla_computation
"CHANGELOG.md for 0.4.30 for more examples.", None
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.api import clear_backends as clear_backends
from jax._src.api import xla_computation as xla_computation
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves

View File

@ -46,7 +46,6 @@ from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import effects
from jax._src import array
from jax._src import basearray
from jax._src import distributed
@ -60,7 +59,7 @@ from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray
from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
argnums_partial_except, flatten_axes, donation_vector,
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
@ -73,13 +72,11 @@ from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps,
split_list)
from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
@ -337,244 +334,6 @@ def disable_jit(disable: bool = True):
yield
def xla_computation(fun: Callable,
static_argnums: int | Iterable[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
in_parts=None, out_parts=None,
backend: str | None = None,
tuple_args: bool = False,
instantiate_const_outputs: bool | None = None,
return_shape: bool = False,
donate_argnums: int | Iterable[int] = ()) -> Callable:
"""Creates a function that produces its XLA computation given example args.
.. warning::
This function is deprecated as of JAX v0.4.30, and will be removed in a future
JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for
example, ``jax.xla_computation(fn)(*args)`` can be replaced with
``jax.jit(fn).lower(*args).compiler_ir('hlo')``.
See the `JAX 0.4.30 Change log`_ for more examples.
Args:
fun: Function from which to form XLA computations.
static_argnums: See the :py:func:`jax.jit` docstring.
axis_env: Optional, a sequence of pairs where the first element is an axis
name and the second element is a positive integer representing the size of
the mapped axis with that name. This parameter is useful when lowering
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`. See the examples below.
in_parts: Optional, how each argument to ``fun`` should be partitioned or
replicated. This is used to specify partitioned XLA computations, see
``sharded_jit`` for more info.
out_parts: Optional, how each output of ``fun`` should be partitioned or
replicated. This is used to specify partitioned XLA computations, see
``sharded_jit`` for more info.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting
XLA computation will have a single tuple argument that is unpacked into
the specified function arguments. If `None`, tupling will be enabled when
there are more than 100 arguments, since some platforms have limits on
argument arity.
instantiate_const_outputs: Deprecated argument, does nothing.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
a built XLA Computation (see xla_client.py), from which representations of
the unoptimized XLA HLO computation can be extracted using methods like
``as_hlo_text``, ``as_serialized_hlo_module_proto``, and
``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
wrapped function returns a pair where the first element is the XLA
Computation and the second element is a pytree representing the structure,
shapes, dtypes, and named shapes of the output of ``fun``.
Concrete example arguments are not always necessary. For those arguments not
indicated by ``static_argnums``, any object with ``shape`` and ``dtype``
attributes is acceptable (excepting namedtuples, which are treated as Python
containers).
For example:
>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> c = jax.xla_computation(f)(3.) # doctest: +SKIP
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule xla_computation_f.6
<BLANKLINE>
ENTRY xla_computation_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = f32[] parameter(0)
cosine.3 = f32[] cosine(parameter.1)
sine.4 = f32[] sine(cosine.3)
ROOT tuple.5 = (f32[]) tuple(sine.4)
}
<BLANKLINE>
<BLANKLINE>
Alternatively, the assignment to ``c`` above could be written:
>>> import types
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP
Here's an example that involves a parallel collective and axis name:
>>> def f(x): return x - jax.lax.psum(x, 'i')
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule jaxpr_computation.9
primitive_computation.3 {
parameter.4 = s32[] parameter(0)
parameter.5 = s32[] parameter(1)
ROOT add.6 = s32[] add(parameter.4, parameter.5)
}
ENTRY jaxpr_computation.9 {
tuple.1 = () tuple()
parameter.2 = s32[] parameter(0)
all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
}
<BLANKLINE>
<BLANKLINE>
Notice the ``replica_groups`` that were generated. Here's an example that
generates more interesting ``replica_groups``:
>>> from jax import lax
>>> def g(x):
... rowsum = lax.psum(x, 'i')
... colsum = lax.psum(x, 'j')
... allsum = lax.psum(x, ('i', 'j'))
... return rowsum, colsum, allsum
...
>>> axis_env = [('i', 4), ('j', 2)]
>>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP
>>> print(c.as_hlo_text()) # doctest: +SKIP
HloModule jaxpr_computation__1.19
[removed uninteresting text here]
ENTRY jaxpr_computation__1.19 {
tuple.1 = () tuple()
parameter.2 = f32[] parameter(0)
all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}
.. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024
"""
if instantiate_const_outputs is not None:
raise ValueError(
"instantiate_const_outputs has been deprecated. Please use the ahead of"
" time APIs. You can read more here:"
" https://jax.readthedocs.io/en/latest/aot.html")
if in_parts is not None:
raise ValueError(
"in_parts has been deprecated. Please use the ahead of time APIs. You"
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
if out_parts is not None:
raise ValueError(
"out_parts has been deprecated. Please use the ahead of time APIs. You"
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
check_callable(fun)
static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
fun_name = getattr(fun, "__name__", "unknown")
platform = backend if backend is not None else xb.get_backend().platform
def make_axis_env(nreps):
if axis_env is None:
return sharding_impls.AxisEnv(nreps, (), ())
else:
nreps = nreps * math.prod(size for name, size in axis_env)
names, sizes = unzip2(axis_env)
return sharding_impls.AxisEnv(nreps, names, sizes)
@wraps(fun)
@api_boundary
def computation_maker(*args, **kwargs):
if max(static_argnums + donate_argnums, default=-1) >= len(args):
raise ValueError(f"jitted function has {static_argnums=}, {donate_argnums=} but "
f"was called with only {len(args)} positional arguments.")
f = lu.wrap_init(fun)
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
if donate_argnums:
donated_invars = donation_vector(donate_argnums, (), in_tree)
else:
donated_invars = (False,) * len(args_flat)
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
avals = map(shaped_abstractify, args_flat)
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
if axis_env:
jaxpr = core.remove_named_axis_effects(
jaxpr, {axis_name for axis_name, _ in axis_env}
)
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
ordered_effects = list(
effects.ordered_effects.filter_in(jaxpr.effects))
lowering_result = mlir.lower_jaxpr_to_module(
f"xla_computation_{fun_name}",
core.ClosedJaxpr(jaxpr, consts),
ordered_effects=ordered_effects,
backend_or_name=backend,
platforms=[platform],
axis_context=sharding_impls.ReplicaAxisContext(axis_env_),
name_stack=source_info_util.new_name_stack(
wrap_name(fun_name, "xla_computation")),
donated_args=donated_invars,
arg_shardings=None,
result_shardings=None,
lowering_parameters=mlir.LoweringParameters())
m = mlir.module_to_bytecode(lowering_result.module)
built = xc._xla.mlir.mlir_module_to_xla_computation(
m, use_tuple_args=tuple_args, return_tuple=True)
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
for out_aval in out_avals:
if not isinstance(out_aval, ShapedArray):
raise RuntimeError("As we want to propagate the weak_type, we need "
"to get a ShapedArray, otherwise this "
"information is lost")
if return_shape:
return built, out_shape
else:
return built
return computation_maker
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False,

View File

@ -50,7 +50,6 @@ from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import deprecations
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -60,7 +59,6 @@ from jax._src.ad_checkpoint import saved_residuals
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
@ -2904,74 +2902,6 @@ class APITest(jtu.JaxTestCase):
r"sub-dtype of np.floating\), but got complex.*"),
lambda: dfn(3. + 1j))
def test_xla_computation(self):
# these tests basically check the examples in the xla_computation docstring
def e(x):
return jnp.sin(jnp.cos(x))
c = api.xla_computation(e)(2.)
self.assertIn('cosine', c.as_hlo_text())
self.assertIn('sine', c.as_hlo_text())
def f(x):
return x - lax.psum(x, 'i')
axis_env = [('i', 4)]
c = api.xla_computation(f, axis_env=axis_env)(2)
self.assertIn('all-reduce', c.as_hlo_text())
self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text())
def g(x):
rowsum = lax.psum(x, 'i')
colsum = lax.psum(x, 'j')
allsum = lax.psum(x, ('i', 'j'))
return rowsum, colsum, allsum
axis_env = [('i', 4), ('j', 2)]
c = api.xla_computation(g, axis_env=axis_env)(5.)
self.assertIn('all-reduce', c.as_hlo_text())
self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text())
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text())
def h(x):
rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]])
colsum = lax.psum(x, 'j')
return rowsum, colsum
axis_env = [('i', 4), ('j', 2)]
c = api.xla_computation(h, axis_env=axis_env)(5.)
self.assertIn('all-reduce', c.as_hlo_text())
self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text())
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
def test_xla_computation_args(self):
def foo(x, y, z):
return x + y + z
c = api.xla_computation(foo)(1., 2., 3.)
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
param_shapes = c.program_shape().parameter_shapes()
self.assertEqual(len(param_shapes), 1)
self.assertEqual(param_shapes[0].xla_element_type(),
xla_client.PrimitiveType.TUPLE)
def test_xla_computation_duck_typing(self):
def foo(x, y, z):
return x + y + z
x = jax.ShapeDtypeStruct((), np.float32)
y = jax.ShapeDtypeStruct((), np.float32)
z = jax.ShapeDtypeStruct((), np.float32)
c = api.xla_computation(foo)(x, y, z)
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
param_shapes = c.program_shape().parameter_shapes()
self.assertEqual(len(param_shapes), 1)
self.assertEqual(param_shapes[0].xla_element_type(),
xla_client.PrimitiveType.TUPLE)
def test_compiler_ir(self):
# TODO(phawkins): merge these tests with the `xla_computation` tests.
def e(x):
@ -2983,72 +2913,6 @@ class APITest(jtu.JaxTestCase):
self.assertIn("stablehlo.cosine", stablehlo)
self.assertIn("stablehlo.sine", stablehlo)
def test_staging_out_multi_replica(self):
def f(x):
return api.pmap(jnp.mean)(x)
xla_comp = api.xla_computation(f)
xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash
def test_xla_computation_instantiate_constant_outputs(self):
def f():
return jnp.zeros((3, 4))
xla_comp = api.xla_computation(f)()
out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
self.assertEqual(out_shape.dimensions(), (3, 4))
def test_xla_computation_static_argnums(self):
def f(x, y):
return x + y
xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
hlo_text = xla_comp.as_hlo_text()
self.assertIn("constant(3)", hlo_text)
# The static arguments should be removed from the function being compiled,
# thus the function should have only a single argument.
self.assertIn("parameter(0)", hlo_text)
self.assertNotIn("parameter(1)", hlo_text)
def test_xla_computation_return_shape(self):
_, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
return_shape=True)(np.int32(1))
expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
self.assertEqual(shape_tree, expected)
def test_xla_computation_psum_constant(self):
f = lambda: jax.lax.psum(1, "i")
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
@jtu.ignore_warning(message="Some donated buffers were not usable")
def test_xla_computation_donate_argnums(self):
api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash
def test_xla_computation_lower_fun_axis_env(self):
axis_name = 'i'
def fn(x):
y = lax.all_gather(
x, axis_name=axis_name)
return y * lax.axis_index(axis_name).astype(jnp.float32)
input_x = jnp.ones((5,6,4), dtype=jnp.float32)
axis_env = [(axis_name, jax.local_device_count())]
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
def test_xla_computation_axis_env(self):
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
def fn(x):
z = x * jax.lax.axis_index('i').astype(jnp.float32)
def inner_fn(carry, a):
return carry + a, ()
return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
x = jnp.ones((5, 6, 4), dtype=jnp.float32)
_ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
def test_concurrent_device_get_and_put(self):
def f(x):
for _ in range(100):
@ -3678,7 +3542,7 @@ class APITest(jtu.JaxTestCase):
return x + y + y
x = np.array([1, 2], dtype=np.float32)
hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n')
hlo_lines = jax.jit(f).lower(x).as_text('hlo').split('\n')
hlo_lines = {s.strip() for s in hlo_lines}
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
@ -3805,11 +3669,6 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
g(1)
def test_xla_computation_zeros_doesnt_device_put(self):
with jtu.count_device_put() as count:
api.xla_computation(lambda: jnp.zeros(3))()
self.assertEqual(count[0], 0)
def test_join_concrete_arrays_with_omnistaging(self):
# https://github.com/google/jax/issues/4622
x = jnp.array([1., 2., 3.])
@ -5532,13 +5391,12 @@ class RematTest(jtu.JaxTestCase):
x, _ = g(x)
return x
c = api.xla_computation(f)(2.)
self.assertNotIn('while', c.as_hlo_text())
self.assertNotIn('conditional', c.as_hlo_text())
self.assertNotIn('opt-barrier', c.as_hlo_text())
text = jax.jit(f).lower(2.).as_text('hlo')
self.assertNotIn('while', text)
self.assertNotIn('conditional', text)
self.assertNotIn('opt-barrier', text)
c = api.xla_computation(grad(f))(2.)
text = c.as_hlo_text()
text = jax.jit(grad(f)).lower(2.).as_text('hlo')
self.assertTrue('while' in text or 'conditional' in text
or 'opt-barrier' in text)
@ -5557,13 +5415,13 @@ class RematTest(jtu.JaxTestCase):
x, _ = g(x)
return x
c = api.xla_computation(f)(2.)
self.assertNotIn('while', c.as_hlo_text())
self.assertNotIn('conditional', c.as_hlo_text())
text = jax.jit(f).lower(2.).as_text('hlo')
self.assertNotIn('while', text)
self.assertNotIn('conditional', text)
c = api.xla_computation(grad(f))(2.)
self.assertNotIn('while', c.as_hlo_text())
self.assertNotIn('conditional', c.as_hlo_text())
text = jax.jit(grad(f)).lower(2.).as_text('hlo')
self.assertNotIn('while', text)
self.assertNotIn('conditional', text)
@parameterized.named_parameters(
{"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat,
@ -6679,7 +6537,7 @@ class JaxprTest(jtu.JaxTestCase):
self.assertLen(jaxpr.jaxpr.eqns, 0)
def test_convert_element_type_literal_constant_folding(self):
# this convert_elemnt_type is nontrivial, but because it's on a scalar we
# this convert_element_type is nontrivial, but because it's on a scalar we
# constant-fold it
cet = partial(lax.convert_element_type, new_dtype='float16')
jaxpr = api.make_jaxpr(lambda: cet(3.))()
@ -10966,25 +10824,6 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
class NamedCallTest(jtu.JaxTestCase):
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
def test_default_name(self):
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
@api.named_call
def my_test_function(x):
return x**2
@jax.jit
def f(x):
return my_test_function(x)
c = xla_computation(f)(2)
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
print_opts.print_metadata = True
hlo_text = c.as_hlo_module().to_string(print_opts)
self.assertIn("my_test_function", hlo_text)
def test_non_jaxtype_arg(self):
# For the test to fail without the invalid JaxType filter we need to pass
# in a valid JaxType that forces the invalid Jaxtype to be raised to an