mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
This commit is contained in:
parent
985e74f2f3
commit
de9b98e0a8
12
CHANGELOG.md
12
CHANGELOG.md
@ -12,6 +12,18 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
|
|
||||||
## jax 0.4.33
|
## jax 0.4.33
|
||||||
|
|
||||||
|
* Deletion:
|
||||||
|
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
|
||||||
|
in 0.4.30 JAX release.
|
||||||
|
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
|
||||||
|
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
|
||||||
|
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
|
||||||
|
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
|
||||||
|
output information (like tree structure, shape and dtype).
|
||||||
|
* For cross-backend lowering, you can replace
|
||||||
|
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
|
||||||
|
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
|
||||||
|
|
||||||
## jaxlib 0.4.33
|
## jaxlib 0.4.33
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,7 +69,6 @@ Just-in-time compilation (:code:`jit`)
|
|||||||
jit
|
jit
|
||||||
disable_jit
|
disable_jit
|
||||||
ensure_compile_time_eval
|
ensure_compile_time_eval
|
||||||
xla_computation
|
|
||||||
make_jaxpr
|
make_jaxpr
|
||||||
eval_shape
|
eval_shape
|
||||||
ShapeDtypeStruct
|
ShapeDtypeStruct
|
||||||
|
@ -127,7 +127,6 @@ from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
|
|||||||
from jax._src.api import value_and_grad as value_and_grad
|
from jax._src.api import value_and_grad as value_and_grad
|
||||||
from jax._src.api import vjp as vjp
|
from jax._src.api import vjp as vjp
|
||||||
from jax._src.api import vmap as vmap
|
from jax._src.api import vmap as vmap
|
||||||
from jax._src.api import xla_computation as _deprecated_xla_computation
|
|
||||||
from jax._src.sharding_impls import NamedSharding as NamedSharding
|
from jax._src.sharding_impls import NamedSharding as NamedSharding
|
||||||
from jax._src.sharding_impls import make_mesh as make_mesh
|
from jax._src.sharding_impls import make_mesh as make_mesh
|
||||||
|
|
||||||
@ -224,20 +223,18 @@ _deprecations = {
|
|||||||
"jax.clear_backends is deprecated.",
|
"jax.clear_backends is deprecated.",
|
||||||
_deprecated_clear_backends
|
_deprecated_clear_backends
|
||||||
),
|
),
|
||||||
# Added Jun 16, 2024
|
# Remove after jax 0.4.35 release.
|
||||||
"xla_computation": (
|
"xla_computation": (
|
||||||
"jax.xla_computation is deprecated. Please use the AOT APIs; see "
|
"jax.xla_computation is deleted. Please use the AOT APIs; see "
|
||||||
"https://jax.readthedocs.io/en/latest/aot.html. For example, replace "
|
"https://jax.readthedocs.io/en/latest/aot.html. For example, replace "
|
||||||
"xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See "
|
"xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See "
|
||||||
"CHANGELOG.md for 0.4.30 for more examples.",
|
"CHANGELOG.md for 0.4.30 for more examples.", None
|
||||||
_deprecated_xla_computation
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
import typing as _typing
|
import typing as _typing
|
||||||
if _typing.TYPE_CHECKING:
|
if _typing.TYPE_CHECKING:
|
||||||
from jax._src.api import clear_backends as clear_backends
|
from jax._src.api import clear_backends as clear_backends
|
||||||
from jax._src.api import xla_computation as xla_computation
|
|
||||||
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
|
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
|
||||||
from jax._src.tree_util import tree_flatten as tree_flatten
|
from jax._src.tree_util import tree_flatten as tree_flatten
|
||||||
from jax._src.tree_util import tree_leaves as tree_leaves
|
from jax._src.tree_util import tree_leaves as tree_leaves
|
||||||
|
245
jax/_src/api.py
245
jax/_src/api.py
@ -46,7 +46,6 @@ from jax._src import api_util
|
|||||||
from jax._src import config
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import effects
|
|
||||||
from jax._src import array
|
from jax._src import array
|
||||||
from jax._src import basearray
|
from jax._src import basearray
|
||||||
from jax._src import distributed
|
from jax._src import distributed
|
||||||
@ -60,7 +59,7 @@ from jax._src import xla_bridge as xb
|
|||||||
from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray
|
from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray
|
||||||
from jax._src.api_util import (
|
from jax._src.api_util import (
|
||||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||||
argnums_partial_except, flatten_axes, donation_vector,
|
flatten_axes, donation_vector,
|
||||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||||
shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info,
|
shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info,
|
||||||
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
|
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
|
||||||
@ -73,13 +72,11 @@ from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
|||||||
from jax._src.layout import Layout, AutoLayout
|
from jax._src.layout import Layout, AutoLayout
|
||||||
from jax._src.traceback_util import api_boundary
|
from jax._src.traceback_util import api_boundary
|
||||||
from jax._src import tree_util
|
from jax._src import tree_util
|
||||||
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps,
|
from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list
|
||||||
split_list)
|
|
||||||
from jax._src import util
|
from jax._src import util
|
||||||
|
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
from jax._src.interpreters import mlir
|
|
||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
from jax._src.interpreters import pxla
|
from jax._src.interpreters import pxla
|
||||||
from jax._src.interpreters import xla
|
from jax._src.interpreters import xla
|
||||||
@ -337,244 +334,6 @@ def disable_jit(disable: bool = True):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
def xla_computation(fun: Callable,
|
|
||||||
static_argnums: int | Iterable[int] = (),
|
|
||||||
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
|
||||||
in_parts=None, out_parts=None,
|
|
||||||
backend: str | None = None,
|
|
||||||
tuple_args: bool = False,
|
|
||||||
instantiate_const_outputs: bool | None = None,
|
|
||||||
return_shape: bool = False,
|
|
||||||
donate_argnums: int | Iterable[int] = ()) -> Callable:
|
|
||||||
"""Creates a function that produces its XLA computation given example args.
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
|
|
||||||
This function is deprecated as of JAX v0.4.30, and will be removed in a future
|
|
||||||
JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for
|
|
||||||
example, ``jax.xla_computation(fn)(*args)`` can be replaced with
|
|
||||||
``jax.jit(fn).lower(*args).compiler_ir('hlo')``.
|
|
||||||
See the `JAX 0.4.30 Change log`_ for more examples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fun: Function from which to form XLA computations.
|
|
||||||
static_argnums: See the :py:func:`jax.jit` docstring.
|
|
||||||
axis_env: Optional, a sequence of pairs where the first element is an axis
|
|
||||||
name and the second element is a positive integer representing the size of
|
|
||||||
the mapped axis with that name. This parameter is useful when lowering
|
|
||||||
functions that involve parallel communication collectives, and it
|
|
||||||
specifies the axis name/size environment that would be set up by
|
|
||||||
applications of :py:func:`jax.pmap`. See the examples below.
|
|
||||||
in_parts: Optional, how each argument to ``fun`` should be partitioned or
|
|
||||||
replicated. This is used to specify partitioned XLA computations, see
|
|
||||||
``sharded_jit`` for more info.
|
|
||||||
out_parts: Optional, how each output of ``fun`` should be partitioned or
|
|
||||||
replicated. This is used to specify partitioned XLA computations, see
|
|
||||||
``sharded_jit`` for more info.
|
|
||||||
backend: This is an experimental feature and the API is likely to change.
|
|
||||||
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
|
||||||
``'tpu'``.
|
|
||||||
tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting
|
|
||||||
XLA computation will have a single tuple argument that is unpacked into
|
|
||||||
the specified function arguments. If `None`, tupling will be enabled when
|
|
||||||
there are more than 100 arguments, since some platforms have limits on
|
|
||||||
argument arity.
|
|
||||||
instantiate_const_outputs: Deprecated argument, does nothing.
|
|
||||||
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
|
||||||
wrapped function returns a pair where the first element is the XLA
|
|
||||||
computation and the second element is a pytree with the same structure as
|
|
||||||
the output of ``fun`` and where the leaves are objects with ``shape`` and
|
|
||||||
``dtype`` attributes representing the corresponding types of the output
|
|
||||||
leaves.
|
|
||||||
donate_argnums: Specify which arguments are "donated" to the computation.
|
|
||||||
It is safe to donate arguments if you no longer need them once the
|
|
||||||
computation has finished. In some cases XLA can make use of donated
|
|
||||||
buffers to reduce the amount of memory needed to perform a computation,
|
|
||||||
for example recycling one of your input buffers to store a result. You
|
|
||||||
should not reuse buffers that you donate to a computation, JAX will raise
|
|
||||||
an error if you try to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A wrapped version of ``fun`` that when applied to example arguments returns
|
|
||||||
a built XLA Computation (see xla_client.py), from which representations of
|
|
||||||
the unoptimized XLA HLO computation can be extracted using methods like
|
|
||||||
``as_hlo_text``, ``as_serialized_hlo_module_proto``, and
|
|
||||||
``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
|
|
||||||
wrapped function returns a pair where the first element is the XLA
|
|
||||||
Computation and the second element is a pytree representing the structure,
|
|
||||||
shapes, dtypes, and named shapes of the output of ``fun``.
|
|
||||||
|
|
||||||
Concrete example arguments are not always necessary. For those arguments not
|
|
||||||
indicated by ``static_argnums``, any object with ``shape`` and ``dtype``
|
|
||||||
attributes is acceptable (excepting namedtuples, which are treated as Python
|
|
||||||
containers).
|
|
||||||
|
|
||||||
For example:
|
|
||||||
|
|
||||||
>>> import jax
|
|
||||||
>>>
|
|
||||||
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
|
||||||
>>> c = jax.xla_computation(f)(3.) # doctest: +SKIP
|
|
||||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
|
||||||
HloModule xla_computation_f.6
|
|
||||||
<BLANKLINE>
|
|
||||||
ENTRY xla_computation_f.6 {
|
|
||||||
constant.2 = pred[] constant(false)
|
|
||||||
parameter.1 = f32[] parameter(0)
|
|
||||||
cosine.3 = f32[] cosine(parameter.1)
|
|
||||||
sine.4 = f32[] sine(cosine.3)
|
|
||||||
ROOT tuple.5 = (f32[]) tuple(sine.4)
|
|
||||||
}
|
|
||||||
<BLANKLINE>
|
|
||||||
<BLANKLINE>
|
|
||||||
|
|
||||||
|
|
||||||
Alternatively, the assignment to ``c`` above could be written:
|
|
||||||
|
|
||||||
>>> import types
|
|
||||||
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
|
|
||||||
>>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP
|
|
||||||
|
|
||||||
|
|
||||||
Here's an example that involves a parallel collective and axis name:
|
|
||||||
|
|
||||||
>>> def f(x): return x - jax.lax.psum(x, 'i')
|
|
||||||
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP
|
|
||||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
|
||||||
HloModule jaxpr_computation.9
|
|
||||||
primitive_computation.3 {
|
|
||||||
parameter.4 = s32[] parameter(0)
|
|
||||||
parameter.5 = s32[] parameter(1)
|
|
||||||
ROOT add.6 = s32[] add(parameter.4, parameter.5)
|
|
||||||
}
|
|
||||||
ENTRY jaxpr_computation.9 {
|
|
||||||
tuple.1 = () tuple()
|
|
||||||
parameter.2 = s32[] parameter(0)
|
|
||||||
all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
|
|
||||||
ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
|
|
||||||
}
|
|
||||||
<BLANKLINE>
|
|
||||||
<BLANKLINE>
|
|
||||||
|
|
||||||
Notice the ``replica_groups`` that were generated. Here's an example that
|
|
||||||
generates more interesting ``replica_groups``:
|
|
||||||
|
|
||||||
>>> from jax import lax
|
|
||||||
>>> def g(x):
|
|
||||||
... rowsum = lax.psum(x, 'i')
|
|
||||||
... colsum = lax.psum(x, 'j')
|
|
||||||
... allsum = lax.psum(x, ('i', 'j'))
|
|
||||||
... return rowsum, colsum, allsum
|
|
||||||
...
|
|
||||||
>>> axis_env = [('i', 4), ('j', 2)]
|
|
||||||
>>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP
|
|
||||||
>>> print(c.as_hlo_text()) # doctest: +SKIP
|
|
||||||
HloModule jaxpr_computation__1.19
|
|
||||||
[removed uninteresting text here]
|
|
||||||
ENTRY jaxpr_computation__1.19 {
|
|
||||||
tuple.1 = () tuple()
|
|
||||||
parameter.2 = f32[] parameter(0)
|
|
||||||
all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
|
|
||||||
all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
|
|
||||||
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
|
|
||||||
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
|
|
||||||
}
|
|
||||||
|
|
||||||
.. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024
|
|
||||||
"""
|
|
||||||
if instantiate_const_outputs is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"instantiate_const_outputs has been deprecated. Please use the ahead of"
|
|
||||||
" time APIs. You can read more here:"
|
|
||||||
" https://jax.readthedocs.io/en/latest/aot.html")
|
|
||||||
if in_parts is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"in_parts has been deprecated. Please use the ahead of time APIs. You"
|
|
||||||
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
|
|
||||||
if out_parts is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"out_parts has been deprecated. Please use the ahead of time APIs. You"
|
|
||||||
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
|
|
||||||
|
|
||||||
check_callable(fun)
|
|
||||||
static_argnums = _ensure_index_tuple(static_argnums)
|
|
||||||
donate_argnums = _ensure_index_tuple(donate_argnums)
|
|
||||||
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
|
|
||||||
|
|
||||||
fun_name = getattr(fun, "__name__", "unknown")
|
|
||||||
|
|
||||||
platform = backend if backend is not None else xb.get_backend().platform
|
|
||||||
|
|
||||||
def make_axis_env(nreps):
|
|
||||||
if axis_env is None:
|
|
||||||
return sharding_impls.AxisEnv(nreps, (), ())
|
|
||||||
else:
|
|
||||||
nreps = nreps * math.prod(size for name, size in axis_env)
|
|
||||||
names, sizes = unzip2(axis_env)
|
|
||||||
return sharding_impls.AxisEnv(nreps, names, sizes)
|
|
||||||
|
|
||||||
@wraps(fun)
|
|
||||||
@api_boundary
|
|
||||||
def computation_maker(*args, **kwargs):
|
|
||||||
if max(static_argnums + donate_argnums, default=-1) >= len(args):
|
|
||||||
raise ValueError(f"jitted function has {static_argnums=}, {donate_argnums=} but "
|
|
||||||
f"was called with only {len(args)} positional arguments.")
|
|
||||||
|
|
||||||
f = lu.wrap_init(fun)
|
|
||||||
f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False)
|
|
||||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
|
||||||
if donate_argnums:
|
|
||||||
donated_invars = donation_vector(donate_argnums, (), in_tree)
|
|
||||||
else:
|
|
||||||
donated_invars = (False,) * len(args_flat)
|
|
||||||
|
|
||||||
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
|
|
||||||
avals = map(shaped_abstractify, args_flat)
|
|
||||||
with ExitStack() as stack:
|
|
||||||
for axis_name, size in axis_env or []:
|
|
||||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
|
||||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
|
||||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
|
||||||
if axis_env:
|
|
||||||
jaxpr = core.remove_named_axis_effects(
|
|
||||||
jaxpr, {axis_name for axis_name, _ in axis_env}
|
|
||||||
)
|
|
||||||
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
|
|
||||||
ordered_effects = list(
|
|
||||||
effects.ordered_effects.filter_in(jaxpr.effects))
|
|
||||||
lowering_result = mlir.lower_jaxpr_to_module(
|
|
||||||
f"xla_computation_{fun_name}",
|
|
||||||
core.ClosedJaxpr(jaxpr, consts),
|
|
||||||
ordered_effects=ordered_effects,
|
|
||||||
backend_or_name=backend,
|
|
||||||
platforms=[platform],
|
|
||||||
axis_context=sharding_impls.ReplicaAxisContext(axis_env_),
|
|
||||||
name_stack=source_info_util.new_name_stack(
|
|
||||||
wrap_name(fun_name, "xla_computation")),
|
|
||||||
donated_args=donated_invars,
|
|
||||||
arg_shardings=None,
|
|
||||||
result_shardings=None,
|
|
||||||
lowering_parameters=mlir.LoweringParameters())
|
|
||||||
|
|
||||||
m = mlir.module_to_bytecode(lowering_result.module)
|
|
||||||
built = xc._xla.mlir.mlir_module_to_xla_computation(
|
|
||||||
m, use_tuple_args=tuple_args, return_tuple=True)
|
|
||||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
|
||||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
|
||||||
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
|
|
||||||
for out_aval in out_avals:
|
|
||||||
if not isinstance(out_aval, ShapedArray):
|
|
||||||
raise RuntimeError("As we want to propagate the weak_type, we need "
|
|
||||||
"to get a ShapedArray, otherwise this "
|
|
||||||
"information is lost")
|
|
||||||
|
|
||||||
if return_shape:
|
|
||||||
return built, out_shape
|
|
||||||
else:
|
|
||||||
return built
|
|
||||||
|
|
||||||
return computation_maker
|
|
||||||
|
|
||||||
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
|
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
|
||||||
has_aux: bool = False, holomorphic: bool = False,
|
has_aux: bool = False, holomorphic: bool = False,
|
||||||
allow_int: bool = False,
|
allow_int: bool = False,
|
||||||
|
@ -50,7 +50,6 @@ from jax._src import array
|
|||||||
from jax._src import config
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import custom_derivatives
|
from jax._src import custom_derivatives
|
||||||
from jax._src import deprecations
|
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
@ -60,7 +59,6 @@ from jax._src.ad_checkpoint import saved_residuals
|
|||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
from jax._src.interpreters import partial_eval as pe
|
from jax._src.interpreters import partial_eval as pe
|
||||||
from jax._src.compilation_cache import is_persistent_cache_enabled
|
from jax._src.compilation_cache import is_persistent_cache_enabled
|
||||||
from jax._src.lib import xla_client
|
|
||||||
from jax._src.lib import xla_extension
|
from jax._src.lib import xla_extension
|
||||||
import jax._src.util as jax_util
|
import jax._src.util as jax_util
|
||||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||||
@ -2904,74 +2902,6 @@ class APITest(jtu.JaxTestCase):
|
|||||||
r"sub-dtype of np.floating\), but got complex.*"),
|
r"sub-dtype of np.floating\), but got complex.*"),
|
||||||
lambda: dfn(3. + 1j))
|
lambda: dfn(3. + 1j))
|
||||||
|
|
||||||
def test_xla_computation(self):
|
|
||||||
# these tests basically check the examples in the xla_computation docstring
|
|
||||||
|
|
||||||
def e(x):
|
|
||||||
return jnp.sin(jnp.cos(x))
|
|
||||||
c = api.xla_computation(e)(2.)
|
|
||||||
self.assertIn('cosine', c.as_hlo_text())
|
|
||||||
self.assertIn('sine', c.as_hlo_text())
|
|
||||||
|
|
||||||
def f(x):
|
|
||||||
return x - lax.psum(x, 'i')
|
|
||||||
axis_env = [('i', 4)]
|
|
||||||
c = api.xla_computation(f, axis_env=axis_env)(2)
|
|
||||||
self.assertIn('all-reduce', c.as_hlo_text())
|
|
||||||
self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text())
|
|
||||||
|
|
||||||
def g(x):
|
|
||||||
rowsum = lax.psum(x, 'i')
|
|
||||||
colsum = lax.psum(x, 'j')
|
|
||||||
allsum = lax.psum(x, ('i', 'j'))
|
|
||||||
return rowsum, colsum, allsum
|
|
||||||
axis_env = [('i', 4), ('j', 2)]
|
|
||||||
c = api.xla_computation(g, axis_env=axis_env)(5.)
|
|
||||||
self.assertIn('all-reduce', c.as_hlo_text())
|
|
||||||
self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text())
|
|
||||||
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
|
|
||||||
self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text())
|
|
||||||
|
|
||||||
def h(x):
|
|
||||||
rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]])
|
|
||||||
colsum = lax.psum(x, 'j')
|
|
||||||
return rowsum, colsum
|
|
||||||
axis_env = [('i', 4), ('j', 2)]
|
|
||||||
c = api.xla_computation(h, axis_env=axis_env)(5.)
|
|
||||||
self.assertIn('all-reduce', c.as_hlo_text())
|
|
||||||
self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text())
|
|
||||||
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
|
|
||||||
|
|
||||||
def test_xla_computation_args(self):
|
|
||||||
def foo(x, y, z):
|
|
||||||
return x + y + z
|
|
||||||
|
|
||||||
c = api.xla_computation(foo)(1., 2., 3.)
|
|
||||||
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
|
|
||||||
|
|
||||||
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
|
||||||
param_shapes = c.program_shape().parameter_shapes()
|
|
||||||
self.assertEqual(len(param_shapes), 1)
|
|
||||||
self.assertEqual(param_shapes[0].xla_element_type(),
|
|
||||||
xla_client.PrimitiveType.TUPLE)
|
|
||||||
|
|
||||||
def test_xla_computation_duck_typing(self):
|
|
||||||
def foo(x, y, z):
|
|
||||||
return x + y + z
|
|
||||||
|
|
||||||
x = jax.ShapeDtypeStruct((), np.float32)
|
|
||||||
y = jax.ShapeDtypeStruct((), np.float32)
|
|
||||||
z = jax.ShapeDtypeStruct((), np.float32)
|
|
||||||
|
|
||||||
c = api.xla_computation(foo)(x, y, z)
|
|
||||||
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
|
|
||||||
|
|
||||||
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
|
||||||
param_shapes = c.program_shape().parameter_shapes()
|
|
||||||
self.assertEqual(len(param_shapes), 1)
|
|
||||||
self.assertEqual(param_shapes[0].xla_element_type(),
|
|
||||||
xla_client.PrimitiveType.TUPLE)
|
|
||||||
|
|
||||||
def test_compiler_ir(self):
|
def test_compiler_ir(self):
|
||||||
# TODO(phawkins): merge these tests with the `xla_computation` tests.
|
# TODO(phawkins): merge these tests with the `xla_computation` tests.
|
||||||
def e(x):
|
def e(x):
|
||||||
@ -2983,72 +2913,6 @@ class APITest(jtu.JaxTestCase):
|
|||||||
self.assertIn("stablehlo.cosine", stablehlo)
|
self.assertIn("stablehlo.cosine", stablehlo)
|
||||||
self.assertIn("stablehlo.sine", stablehlo)
|
self.assertIn("stablehlo.sine", stablehlo)
|
||||||
|
|
||||||
def test_staging_out_multi_replica(self):
|
|
||||||
def f(x):
|
|
||||||
return api.pmap(jnp.mean)(x)
|
|
||||||
xla_comp = api.xla_computation(f)
|
|
||||||
xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash
|
|
||||||
|
|
||||||
def test_xla_computation_instantiate_constant_outputs(self):
|
|
||||||
def f():
|
|
||||||
return jnp.zeros((3, 4))
|
|
||||||
|
|
||||||
xla_comp = api.xla_computation(f)()
|
|
||||||
out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
|
|
||||||
self.assertEqual(out_shape.dimensions(), (3, 4))
|
|
||||||
|
|
||||||
def test_xla_computation_static_argnums(self):
|
|
||||||
def f(x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
|
|
||||||
hlo_text = xla_comp.as_hlo_text()
|
|
||||||
self.assertIn("constant(3)", hlo_text)
|
|
||||||
# The static arguments should be removed from the function being compiled,
|
|
||||||
# thus the function should have only a single argument.
|
|
||||||
self.assertIn("parameter(0)", hlo_text)
|
|
||||||
self.assertNotIn("parameter(1)", hlo_text)
|
|
||||||
|
|
||||||
def test_xla_computation_return_shape(self):
|
|
||||||
_, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
|
|
||||||
return_shape=True)(np.int32(1))
|
|
||||||
expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
|
|
||||||
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
|
|
||||||
self.assertEqual(shape_tree, expected)
|
|
||||||
|
|
||||||
def test_xla_computation_psum_constant(self):
|
|
||||||
f = lambda: jax.lax.psum(1, "i")
|
|
||||||
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
|
|
||||||
|
|
||||||
@jtu.ignore_warning(message="Some donated buffers were not usable")
|
|
||||||
def test_xla_computation_donate_argnums(self):
|
|
||||||
api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash
|
|
||||||
|
|
||||||
def test_xla_computation_lower_fun_axis_env(self):
|
|
||||||
axis_name = 'i'
|
|
||||||
def fn(x):
|
|
||||||
y = lax.all_gather(
|
|
||||||
x, axis_name=axis_name)
|
|
||||||
return y * lax.axis_index(axis_name).astype(jnp.float32)
|
|
||||||
|
|
||||||
input_x = jnp.ones((5,6,4), dtype=jnp.float32)
|
|
||||||
axis_env = [(axis_name, jax.local_device_count())]
|
|
||||||
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
|
|
||||||
|
|
||||||
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
|
|
||||||
def test_xla_computation_axis_env(self):
|
|
||||||
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
|
|
||||||
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
|
|
||||||
|
|
||||||
def fn(x):
|
|
||||||
z = x * jax.lax.axis_index('i').astype(jnp.float32)
|
|
||||||
def inner_fn(carry, a):
|
|
||||||
return carry + a, ()
|
|
||||||
return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z)
|
|
||||||
|
|
||||||
x = jnp.ones((5, 6, 4), dtype=jnp.float32)
|
|
||||||
_ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x)
|
|
||||||
|
|
||||||
def test_concurrent_device_get_and_put(self):
|
def test_concurrent_device_get_and_put(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
@ -3678,7 +3542,7 @@ class APITest(jtu.JaxTestCase):
|
|||||||
return x + y + y
|
return x + y + y
|
||||||
|
|
||||||
x = np.array([1, 2], dtype=np.float32)
|
x = np.array([1, 2], dtype=np.float32)
|
||||||
hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n')
|
hlo_lines = jax.jit(f).lower(x).as_text('hlo').split('\n')
|
||||||
hlo_lines = {s.strip() for s in hlo_lines}
|
hlo_lines = {s.strip() for s in hlo_lines}
|
||||||
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
|
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
|
||||||
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
|
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
|
||||||
@ -3805,11 +3669,6 @@ class APITest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
||||||
g(1)
|
g(1)
|
||||||
|
|
||||||
def test_xla_computation_zeros_doesnt_device_put(self):
|
|
||||||
with jtu.count_device_put() as count:
|
|
||||||
api.xla_computation(lambda: jnp.zeros(3))()
|
|
||||||
self.assertEqual(count[0], 0)
|
|
||||||
|
|
||||||
def test_join_concrete_arrays_with_omnistaging(self):
|
def test_join_concrete_arrays_with_omnistaging(self):
|
||||||
# https://github.com/google/jax/issues/4622
|
# https://github.com/google/jax/issues/4622
|
||||||
x = jnp.array([1., 2., 3.])
|
x = jnp.array([1., 2., 3.])
|
||||||
@ -5532,13 +5391,12 @@ class RematTest(jtu.JaxTestCase):
|
|||||||
x, _ = g(x)
|
x, _ = g(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
c = api.xla_computation(f)(2.)
|
text = jax.jit(f).lower(2.).as_text('hlo')
|
||||||
self.assertNotIn('while', c.as_hlo_text())
|
self.assertNotIn('while', text)
|
||||||
self.assertNotIn('conditional', c.as_hlo_text())
|
self.assertNotIn('conditional', text)
|
||||||
self.assertNotIn('opt-barrier', c.as_hlo_text())
|
self.assertNotIn('opt-barrier', text)
|
||||||
|
|
||||||
c = api.xla_computation(grad(f))(2.)
|
text = jax.jit(grad(f)).lower(2.).as_text('hlo')
|
||||||
text = c.as_hlo_text()
|
|
||||||
self.assertTrue('while' in text or 'conditional' in text
|
self.assertTrue('while' in text or 'conditional' in text
|
||||||
or 'opt-barrier' in text)
|
or 'opt-barrier' in text)
|
||||||
|
|
||||||
@ -5557,13 +5415,13 @@ class RematTest(jtu.JaxTestCase):
|
|||||||
x, _ = g(x)
|
x, _ = g(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
c = api.xla_computation(f)(2.)
|
text = jax.jit(f).lower(2.).as_text('hlo')
|
||||||
self.assertNotIn('while', c.as_hlo_text())
|
self.assertNotIn('while', text)
|
||||||
self.assertNotIn('conditional', c.as_hlo_text())
|
self.assertNotIn('conditional', text)
|
||||||
|
|
||||||
c = api.xla_computation(grad(f))(2.)
|
text = jax.jit(grad(f)).lower(2.).as_text('hlo')
|
||||||
self.assertNotIn('while', c.as_hlo_text())
|
self.assertNotIn('while', text)
|
||||||
self.assertNotIn('conditional', c.as_hlo_text())
|
self.assertNotIn('conditional', text)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat,
|
{"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat,
|
||||||
@ -6679,7 +6537,7 @@ class JaxprTest(jtu.JaxTestCase):
|
|||||||
self.assertLen(jaxpr.jaxpr.eqns, 0)
|
self.assertLen(jaxpr.jaxpr.eqns, 0)
|
||||||
|
|
||||||
def test_convert_element_type_literal_constant_folding(self):
|
def test_convert_element_type_literal_constant_folding(self):
|
||||||
# this convert_elemnt_type is nontrivial, but because it's on a scalar we
|
# this convert_element_type is nontrivial, but because it's on a scalar we
|
||||||
# constant-fold it
|
# constant-fold it
|
||||||
cet = partial(lax.convert_element_type, new_dtype='float16')
|
cet = partial(lax.convert_element_type, new_dtype='float16')
|
||||||
jaxpr = api.make_jaxpr(lambda: cet(3.))()
|
jaxpr = api.make_jaxpr(lambda: cet(3.))()
|
||||||
@ -10966,25 +10824,6 @@ class BufferDonationTest(jtu.BufferDonationTestCase):
|
|||||||
|
|
||||||
class NamedCallTest(jtu.JaxTestCase):
|
class NamedCallTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
@jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated')
|
|
||||||
def test_default_name(self):
|
|
||||||
is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation')
|
|
||||||
xla_computation = api.xla_computation if is_accelerated else jax.xla_computation
|
|
||||||
|
|
||||||
@api.named_call
|
|
||||||
def my_test_function(x):
|
|
||||||
return x**2
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def f(x):
|
|
||||||
return my_test_function(x)
|
|
||||||
|
|
||||||
c = xla_computation(f)(2)
|
|
||||||
print_opts = xla_client._xla.HloPrintOptions.short_parsable()
|
|
||||||
print_opts.print_metadata = True
|
|
||||||
hlo_text = c.as_hlo_module().to_string(print_opts)
|
|
||||||
self.assertIn("my_test_function", hlo_text)
|
|
||||||
|
|
||||||
def test_non_jaxtype_arg(self):
|
def test_non_jaxtype_arg(self):
|
||||||
# For the test to fail without the invalid JaxType filter we need to pass
|
# For the test to fail without the invalid JaxType filter we need to pass
|
||||||
# in a valid JaxType that forces the invalid Jaxtype to be raised to an
|
# in a valid JaxType that forces the invalid Jaxtype to be raised to an
|
||||||
|
Loading…
x
Reference in New Issue
Block a user