mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
This commit is contained in:
parent
985e74f2f3
commit
de9b98e0a8
12
CHANGELOG.md
12
CHANGELOG.md
@ -12,6 +12,18 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
|
||||
## jax 0.4.33
|
||||
|
||||
* Deletion:
|
||||
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
|
||||
in 0.4.30 JAX release.
|
||||
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
|
||||
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
|
||||
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
|
||||
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
|
||||
output information (like tree structure, shape and dtype).
|
||||
* For cross-backend lowering, you can replace
|
||||
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
|
||||
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
|
||||
|
||||
## jaxlib 0.4.33
|
||||
|
||||
|
||||
|
@ -69,7 +69,6 @@ Just-in-time compilation (:code:`jit`)
|
||||
jit
|
||||
disable_jit
|
||||
ensure_compile_time_eval
|
||||
xla_computation
|
||||
make_jaxpr
|
||||
eval_shape
|
||||
ShapeDtypeStruct
|
||||
|
@ -127,7 +127,6 @@ from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
from jax._src.api import value_and_grad as value_and_grad
|
||||
from jax._src.api import vjp as vjp
|
||||
from jax._src.api import vmap as vmap
|
||||
from jax._src.api import xla_computation as _deprecated_xla_computation
|
||||
from jax._src.sharding_impls import NamedSharding as NamedSharding
|
||||
from jax._src.sharding_impls import make_mesh as make_mesh
|
||||
|
||||
@ -224,20 +223,18 @@ _deprecations = {
|
||||
"jax.clear_backends is deprecated.",
|
||||
_deprecated_clear_backends
|
||||
),
|
||||
# Added Jun 16, 2024
|
||||
# Remove after jax 0.4.35 release.
|
||||
"xla_computation": (
|
||||
"jax.xla_computation is deprecated. Please use the AOT APIs; see "
|
||||
"jax.xla_computation is deleted. Please use the AOT APIs; see "
|
||||
"https://jax.readthedocs.io/en/latest/aot.html. For example, replace "
|
||||
"xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See "
|
||||
"CHANGELOG.md for 0.4.30 for more examples.",
|
||||
_deprecated_xla_computation
|
||||
"CHANGELOG.md for 0.4.30 for more examples.", None
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from jax._src.api import clear_backends as clear_backends
|
||||
from jax._src.api import xla_computation as xla_computation
|
||||
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
|
||||
from jax._src.tree_util import tree_flatten as tree_flatten
|
||||
from jax._src.tree_util import tree_leaves as tree_leaves
|
||||
|
245
jax/_src/api.py
245
jax/_src/api.py
@ -46,7 +46,6 @@ from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import array
|
||||
from jax._src import basearray
|
||||
from jax._src import distributed
|
||||
@ -60,7 +59,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||
argnums_partial_except, flatten_axes, donation_vector,
|
||||
flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info,
|
||||
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
|
||||
@ -73,13 +72,11 @@ from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
||||
from jax._src.layout import Layout, AutoLayout
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps,
|
||||
split_list)
|
||||
from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list
|
||||
from jax._src import util
|
||||
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
@ -337,244 +334,6 @@ def disable_jit(disable: bool = True):
|
||||
yield
|
||||
|
||||
|
||||
def xla_computation(fun: Callable,
|
||||
static_argnums: int | Iterable[int] = (),
|
||||
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
||||
in_parts=None, out_parts=None,
|
||||
backend: str | None = None,
|
||||
tuple_args: bool = False,
|
||||
instantiate_const_outputs: bool | None = None,
|
||||
return_shape: bool = False,
|
||||
donate_argnums: int | Iterable[int] = ()) -> Callable:
|
||||
"""Creates a function that produces its XLA computation given example args.
|
||||
|
||||
.. warning::
|
||||
|
||||
This function is deprecated as of JAX v0.4.30, and will be removed in a future
|
||||
JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for
|
||||
example, ``jax.xla_computation(fn)(*args)`` can be replaced with
|
||||
``jax.jit(fn).lower(*args).compiler_ir('hlo')``.
|
||||
See the `JAX 0.4.30 Change log`_ for more examples.
|
||||
|
||||
Args:
|
||||
fun: Function from which to form XLA computations.
|
||||
static_argnums: See the :py:func:`jax.jit` docstring.
|
||||
axis_env: Optional, a sequence of pairs where the first element is an axis
|
||||
name and the second element is a positive integer representing the size of
|
||||
the mapped axis with that name. This parameter is useful when lowering
|
||||
functions that involve parallel communication collectives, and it
|
||||
specifies the axis name/size environment that would be set up by
|
||||
applications of :py:func:`jax.pmap`. See the examples below.
|
||||
in_parts: Optional, how each argument to ``fun`` should be partitioned or
|
||||
replicated. This is used to specify partitioned XLA computations, see
|
||||
``sharded_jit`` for more info.
|
||||
out_parts: Optional, how each output of ``fun`` should be partitioned or
|
||||
replicated. This is used to specify partitioned XLA computations, see
|
||||
``sharded_jit`` for more info.
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting
|
||||
XLA computation will have a single tuple argument that is unpacked into
|
||||
the specified function arguments. If `None`, tupling will be enabled when
|
||||
there are more than 100 arguments, since some platforms have limits on
|
||||
argument arity.
|
||||
instantiate_const_outputs: Deprecated argument, does nothing.
|
||||
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
||||
wrapped function returns a pair where the first element is the XLA
|
||||
computation and the second element is a pytree with the same structure as
|
||||
the output of ``fun`` and where the leaves are objects with ``shape`` and
|
||||
``dtype`` attributes representing the corresponding types of the output
|
||||
leaves.
|
||||
donate_argnums: Specify which arguments are "donated" to the computation.
|
||||
It is safe to donate arguments if you no longer need them once the
|
||||
computation has finished. In some cases XLA can make use of donated
|
||||
buffers to reduce the amount of memory needed to perform a computation,
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun`` that when applied to example arguments returns
|
||||
a built XLA Computation (see xla_client.py), from which representations of
|
||||
the unoptimized XLA HLO computation can be extracted using methods like
|
||||
``as_hlo_text``, ``as_serialized_hlo_module_proto``, and
|
||||
``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
|
||||
wrapped function returns a pair where the first element is the XLA
|
||||
Computation and the second element is a pytree representing the structure,
|
||||
shapes, dtypes, and named shapes of the output of ``fun``.
|
||||
|
||||
Concrete example arguments are not always necessary. For those arguments not
|
||||
indicated by ``static_argnums``, any object with ``shape`` and ``dtype``
|
||||
attributes is acceptable (excepting namedtuples, which are treated as Python
|
||||
containers).
|
||||
|
||||
For example:
|
||||
|
||||
>>> import jax
|
||||
>>>
|
||||
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
||||
>>> c = jax.xla_computation(f)(3.) # doctest: +SKIP
|
||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
||||
HloModule xla_computation_f.6
|
||||
<BLANKLINE>
|
||||
ENTRY xla_computation_f.6 {
|
||||
constant.2 = pred[] constant(false)
|
||||
parameter.1 = f32[] parameter(0)
|
||||
cosine.3 = f32[] cosine(parameter.1)
|
||||
sine.4 = f32[] sine(cosine.3)
|
||||
ROOT tuple.5 = (f32[]) tuple(sine.4)
|
||||
}
|
||||
<BLANKLINE>
|
||||
<BLANKLINE>
|
||||
|
||||
|
||||
Alternatively, the assignment to ``c`` above could be written:
|
||||
|
||||
>>> import types
|
||||
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
|
||||
>>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP
|
||||
|
||||
|
||||
Here's an example that involves a parallel collective and axis name:
|
||||
|
||||
>>> def f(x): return x - jax.lax.psum(x, 'i')
|
||||
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP
|
||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
||||
HloModule jaxpr_computation.9
|
||||
primitive_computation.3 {
|
||||
parameter.4 = s32[] parameter(0)
|
||||
parameter.5 = s32[] parameter(1)
|
||||
ROOT add.6 = s32[] add(parameter.4, parameter.5)
|
||||
}
|
||||
ENTRY jaxpr_computation.9 {
|
||||
tuple.1 = () tuple()
|
||||
parameter.2 = s32[] parameter(0)
|
||||
all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
|
||||
ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
|
||||
}
|
||||
<BLANKLINE>
|
||||
<BLANKLINE>
|
||||
|
||||
Notice the ``replica_groups`` that were generated. Here's an example that
|
||||
generates more interesting ``replica_groups``:
|
||||
|
||||
>>> from jax import lax
|
||||
>>> def g(x):
|
||||
... rowsum = lax.psum(x, 'i')
|
||||
... colsum = lax.psum(x, 'j')
|
||||
... allsum = lax.psum(x, ('i', 'j'))
|
||||
... return rowsum, colsum, allsum
|
||||
...
|
||||
>>> axis_env = [('i', 4), ('j', 2)]
|
||||
>>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP
|
||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
||||
HloModule jaxpr_computation__1.19
|
||||
[removed uninteresting text here]
|
||||
ENTRY jaxpr_computation__1.19 {
|
||||
tuple.1 = () tuple()
|
||||
parameter.2 = f32[] parameter(0)
|
||||
all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
|
||||
all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
|
||||
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
|
||||
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
|
||||
}
|
||||
|
||||
.. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024
|
||||
"""
|
||||
if instantiate_const_outputs is not None:
|
||||
raise ValueError(
|
||||
"instantiate_const_outputs has been deprecated. Please use the ahead of"
|
||||
" time APIs. You can read more here:"
|
||||
" https://jax.readthedocs.io/en/latest/aot.html")
|
||||
if in_parts is not None:
|
||||
raise ValueError(
|
||||
"in_parts has been deprecated. Please use the ahead of time APIs. You"
|
||||
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
|
||||
if out_parts is not None:
|
||||
raise ValueError(
|
||||
"out_parts has been deprecated. Please use the ahead of time APIs. You"
|
||||
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
|
||||
|
||||
check_callable(fun)
|
||||
static_argnums = _ensure_index_tuple(static_argnums)
|
||||
donate_argnums = _ensure_index_tuple(donate_argnums)
|
||||
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
|
||||
|
||||
fun_name = getattr(fun, "__name__", "unknown")
|
||||
|
||||
platform = backend if backend is not None else xb.get_backend().platform
|
||||
|
||||
def make_axis_env(nreps):
|
||||
if axis_env is None:
|
||||
return sharding_impls.AxisEnv(nreps, (), ())
|
||||
else:
|
||||
nreps = nreps * math.prod(size for name, size in axis_env)
|
||||
names, sizes = unzip2(axis_env)
|
||||
return sharding_impls.AxisEnv(nreps, names, sizes)
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def computation_maker(*args, **kwargs):
|
||||
if max(static_argnums + donate_argnums, default=-1) >= len(args):
|
||||
raise ValueError(f"jitted function has {static_argnums=}, {donate_argnums=} but "
|
||||
f"was called with only {len(args)} positional arguments.")
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
|
||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
if donate_argnums:
|
||||
donated_invars = donation_vector(donate_argnums, (), in_tree)
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
|
||||
avals = map(shaped_abstractify, args_flat)
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
if axis_env:
|
||||
jaxpr = core.remove_named_axis_effects(
|
||||
jaxpr, {axis_name for axis_name, _ in axis_env}
|
||||
)
|
||||
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
|
||||
ordered_effects = list(
|
||||
effects.ordered_effects.filter_in(jaxpr.effects))
|
||||
lowering_result = mlir.lower_jaxpr_to_module(
|
||||
f"xla_computation_{fun_name}",
|
||||
core.ClosedJaxpr(jaxpr, consts),
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platforms=[platform],
|
||||
axis_context=sharding_impls.ReplicaAxisContext(axis_env_),
|
||||
name_stack=source_info_util.new_name_stack(
|
||||
wrap_name(fun_name, "xla_computation")),
|
||||
donated_args=donated_invars,
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
|
||||
m = mlir.module_to_bytecode(lowering_result.module)
|
||||
built = xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
m, use_tuple_args=tuple_args, return_tuple=True)
|
||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
||||
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
|
||||
for out_aval in out_avals:
|
||||
if not isinstance(out_aval, ShapedArray):
|
||||
raise RuntimeError("As we want to propagate the weak_type, we need "
|
||||
"to get a ShapedArray, otherwise this "
|
||||
"information is lost")
|
||||
|
||||
if return_shape:
|
||||
return built, out_shape
|
||||
else:
|
||||
return built
|
||||
|
||||
return computation_maker
|
||||
|
||||
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
|
||||
has_aux: bool = False, holomorphic: bool = False,
|
||||
allow_int: bool = False,
|
||||
|
@ -50,7 +50,6 @@ from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import deprecations
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -60,7 +59,6 @@ from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.compilation_cache import is_persistent_cache_enabled
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
import jax._src.util as jax_util
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
@ -2904,74 +2902,6 @@ class APITest(jtu.JaxTestCase):
|
||||
r"sub-dtype of np.floating\), but got complex.*"),
|
||||
lambda: dfn(3. + 1j))
|
||||
|
||||
def test_xla_computation(self):
|
||||
# these tests basically check the examples in the xla_computation docstring
|
||||
|
||||
def e(x):
|
||||
return jnp.sin(jnp.cos(x))
|
||||
c = api.xla_computation(e)(2.)
|
||||
self.assertIn('cosine', c.as_hlo_text())
|
||||
self.assertIn('sine', c.as_hlo_text())
|
||||
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i')
|
||||
axis_env = [('i', 4)]
|
||||
c = api.xla_computation(f, axis_env=axis_env)(2)
|
||||
self.assertIn('all-reduce', c.as_hlo_text())
|
||||
self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text())
|
||||
|
||||
def g(x):
|
||||
rowsum = lax.psum(x, 'i')
|
||||
colsum = lax.psum(x, 'j')
|
||||
allsum = lax.psum(x, ('i', 'j'))
|
||||
return rowsum, colsum, allsum
|
||||
axis_env = [('i', 4), ('j', 2)]
|
||||
c = api.xla_computation(g, axis_env=axis_env)(5.)
|
||||
self.assertIn('all-reduce', c.as_hlo_text())
|
||||
self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text())
|
||||
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
|
||||
self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text())
|
||||
|
||||
def h(x):
|
||||
rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]])
|
||||
colsum = lax.psum(x, 'j')
|
||||
return rowsum, colsum
|
||||
axis_env = [('i', 4), ('j', 2)]
|
||||
c = api.xla_computation(h, axis_env=axis_env)(5.)
|
||||
self.assertIn('all-reduce', c.as_hlo_text())
|
||||
self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text())
|
||||
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
|
||||
|
||||
def test_xla_computation_args(self):
|
||||
def foo(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
c = api.xla_computation(foo)(1., 2., 3.)
|
||||
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
|
||||
|
||||
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
||||
param_shapes = c.program_shape().parameter_shapes()
|
||||
self.assertEqual(len(param_shapes), 1)
|
||||
self.assertEqual(param_shapes[0].xla_element_type(),
|
||||
xla_client.PrimitiveType.TUPLE)
|
||||
|
||||
def test_xla_computation_duck_typing(self):
|
||||
def foo(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
x = jax.ShapeDtypeStruct((), np.float32)
|
||||
y = jax.ShapeDtypeStruct((), np.float32)
|
||||
z = jax.ShapeDtypeStruct((), np.float32)
|
||||
|
||||
c = api.xla_computation(foo)(x, y, z)
|
||||
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
|
||||
|
||||
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
||||
param_shapes = c.program_shape().parameter_shapes()
|
||||
self.assertEqual(len(param_shapes), 1)
|
||||
self.assertEqual(param_shapes[0].xla_element_type(),
|
||||
xla_client.PrimitiveType.TUPLE)
|
||||
|
||||
def test_compiler_ir(self):
|
||||
# TODO(phawkins): merge these tests with the `xla_computation` tests.
|
||||
def e(x):
|
||||
@ -2983,72 +2913,6 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIn("stablehlo.cosine", stablehlo)
|
||||
self.assertIn("stablehlo.sine", stablehlo)
|
||||
|
||||
def test_staging_out_multi_replica(self):
|
||||
def f(x):
|
||||
return api.pmap(jnp.mean)(x)
|
||||
xla_comp = api.xla_computation(f)
|
||||
xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash
|
||||
|
||||
def test_xla_computation_instantiate_constant_outputs(self):
|
||||
def f():
|
||||
return jnp.zeros((3, 4))
|
||||
|
||||
xla_comp = api.xla_computation(f)()
|
||||
out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
|
||||
self.assertEqual(out_shape.dimensions(), (3, 4))
|
||||
|
||||
def test_xla_computation_static_argnums(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
|
||||
hlo_text = xla_comp.as_hlo_text()
|
||||
self.assertIn("constant(3)", hlo_text)
|
||||
# The static arguments should be removed from the function being compiled,
|
||||
# thus the function should have only a single argument.
|
||||
self.assertIn("parameter(0)", hlo_text)
|
||||
self.assertNotIn("parameter(1)", hlo_text)
|
||||
|
||||
def test_xla_computation_return_shape(self):
|
||||
_, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
|
||||
return_shape=True)(np.int32(1))
|
||||
expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
|
||||
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
|
||||
self.assertEqual(shape_tree, expected)
|
||||
|
||||
def test_xla_computation_psum_constant(self):
|
||||
f = lambda: jax.lax.psum(1, "i")
|
||||
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
|
||||
|
||||
@jtu.ignore_warning(message="Some donated buffers were not usable")
|
||||
def test_xla_computation_donate_argnums(self):
|
||||
api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash
|
||||
|
||||
def test_xla_computation_lower_fun_axis_env(self):
|
||||
axis_name = 'i'
|
||||
def fn(x):
|
||||
y = lax.all_gather(
|
||||
x, axis_name=axis_name)
|
||||
return y * lax.axis_index(axis_name).astype(jnp.float32)
|
||||
|
||||
input_x = jnp.ones((5,6,4), dtype=jnp.float32)
|
||||
axis_env = [(axis_name, jax.local_device_count())]
|
||||
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
|
||||
def test_xla_computation_axis_env(self):
|
||||
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
|
||||
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
|
||||
|
||||
def fn(x):
|
||||
z = x * jax.lax.axis_index('i').astype(jnp.float32)
|
||||
def inner_fn(carry, a):
|
||||
return carry + a, ()
|
||||
return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
|
||||
|
||||
x = jnp.ones((5, 6, 4), dtype=jnp.float32)
|
||||
_ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
|
||||
|
||||
def test_concurrent_device_get_and_put(self):
|
||||
def f(x):
|
||||
for _ in range(100):
|
||||
@ -3678,7 +3542,7 @@ class APITest(jtu.JaxTestCase):
|
||||
return x + y + y
|
||||
|
||||
x = np.array([1, 2], dtype=np.float32)
|
||||
hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n')
|
||||
hlo_lines = jax.jit(f).lower(x).as_text('hlo').split('\n')
|
||||
hlo_lines = {s.strip() for s in hlo_lines}
|
||||
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
|
||||
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
|
||||
@ -3805,11 +3669,6 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||
g(1)
|
||||
|
||||
def test_xla_computation_zeros_doesnt_device_put(self):
|
||||
with jtu.count_device_put() as count:
|
||||
api.xla_computation(lambda: jnp.zeros(3))()
|
||||
self.assertEqual(count[0], 0)
|
||||
|
||||
def test_join_concrete_arrays_with_omnistaging(self):
|
||||
# https://github.com/google/jax/issues/4622
|
||||
x = jnp.array([1., 2., 3.])
|
||||
@ -5532,13 +5391,12 @@ class RematTest(jtu.JaxTestCase):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
c = api.xla_computation(f)(2.)
|
||||
self.assertNotIn('while', c.as_hlo_text())
|
||||
self.assertNotIn('conditional', c.as_hlo_text())
|
||||
self.assertNotIn('opt-barrier', c.as_hlo_text())
|
||||
text = jax.jit(f).lower(2.).as_text('hlo')
|
||||
self.assertNotIn('while', text)
|
||||
self.assertNotIn('conditional', text)
|
||||
self.assertNotIn('opt-barrier', text)
|
||||
|
||||
c = api.xla_computation(grad(f))(2.)
|
||||
text = c.as_hlo_text()
|
||||
text = jax.jit(grad(f)).lower(2.).as_text('hlo')
|
||||
self.assertTrue('while' in text or 'conditional' in text
|
||||
or 'opt-barrier' in text)
|
||||
|
||||
@ -5557,13 +5415,13 @@ class RematTest(jtu.JaxTestCase):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
c = api.xla_computation(f)(2.)
|
||||
self.assertNotIn('while', c.as_hlo_text())
|
||||
self.assertNotIn('conditional', c.as_hlo_text())
|
||||
text = jax.jit(f).lower(2.).as_text('hlo')
|
||||
self.assertNotIn('while', text)
|
||||
self.assertNotIn('conditional', text)
|
||||
|
||||
c = api.xla_computation(grad(f))(2.)
|
||||
self.assertNotIn('while', c.as_hlo_text())
|
||||
self.assertNotIn('conditional', c.as_hlo_text())
|
||||
text = jax.jit(grad(f)).lower(2.).as_text('hlo')
|
||||
self.assertNotIn('while', text)
|
||||
self.assertNotIn('conditional', text)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat,
|
||||
@ -6679,7 +6537,7 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
self.assertLen(jaxpr.jaxpr.eqns, 0)
|
||||
|
||||
def test_convert_element_type_literal_constant_folding(self):
|
||||
# this convert_elemnt_type is nontrivial, but because it's on a scalar we
|
||||
# this convert_element_type is nontrivial, but because it's on a scalar we
|
||||
# constant-fold it
|
||||
cet = partial(lax.convert_element_type, new_dtype='float16')
|
||||
jaxpr = api.make_jaxpr(lambda: cet(3.))()
|
||||
@ -10966,25 +10824,6 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
|
||||
|
||||
class NamedCallTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
|
||||
def test_default_name(self):
|
||||
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
|
||||
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
|
||||
|
||||
@api.named_call
|
||||
def my_test_function(x):
|
||||
return x**2
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return my_test_function(x)
|
||||
|
||||
c = xla_computation(f)(2)
|
||||
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
|
||||
print_opts.print_metadata = True
|
||||
hlo_text = c.as_hlo_module().to_string(print_opts)
|
||||
self.assertIn("my_test_function", hlo_text)
|
||||
|
||||
def test_non_jaxtype_arg(self):
|
||||
# For the test to fail without the invalid JaxType filter we need to pass
|
||||
# in a valid JaxType that forces the invalid Jaxtype to be raised to an
|
||||
|
Loading…
x
Reference in New Issue
Block a user