mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Update minimum jaxlib version to 0.1.69.
This commit is contained in:
parent
6aa20d8f8f
commit
3ddcec27f2
@ -15,12 +15,16 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Support for Python 3.6 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
Please upgrade to a supported Python version.
|
||||
* The minimum jaxlib version is now 0.1.69.
|
||||
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
|
||||
removed.
|
||||
|
||||
* Bug fixes:
|
||||
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
|
||||
not used with invalid `axis` value, or with an empty reduction dimension.
|
||||
({jax-issue}`#7196`)
|
||||
|
||||
|
||||
## jaxlib 0.1.70 (unreleased)
|
||||
* Breaking changes:
|
||||
* Support for Python 3.6 has been dropped, per the
|
||||
|
@ -1,8 +1,6 @@
|
||||
flake8
|
||||
flatbuffers==1.12
|
||||
# For now, we pin the numpy version here
|
||||
# TODO(jakevdp): unpin maximum version when minimum jaxlib supports newer numpy
|
||||
numpy>=1.17,<1.21
|
||||
numpy>=1.17
|
||||
mypy==0.902
|
||||
pillow>=8.3.1
|
||||
pytest-benchmark
|
||||
|
@ -13,6 +13,5 @@ pytest-xdist
|
||||
# Packages used for notebook execution
|
||||
matplotlib
|
||||
scikit-learn
|
||||
# TODO(jakevdp) remove numpy pinning when minimum jaxlib supports newer numpy.
|
||||
numpy<1.21
|
||||
numpy
|
||||
.[cpu] # Install jax from the current directory; jaxlib from pypi.
|
||||
|
@ -197,12 +197,6 @@ def _infer_argnums_and_argnames(
|
||||
return argnums, argnames
|
||||
|
||||
|
||||
# Static kwargs require jaxlib 0.1.66 or newer. (Technically the Python jit
|
||||
# implementation works with any jaxlib, but we want consistency between C++ and
|
||||
# Python implementations.)
|
||||
# TODO(phawkins): remove when jaxlib 0.1.66 is the minimum.
|
||||
_ALLOW_STATIC_ARGNAMES = lib._xla_extension_version >= 14
|
||||
|
||||
def jit(
|
||||
fun: F,
|
||||
*,
|
||||
@ -285,10 +279,6 @@ def jit(
|
||||
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
||||
-0.85743 -0.78232 0.76827 0.59566 ]
|
||||
"""
|
||||
if not _ALLOW_STATIC_ARGNAMES:
|
||||
if static_argnames is not None:
|
||||
raise ValueError("static_argnames requires jaxlib 0.1.66 or newer")
|
||||
static_argnames = ()
|
||||
if FLAGS.experimental_cpp_jit:
|
||||
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline)
|
||||
@ -319,17 +309,10 @@ def _python_jit(
|
||||
def f_jitted(*args, **kwargs):
|
||||
if config.jax_disable_jit:
|
||||
return fun(*args, **kwargs)
|
||||
if _ALLOW_STATIC_ARGNAMES:
|
||||
if max(donate_argnums, default=-1) >= len(args):
|
||||
raise ValueError(
|
||||
f"jitted function has donate_argnums={donate_argnums} but "
|
||||
f"was called with only {len(args)} positional arguments.")
|
||||
else:
|
||||
if max(static_argnums + donate_argnums, default=-1) >= len(args):
|
||||
raise ValueError(
|
||||
f"jitted function has static_argnums={static_argnums}, "
|
||||
f"donate_argnums={donate_argnums} but "
|
||||
f"was called with only {len(args)} positional arguments.")
|
||||
if max(donate_argnums, default=-1) >= len(args):
|
||||
raise ValueError(
|
||||
f"jitted function has donate_argnums={donate_argnums} but "
|
||||
f"was called with only {len(args)} positional arguments.")
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
f, args = argnums_partial_except(f, static_argnums, args,
|
||||
@ -362,8 +345,7 @@ class _FastpathData(NamedTuple):
|
||||
lazy_exprs: Iterable[Any]
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
|
||||
if lib._xla_extension_version >= 16:
|
||||
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
|
||||
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
|
||||
|
||||
def _cpp_jit(
|
||||
fun: F,
|
||||
@ -398,17 +380,10 @@ def _cpp_jit(
|
||||
# An alternative would be for cache_miss to accept from C++ the arguments
|
||||
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
|
||||
# work/code that is redundant between C++ and Python. We can try that later.
|
||||
if _ALLOW_STATIC_ARGNAMES:
|
||||
if max(donate_argnums, default=-1) >= len(args):
|
||||
raise ValueError(
|
||||
f"jitted function has donate_argnums={donate_argnums} but "
|
||||
f"was called with only {len(args)} positional arguments.")
|
||||
else:
|
||||
if max(static_argnums + donate_argnums, default=-1) >= len(args):
|
||||
raise ValueError(
|
||||
f"jitted function has static_argnums={static_argnums}, "
|
||||
f"donate_argnums={donate_argnums} but "
|
||||
f"was called with only {len(args)} positional arguments.")
|
||||
if max(donate_argnums, default=-1) >= len(args):
|
||||
raise ValueError(
|
||||
f"jitted function has donate_argnums={donate_argnums} but "
|
||||
f"was called with only {len(args)} positional arguments.")
|
||||
f = lu.wrap_init(fun)
|
||||
f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
|
||||
f, kwargs = argnames_partial_except(f, static_argnames, kwargs)
|
||||
@ -452,14 +427,10 @@ def _cpp_jit(
|
||||
aval, sticky_device = result_handler.args
|
||||
avals.append(aval)
|
||||
assert len(avals) == len(out_flat)
|
||||
if xla._ALLOW_ARG_PRUNING:
|
||||
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
|
||||
fastpath_data = _FastpathData(xla_executable, out_pytree_def,
|
||||
sticky_device, avals, lazy_exprs,
|
||||
kept_var_bitvec)
|
||||
else:
|
||||
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals,
|
||||
lazy_exprs)
|
||||
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
|
||||
fastpath_data = _FastpathData(xla_executable, out_pytree_def,
|
||||
sticky_device, avals, lazy_exprs,
|
||||
kept_var_bitvec)
|
||||
else:
|
||||
fastpath_data = None
|
||||
|
||||
@ -477,21 +448,12 @@ def _cpp_jit(
|
||||
|
||||
return _BackendAndDeviceInfo(default_device, committed_to_device)
|
||||
|
||||
if lib._xla_extension_version < 14:
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, static_argnums)
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
elif lib._xla_extension_version < 16:
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
static_argnums=static_argnums,
|
||||
static_argnames=static_argnames)
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
else:
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
static_argnums=static_argnums,
|
||||
static_argnames=static_argnames,
|
||||
donate_argnums=donate_argnums,
|
||||
cache=_cpp_jit_cache)
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
static_argnums=static_argnums,
|
||||
static_argnames=static_argnames,
|
||||
donate_argnums=donate_argnums,
|
||||
cache=_cpp_jit_cache)
|
||||
f_jitted = wraps(fun)(cpp_jitted_f)
|
||||
|
||||
return f_jitted
|
||||
|
||||
@ -627,7 +589,7 @@ def xla_computation(fun: Callable,
|
||||
Alternatively, the assignment to ``c`` above could be written:
|
||||
|
||||
>>> import types
|
||||
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32)
|
||||
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
|
||||
>>> c = jax.xla_computation(f)(scalar)
|
||||
|
||||
|
||||
@ -2063,7 +2025,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
>>> import types
|
||||
>>>
|
||||
>>> f = lambda x, y: 0.5 * x - 0.5 * y
|
||||
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32)
|
||||
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
|
||||
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
|
||||
>>> f_transpose(1.0)
|
||||
(DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))
|
||||
@ -2441,7 +2403,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
>>> class MyArgArray(object):
|
||||
... def __init__(self, shape, dtype):
|
||||
... self.shape = shape
|
||||
... self.dtype = dtype
|
||||
... self.dtype = jnp.dtype(dtype)
|
||||
...
|
||||
>>> A = MyArgArray((2000, 3000), jnp.float32)
|
||||
>>> x = MyArgArray((3000, 1000), jnp.float32)
|
||||
|
@ -15,7 +15,6 @@
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax.interpreters import xla
|
||||
import jax.lib
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import xla_bridge
|
||||
|
||||
@ -44,28 +43,21 @@ def to_dlpack(x: xla.DeviceArrayProtocol, take_ownership: bool = False):
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
x.device_buffer, take_ownership=take_ownership)
|
||||
|
||||
def from_dlpack(dlpack, backend=None):
|
||||
def from_dlpack(dlpack):
|
||||
"""Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.
|
||||
|
||||
The returned `DeviceArray` shares memory with `dlpack`.
|
||||
|
||||
Args:
|
||||
dlpack: a DLPack tensor, on either CPU or GPU.
|
||||
backend: deprecated, do not use.
|
||||
"""
|
||||
if jax.lib._xla_extension_version >= 25:
|
||||
cpu_backend = xla_bridge.get_backend("cpu")
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("gpu")
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend)
|
||||
else:
|
||||
# TODO(phawkins): drop the backend argument after deleting this case.
|
||||
backend = backend or xla_bridge.get_backend()
|
||||
client = getattr(backend, "client", backend)
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
|
||||
cpu_backend = xla_bridge.get_backend("cpu")
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("gpu")
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend)
|
||||
|
||||
xla_shape = buf.xla_shape()
|
||||
assert not xla_shape.is_tuple()
|
||||
|
@ -2443,10 +2443,7 @@ acos_p = standard_unop(_float | _complex, 'acos',
|
||||
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
|
||||
|
||||
def atan_translation_rule(x):
|
||||
if jax.lib._xla_extension_version < 26 and dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
return mul(_const(x, -1j), atanh(mul(_const(x, 1j), x)))
|
||||
else:
|
||||
return atan2(x, _const(x, 1))
|
||||
return atan2(x, _const(x, 1))
|
||||
|
||||
atan_p = standard_unop(_float | _complex, 'atan',
|
||||
translation_rule=xla.lower_fun(atan_translation_rule,
|
||||
@ -6149,14 +6146,10 @@ ad.primitive_transposes[select_and_gather_add_p] = \
|
||||
batching.primitive_batchers[select_and_gather_add_p] = \
|
||||
_select_and_gather_add_batching_rule
|
||||
# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
|
||||
if jax.lib._xla_extension_version >= 15:
|
||||
xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \
|
||||
_select_and_gather_add_translation_using_variadic_reducewindow
|
||||
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \
|
||||
_select_and_gather_add_translation_using_variadic_reducewindow
|
||||
else:
|
||||
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
|
||||
_select_and_gather_add_translation, max_bits=32)
|
||||
xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \
|
||||
_select_and_gather_add_translation_using_variadic_reducewindow
|
||||
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \
|
||||
_select_and_gather_add_translation_using_variadic_reducewindow
|
||||
|
||||
|
||||
def _sort_abstract_eval(*args, **kwargs):
|
||||
|
@ -437,15 +437,10 @@ def block_diag(*arrs):
|
||||
return acc
|
||||
|
||||
|
||||
# TODO(phawkins): use static_argnames when jaxlib 0.1.66 is the minimum and
|
||||
# remove this wrapper.
|
||||
@_wraps(scipy.linalg.eigh_tridiagonal)
|
||||
@partial(jit, static_argnames=("eigvals_only", "select", "select_range"))
|
||||
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
|
||||
select_range=None, tol=None):
|
||||
return _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol)
|
||||
|
||||
@partial(jit, static_argnums=(2, 3, 4))
|
||||
def _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol):
|
||||
if not eigvals_only:
|
||||
raise NotImplementedError("Calculation of eigenvectors is not implemented")
|
||||
|
||||
|
@ -90,10 +90,6 @@ def user_context(c):
|
||||
except Exception as e:
|
||||
if c is None or has_user_context(e):
|
||||
raise
|
||||
# TODO(phawkins): remove the following condition after Jaxlib 0.1.66 is the
|
||||
# minimum.
|
||||
if not hasattr(c, 'as_python_traceback'):
|
||||
raise
|
||||
filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
|
||||
if filtered_tb:
|
||||
msg = traceback_util.format_exception_only(e)
|
||||
|
@ -208,15 +208,8 @@ def api_boundary(fun):
|
||||
# There seems to be no way to alter the currently raised exception's
|
||||
# traceback, except via the C API. The currently raised exception
|
||||
# is part of the interpreter's thread state: value `e` is a copy.
|
||||
if hasattr(xla_extension, 'replace_thread_exc_traceback'):
|
||||
xla_extension.replace_thread_exc_traceback(filtered_tb)
|
||||
raise
|
||||
else:
|
||||
# TODO(phawkins): remove this case when jaxlib 0.1.66 is the
|
||||
# minimum.
|
||||
|
||||
# Fallback case for older jaxlibs; includes the current frame.
|
||||
raise e.with_traceback(filtered_tb)
|
||||
xla_extension.replace_thread_exc_traceback(filtered_tb)
|
||||
raise
|
||||
finally:
|
||||
del filtered_tb
|
||||
del unfiltered
|
||||
|
@ -23,8 +23,8 @@ product, sparse matrix/matrix product) for two common sparse representations
|
||||
|
||||
These routines have reference implementations defined via XLA scatter/gather
|
||||
operations that will work on any backend, although they are not particularly
|
||||
performant. On GPU runtimes with jaxlib 0.1.66 or newer built against CUDA 11.0
|
||||
or newer, each operation is computed efficiently via cusparse.
|
||||
performant. On GPU runtimes built against CUDA 11.0 or newer, each operation is
|
||||
computed efficiently via cusparse.
|
||||
|
||||
Further down are some examples of potential high-level wrappers for sparse objects.
|
||||
(API should be considered unstable and subject to change).
|
||||
|
@ -38,7 +38,6 @@ from .._src.util import (partial, partialmethod, cache, prod, unzip2,
|
||||
extend_name_stack, wrap_name, safe_zip, safe_map)
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from ..lib import _xla_extension_version
|
||||
from . import partial_eval as pe
|
||||
from . import ad
|
||||
from . import masking
|
||||
@ -778,18 +777,8 @@ def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
|
||||
return tuple(out_donated_args)
|
||||
|
||||
|
||||
# Pruning unused JIT arguments require jaxlib 0.1.66 or newer.
|
||||
# TODO(zhangqiaorjc): remove when jaxlib 0.1.66 is the minimum.
|
||||
_ALLOW_ARG_PRUNING = _xla_extension_version >= 20
|
||||
|
||||
|
||||
def _prune_unused_inputs(
|
||||
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
|
||||
if not _ALLOW_ARG_PRUNING:
|
||||
kept_const_idx = range(len(jaxpr.constvars))
|
||||
kept_var_idx = range(len(jaxpr.invars))
|
||||
return jaxpr, set(kept_const_idx), set(kept_var_idx)
|
||||
|
||||
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
|
||||
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
|
||||
# applications that do not produce used outputs. Must handle side-effecting
|
||||
@ -1111,11 +1100,9 @@ class _DeviceArray(DeviceArray): # type: ignore
|
||||
"""A DeviceArray is an ndarray backed by a single device memory buffer."""
|
||||
# We don't subclass ndarray because that would open up a host of issues,
|
||||
# but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
|
||||
# TODO(phawkins): make __weakref__ an unconditional slot when jaxlib 0.1.66
|
||||
# is the minimum version.
|
||||
__slots__ = [
|
||||
"aval", "device_buffer", "_npy_value", "_device"
|
||||
] + ([] if device_array_supports_weakrefs() else ["__weakref__"])
|
||||
"aval", "device_buffer", "_npy_value", "_device", "__weakref__"
|
||||
]
|
||||
__array_priority__ = 100
|
||||
|
||||
# DeviceArray has methods that are dynamically populated in lax_numpy.py,
|
||||
|
@ -67,9 +67,8 @@ def _check_jaxlib_version():
|
||||
|
||||
_check_jaxlib_version()
|
||||
|
||||
if version >= (0, 1, 68):
|
||||
from jaxlib import cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
from jaxlib import cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import lapack
|
||||
|
@ -155,40 +155,23 @@ def register_backend_factory(name, factory, *, priority=0):
|
||||
_backend_factories[name] = (factory, priority)
|
||||
|
||||
|
||||
if jax.lib._xla_extension_version >= 23:
|
||||
register_backend_factory('interpreter', xla_client.make_interpreter_client,
|
||||
priority=-100)
|
||||
if jax.lib._xla_extension_version >= 27:
|
||||
if FLAGS.jax_cpu_backend_variant == 'stream_executor':
|
||||
register_backend_factory('cpu',
|
||||
partial(xla_client.make_cpu_client, use_tfrt=False),
|
||||
priority=0)
|
||||
else:
|
||||
assert FLAGS.jax_cpu_backend_variant == 'tfrt'
|
||||
register_backend_factory('cpu',
|
||||
partial(xla_client.make_cpu_client, use_tfrt=True),
|
||||
priority=0)
|
||||
else:
|
||||
register_backend_factory('cpu',
|
||||
partial(xla_client.make_cpu_client, use_tfrt=False),
|
||||
priority=0)
|
||||
register_backend_factory('tpu_driver', _make_tpu_driver_client,
|
||||
priority=100)
|
||||
register_backend_factory('gpu', xla_client.make_gpu_client,
|
||||
priority=200)
|
||||
register_backend_factory('tpu', xla_client.make_tpu_client,
|
||||
priority=300)
|
||||
register_backend_factory('interpreter', xla_client.make_interpreter_client,
|
||||
priority=-100)
|
||||
if FLAGS.jax_cpu_backend_variant == 'stream_executor':
|
||||
register_backend_factory('cpu',
|
||||
partial(xla_client.make_cpu_client, use_tfrt=False),
|
||||
priority=0)
|
||||
else:
|
||||
register_backend_factory('interpreter',
|
||||
xla_client._interpreter_backend_factory,
|
||||
priority=-100)
|
||||
register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0)
|
||||
register_backend_factory('tpu_driver', _make_tpu_driver_client,
|
||||
priority=100)
|
||||
register_backend_factory('gpu', xla_client._gpu_backend_factory,
|
||||
priority=200)
|
||||
register_backend_factory('tpu', xla_client._tpu_backend_factory,
|
||||
priority=300)
|
||||
assert FLAGS.jax_cpu_backend_variant == 'tfrt'
|
||||
register_backend_factory('cpu',
|
||||
partial(xla_client.make_cpu_client, use_tfrt=True),
|
||||
priority=0)
|
||||
register_backend_factory('tpu_driver', _make_tpu_driver_client,
|
||||
priority=100)
|
||||
register_backend_factory('gpu', xla_client.make_gpu_client,
|
||||
priority=200)
|
||||
register_backend_factory('tpu', xla_client.make_tpu_client,
|
||||
priority=300)
|
||||
|
||||
_default_backend = None
|
||||
_backends = None
|
||||
@ -258,11 +241,6 @@ def get_backend(platform=None):
|
||||
def get_device_backend(device=None):
|
||||
"""Returns the Backend associated with `device`, or the default Backend."""
|
||||
if device is not None:
|
||||
# TODO(phawkins): remove this workaround after jaxlib 0.1.68 becomes the
|
||||
# minimum and it is safe to call `.client` on a tpu_driver TpuDevice.
|
||||
if tpu_driver_client and isinstance(
|
||||
device, tpu_driver_client._tpu_client.TpuDevice):
|
||||
return get_backend('tpu_driver')
|
||||
return device.client
|
||||
return get_backend()
|
||||
|
||||
|
@ -13,4 +13,4 @@
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.2.17"
|
||||
_minimum_jaxlib_version = "0.1.65"
|
||||
_minimum_jaxlib_version = "0.1.69"
|
||||
|
@ -50,7 +50,6 @@ from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax import linear_util as lu
|
||||
import jax._src.util
|
||||
from jax._src.api import _ALLOW_STATIC_ARGNAMES
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -309,15 +308,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f(2)
|
||||
assert len(effects) == 3
|
||||
|
||||
# TODO(phawkins): delete this test after jaxlib 0.1.66 is the minimum.
|
||||
@unittest.skipIf(_ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
|
||||
def test_static_argnum_errors_on_keyword_arguments(self):
|
||||
f = self.jit(lambda x: x, static_argnums=0)
|
||||
msg = ("jitted function has static_argnums=(0,), donate_argnums=() but was "
|
||||
"called with only 0 positional arguments.")
|
||||
with self.assertRaisesRegex(ValueError, re.escape(msg)):
|
||||
f(x=4)
|
||||
|
||||
def test_static_argnum_on_method(self):
|
||||
|
||||
class A:
|
||||
@ -587,7 +577,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
assert argnums == ()
|
||||
assert argnames == ('foo', 'bar')
|
||||
|
||||
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
|
||||
def test_jit_with_static_argnames(self):
|
||||
|
||||
def f(x):
|
||||
@ -602,19 +591,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
assert f_names('foo') == 1
|
||||
assert f_names(x='foo') == 1
|
||||
|
||||
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
|
||||
def test_new_static_argnum_on_keyword_arguments(self):
|
||||
f = self.jit(lambda x: x, static_argnums=0)
|
||||
y = f(x=4)
|
||||
assert y == 4
|
||||
|
||||
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
|
||||
def test_new_static_argnum_with_default_arguments(self):
|
||||
f = self.jit(lambda x=4: x, static_argnums=0)
|
||||
y = f()
|
||||
assert y == 4
|
||||
|
||||
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
|
||||
def test_jit_with_mismatched_static_argnames(self):
|
||||
x_is_tracer, y_is_tracer = False, False
|
||||
def f(x, y):
|
||||
@ -646,7 +632,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
# TODO(zhangqiaorjc): Test pruning constants after DCE pass prunes primitive
|
||||
# applications.
|
||||
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_num_args={}".format(num_args),
|
||||
"num_args": num_args}
|
||||
@ -2697,8 +2682,6 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIn('Precision.HIGH', str(jaxpr))
|
||||
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
|
||||
|
||||
@unittest.skipIf(jax.lib._xla_extension_version <= 17,
|
||||
"Test requires jaxlib 0.1.66")
|
||||
def test_dot_precision_forces_retrace(self):
|
||||
num_traces = 0
|
||||
|
||||
@ -2738,9 +2721,6 @@ class APITest(jtu.JaxTestCase):
|
||||
finally:
|
||||
FLAGS.jax_default_matmul_precision = precision
|
||||
|
||||
|
||||
@unittest.skipIf(jax.lib._xla_extension_version <= 17,
|
||||
"Test requires jaxlib 0.1.66")
|
||||
def test_rank_promotion_forces_retrace(self):
|
||||
num_traces = 0
|
||||
|
||||
@ -5508,8 +5488,6 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
for jit_type in [None, "python", "cpp"]
|
||||
if not (jit_type is None and func == 'identity')))
|
||||
def test_integer_overflow(self, jit_type, func):
|
||||
if jit_type == "cpp" and not config.x64_enabled and jax.lib.version < (0, 1, 65):
|
||||
self.skipTest("int32 overflow detection not yet implemented in CPP JIT.")
|
||||
funcdict = {
|
||||
'identity': lambda x: x,
|
||||
'asarray': jnp.asarray,
|
||||
|
@ -80,10 +80,6 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
np = rng(shape, dtype)
|
||||
if gpu and jax.default_backend() == "cpu":
|
||||
raise unittest.SkipTest("Skipping GPU test case on CPU")
|
||||
if (not gpu and jax.default_backend() == "gpu" and
|
||||
jax.lib._xla_extension_version < 25):
|
||||
raise unittest.SkipTest("Mixed CPU/GPU dlpack support requires jaxlib "
|
||||
"0.1.68 or newer")
|
||||
device = jax.devices("gpu" if gpu else "cpu")[0]
|
||||
x = jax.device_put(np, device)
|
||||
dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
|
||||
|
@ -20,7 +20,6 @@ import jax.test_util as jtu
|
||||
import numpy as np
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
|
||||
from jax.config import config
|
||||
@ -29,7 +28,6 @@ FLAGS = config.FLAGS
|
||||
|
||||
class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
def test_compile_options(self):
|
||||
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
@ -40,7 +38,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertEqual(filled_hash1, filled_hash2)
|
||||
self.assertNotEqual(filled_hash1, not_filled_hash3)
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
def test_executable_build_options(self):
|
||||
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
@ -66,7 +63,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
hash3 = self.get_hashed_value(cc._hash_debug_options, new_debug_options)
|
||||
self.assertNotEqual(hash1, hash3)
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
def test_hash_platform(self):
|
||||
hash1 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
|
||||
hash2 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
|
||||
@ -97,7 +93,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertEqual(hash2, hash3)
|
||||
self.assertNotEqual(hash1, hash2)
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
def test_same_hash_key(self):
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
@ -105,7 +100,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertEqual(cc.get_cache_key(computation, compile_options),
|
||||
cc.get_cache_key(computation, compile_options))
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
def test_different_hash_key(self):
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
|
||||
@ -114,7 +108,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled),
|
||||
cc.get_cache_key(computation, compile_options_filled))
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
def test_different_computations(self):
|
||||
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
|
||||
@ -123,7 +116,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(cc.get_cache_key(computation1, compile_options),
|
||||
cc.get_cache_key(computation2, compile_options))
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
|
||||
def test_get_no_executable(self):
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest("serialize executable only works on TPU")
|
||||
@ -137,7 +129,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
num_replicas=1, num_partitions=1)
|
||||
self.assertEqual(cc.get_executable(computation, compile_options), None)
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
|
||||
def test_diff_executables(self):
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest("serialize executable only works on TPU")
|
||||
@ -158,7 +149,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(cc.get_executable(computation1, compile_options),
|
||||
cc.get_executable(computation2, compile_options))
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
|
||||
def test_put_executable(self):
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest("serialize executable only works on TPU")
|
||||
|
@ -26,7 +26,6 @@ import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax.lib import xla_extension
|
||||
|
||||
|
||||
from jax.config import config
|
||||
@ -58,11 +57,6 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=[],
|
||||
if filter_mode == "tracebackhide":
|
||||
if "__tracebackhide__" in frame.f_locals.keys():
|
||||
continue
|
||||
elif filter_mode == "remove_frames":
|
||||
# TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum.
|
||||
if (not hasattr(xla_extension, "replace_thread_exc_traceback") and
|
||||
frame.f_code.co_name == "reraise_with_filtered_traceback"):
|
||||
continue
|
||||
frames.append((frame, lineno))
|
||||
|
||||
c_tb = traceback.format_list(traceback.StackSummary.extract(frames))
|
||||
@ -341,12 +335,9 @@ class UserContextTracebackTest(jtu.JaxTestCase):
|
||||
e = exc
|
||||
self.assertIsNot(e, None)
|
||||
self.assertIn("invalid value", str(e))
|
||||
# TODO(phawkins): make this test unconditional after jaxlib 0.1.66 is the
|
||||
# minimum.
|
||||
if jax.lib._xla_extension_version >= 19:
|
||||
self.assertIsInstance(
|
||||
e.__cause__.__cause__,
|
||||
source_info_util.JaxStackTraceBeforeTransformation)
|
||||
self.assertIsInstance(
|
||||
e.__cause__.__cause__,
|
||||
source_info_util.JaxStackTraceBeforeTransformation)
|
||||
|
||||
|
||||
class CustomErrorsTest(jtu.JaxTestCase):
|
||||
|
@ -324,10 +324,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
if not config.x64_enabled:
|
||||
raise unittest.SkipTest("requires x64 mode")
|
||||
|
||||
# The LLVM bug that caused this appears to be fixed in jaxlib 0.1.67.
|
||||
if jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 66):
|
||||
raise unittest.SkipTest("test fails on CPU jaxlib <= 0.1.66")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
A = rng(shape, dtype)
|
||||
b = rng(shape[:1], dtype)
|
||||
|
@ -2417,16 +2417,10 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self.assertEqual(dtypes.result_type(val), dtypes.result_type(const))
|
||||
|
||||
|
||||
# TODO(phawkins): make this test unconditional after jaxlib 0.1.67 is the
|
||||
# default.
|
||||
@unittest.skipIf(jax.lib._xla_extension_version < 22,
|
||||
"Test requires jaxlib 0.1.67 or newer")
|
||||
def testIgammaSpecial(self):
|
||||
self.assertEqual(lax.igamma(1., np.inf), 1.)
|
||||
self.assertEqual(lax.igammac(1., np.inf), 0.)
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 66),
|
||||
"Test fails on jaxlib 0.1.65 or earlier.")
|
||||
def testRegressionIssue5728(self):
|
||||
# The computation in this test gave garbage data on CPU due to an LLVM bug.
|
||||
@jax.jit
|
||||
@ -2642,8 +2636,6 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
out = lax.cumsum(x)
|
||||
self.assertArraysEqual(out, x)
|
||||
|
||||
@unittest.skipIf(jax.lib._xla_extension_version < 24,
|
||||
"Test requires Jaxlib 0.1.68")
|
||||
def testLog1pNearOne(self):
|
||||
np.testing.assert_array_almost_equal_nulp(
|
||||
np.log1p(np.float32(1e-5)), lax.log1p(np.float32(1e-5)))
|
||||
|
@ -356,7 +356,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
|
||||
rtol=1e-3)
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
def testEighZeroDiagonal(self):
|
||||
a = np.array([[0., -1., -1., 1.],
|
||||
[-1., 0., 1., -1.],
|
||||
|
@ -76,8 +76,7 @@ class ProfilerTest(unittest.TestCase):
|
||||
self.assertIn(b"/host:CPU", proto)
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.assertIn(b"/device:TPU", proto)
|
||||
if jax.lib.version >= (0, 1, 65):
|
||||
self.assertIn(b"pxla.py", proto)
|
||||
self.assertIn(b"pxla.py", proto)
|
||||
|
||||
def testProgrammaticProfilingErrors(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "No profile started"):
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -25,7 +24,6 @@ from jax import tree_util
|
||||
from jax._src.tree_util import _process_pytree
|
||||
from jax import flatten_util
|
||||
import jax.numpy as jnp
|
||||
from jax import lib
|
||||
|
||||
|
||||
def _dummy_func(*args, **kwargs):
|
||||
@ -306,8 +304,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
FlatCache({"a": [3, 4], "b": [5, 6]}))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(lib._xla_extension_version < 17,
|
||||
"Test requires jaxlib 0.1.66.")
|
||||
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
|
||||
def testStringRepresentation(self, tree, correct_string):
|
||||
"""Checks that the string representation of a tree works."""
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import test_util as jtu
|
||||
@ -23,7 +21,6 @@ from jax.interpreters import xla
|
||||
|
||||
class XlaInterpreterTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
|
||||
def test_prune_jit_args(self):
|
||||
def f(*args):
|
||||
return args[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user