From 3ddcec27f2d55bd55c68766167a49610f4cd0541 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 15 Jul 2021 16:39:18 -0400 Subject: [PATCH] Update minimum jaxlib version to 0.1.69. --- CHANGELOG.md | 4 ++ build/test-requirements.txt | 4 +- docs/requirements.txt | 3 +- jax/_src/api.py | 82 ++++++++-------------------- jax/_src/dlpack.py | 24 +++----- jax/_src/lax/lax.py | 17 ++---- jax/_src/scipy/linalg.py | 7 +-- jax/_src/source_info_util.py | 4 -- jax/_src/traceback_util.py | 11 +--- jax/experimental/sparse/ops.py | 4 +- jax/interpreters/xla.py | 17 +----- jax/lib/__init__.py | 5 +- jax/lib/xla_bridge.py | 54 ++++++------------ jax/version.py | 2 +- tests/api_test.py | 22 -------- tests/array_interoperability_test.py | 4 -- tests/compilation_cache_test.py | 10 ---- tests/errors_test.py | 15 +---- tests/lax_scipy_sparse_test.py | 4 -- tests/lax_test.py | 8 --- tests/linalg_test.py | 1 - tests/profiler_test.py | 3 +- tests/tree_util_test.py | 4 -- tests/xla_interpreter_test.py | 3 - 24 files changed, 71 insertions(+), 241 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d265d7ce..856d80f97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,12 +15,16 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * Support for Python 3.6 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). Please upgrade to a supported Python version. + * The minimum jaxlib version is now 0.1.69. + * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been + removed. * Bug fixes: * Tightened the checks for lax.argmin and lax.argmax to ensure they are not used with invalid `axis` value, or with an empty reduction dimension. ({jax-issue}`#7196`) + ## jaxlib 0.1.70 (unreleased) * Breaking changes: * Support for Python 3.6 has been dropped, per the diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 4259f2b7b..fb9fbcf39 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -1,8 +1,6 @@ flake8 flatbuffers==1.12 -# For now, we pin the numpy version here -# TODO(jakevdp): unpin maximum version when minimum jaxlib supports newer numpy -numpy>=1.17,<1.21 +numpy>=1.17 mypy==0.902 pillow>=8.3.1 pytest-benchmark diff --git a/docs/requirements.txt b/docs/requirements.txt index 0a1c73342..6ebda8815 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -13,6 +13,5 @@ pytest-xdist # Packages used for notebook execution matplotlib scikit-learn -# TODO(jakevdp) remove numpy pinning when minimum jaxlib supports newer numpy. -numpy<1.21 +numpy .[cpu] # Install jax from the current directory; jaxlib from pypi. diff --git a/jax/_src/api.py b/jax/_src/api.py index 49b31a2a1..49358c1b3 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -197,12 +197,6 @@ def _infer_argnums_and_argnames( return argnums, argnames -# Static kwargs require jaxlib 0.1.66 or newer. (Technically the Python jit -# implementation works with any jaxlib, but we want consistency between C++ and -# Python implementations.) -# TODO(phawkins): remove when jaxlib 0.1.66 is the minimum. -_ALLOW_STATIC_ARGNAMES = lib._xla_extension_version >= 14 - def jit( fun: F, *, @@ -285,10 +279,6 @@ def jit( [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ] """ - if not _ALLOW_STATIC_ARGNAMES: - if static_argnames is not None: - raise ValueError("static_argnames requires jaxlib 0.1.66 or newer") - static_argnames = () if FLAGS.experimental_cpp_jit: return _cpp_jit(fun, static_argnums, static_argnames, device, backend, donate_argnums, inline) @@ -319,17 +309,10 @@ def _python_jit( def f_jitted(*args, **kwargs): if config.jax_disable_jit: return fun(*args, **kwargs) - if _ALLOW_STATIC_ARGNAMES: - if max(donate_argnums, default=-1) >= len(args): - raise ValueError( - f"jitted function has donate_argnums={donate_argnums} but " - f"was called with only {len(args)} positional arguments.") - else: - if max(static_argnums + donate_argnums, default=-1) >= len(args): - raise ValueError( - f"jitted function has static_argnums={static_argnums}, " - f"donate_argnums={donate_argnums} but " - f"was called with only {len(args)} positional arguments.") + if max(donate_argnums, default=-1) >= len(args): + raise ValueError( + f"jitted function has donate_argnums={donate_argnums} but " + f"was called with only {len(args)} positional arguments.") f = lu.wrap_init(fun) f, args = argnums_partial_except(f, static_argnums, args, @@ -362,8 +345,7 @@ class _FastpathData(NamedTuple): lazy_exprs: Iterable[Any] kept_var_bitvec: Iterable[bool] -if lib._xla_extension_version >= 16: - _cpp_jit_cache = jax_jit.CompiledFunctionCache() +_cpp_jit_cache = jax_jit.CompiledFunctionCache() def _cpp_jit( fun: F, @@ -398,17 +380,10 @@ def _cpp_jit( # An alternative would be for cache_miss to accept from C++ the arguments # (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have # work/code that is redundant between C++ and Python. We can try that later. - if _ALLOW_STATIC_ARGNAMES: - if max(donate_argnums, default=-1) >= len(args): - raise ValueError( - f"jitted function has donate_argnums={donate_argnums} but " - f"was called with only {len(args)} positional arguments.") - else: - if max(static_argnums + donate_argnums, default=-1) >= len(args): - raise ValueError( - f"jitted function has static_argnums={static_argnums}, " - f"donate_argnums={donate_argnums} but " - f"was called with only {len(args)} positional arguments.") + if max(donate_argnums, default=-1) >= len(args): + raise ValueError( + f"jitted function has donate_argnums={donate_argnums} but " + f"was called with only {len(args)} positional arguments.") f = lu.wrap_init(fun) f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True) f, kwargs = argnames_partial_except(f, static_argnames, kwargs) @@ -452,14 +427,10 @@ def _cpp_jit( aval, sticky_device = result_handler.args avals.append(aval) assert len(avals) == len(out_flat) - if xla._ALLOW_ARG_PRUNING: - kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))] - fastpath_data = _FastpathData(xla_executable, out_pytree_def, - sticky_device, avals, lazy_exprs, - kept_var_bitvec) - else: - fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals, - lazy_exprs) + kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))] + fastpath_data = _FastpathData(xla_executable, out_pytree_def, + sticky_device, avals, lazy_exprs, + kept_var_bitvec) else: fastpath_data = None @@ -477,21 +448,12 @@ def _cpp_jit( return _BackendAndDeviceInfo(default_device, committed_to_device) - if lib._xla_extension_version < 14: - cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, static_argnums) - f_jitted = wraps(fun)(cpp_jitted_f) - elif lib._xla_extension_version < 16: - cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, - static_argnums=static_argnums, - static_argnames=static_argnames) - f_jitted = wraps(fun)(cpp_jitted_f) - else: - cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - cache=_cpp_jit_cache) - f_jitted = wraps(fun)(cpp_jitted_f) + cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + cache=_cpp_jit_cache) + f_jitted = wraps(fun)(cpp_jitted_f) return f_jitted @@ -627,7 +589,7 @@ def xla_computation(fun: Callable, Alternatively, the assignment to ``c`` above could be written: >>> import types - >>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32) + >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> c = jax.xla_computation(f)(scalar) @@ -2063,7 +2025,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: >>> import types >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y - >>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32) + >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32)) @@ -2441,7 +2403,7 @@ def eval_shape(fun: Callable, *args, **kwargs): >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape - ... self.dtype = dtype + ... self.dtype = jnp.dtype(dtype) ... >>> A = MyArgArray((2000, 3000), jnp.float32) >>> x = MyArgArray((3000, 1000), jnp.float32) diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 2d5625677..442894043 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -15,7 +15,6 @@ from jax import core from jax import numpy as jnp from jax.interpreters import xla -import jax.lib from jax.lib import xla_client from jax.lib import xla_bridge @@ -44,28 +43,21 @@ def to_dlpack(x: xla.DeviceArrayProtocol, take_ownership: bool = False): return xla_client._xla.buffer_to_dlpack_managed_tensor( x.device_buffer, take_ownership=take_ownership) -def from_dlpack(dlpack, backend=None): +def from_dlpack(dlpack): """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`. The returned `DeviceArray` shares memory with `dlpack`. Args: dlpack: a DLPack tensor, on either CPU or GPU. - backend: deprecated, do not use. """ - if jax.lib._xla_extension_version >= 25: - cpu_backend = xla_bridge.get_backend("cpu") - try: - gpu_backend = xla_bridge.get_backend("gpu") - except RuntimeError: - gpu_backend = None - buf = xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, cpu_backend, gpu_backend) - else: - # TODO(phawkins): drop the backend argument after deleting this case. - backend = backend or xla_bridge.get_backend() - client = getattr(backend, "client", backend) - buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client) + cpu_backend = xla_bridge.get_backend("cpu") + try: + gpu_backend = xla_bridge.get_backend("gpu") + except RuntimeError: + gpu_backend = None + buf = xla_client._xla.dlpack_managed_tensor_to_buffer( + dlpack, cpu_backend, gpu_backend) xla_shape = buf.xla_shape() assert not xla_shape.is_tuple() diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 50f8149aa..e9480f302 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2443,10 +2443,7 @@ acos_p = standard_unop(_float | _complex, 'acos', ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) def atan_translation_rule(x): - if jax.lib._xla_extension_version < 26 and dtypes.issubdtype(_dtype(x), np.complexfloating): - return mul(_const(x, -1j), atanh(mul(_const(x, 1j), x))) - else: - return atan2(x, _const(x, 1)) + return atan2(x, _const(x, 1)) atan_p = standard_unop(_float | _complex, 'atan', translation_rule=xla.lower_fun(atan_translation_rule, @@ -6149,14 +6146,10 @@ ad.primitive_transposes[select_and_gather_add_p] = \ batching.primitive_batchers[select_and_gather_add_p] = \ _select_and_gather_add_batching_rule # TODO(b/183233858): use variadic reducewindow on GPU, when implemented. -if jax.lib._xla_extension_version >= 15: - xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \ - _select_and_gather_add_translation_using_variadic_reducewindow - xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \ - _select_and_gather_add_translation_using_variadic_reducewindow -else: - xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial( - _select_and_gather_add_translation, max_bits=32) +xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \ + _select_and_gather_add_translation_using_variadic_reducewindow +xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \ + _select_and_gather_add_translation_using_variadic_reducewindow def _sort_abstract_eval(*args, **kwargs): diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index e3447729a..acd42f468 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -437,15 +437,10 @@ def block_diag(*arrs): return acc -# TODO(phawkins): use static_argnames when jaxlib 0.1.66 is the minimum and -# remove this wrapper. @_wraps(scipy.linalg.eigh_tridiagonal) +@partial(jit, static_argnames=("eigvals_only", "select", "select_range")) def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None): - return _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol) - -@partial(jit, static_argnums=(2, 3, 4)) -def _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol): if not eigvals_only: raise NotImplementedError("Calculation of eigenvectors is not implemented") diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index 7b78b63ab..320f6fdea 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -90,10 +90,6 @@ def user_context(c): except Exception as e: if c is None or has_user_context(e): raise - # TODO(phawkins): remove the following condition after Jaxlib 0.1.66 is the - # minimum. - if not hasattr(c, 'as_python_traceback'): - raise filtered_tb = traceback_util.filter_traceback(c.as_python_traceback()) if filtered_tb: msg = traceback_util.format_exception_only(e) diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index aeef289b7..e073e362f 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -208,15 +208,8 @@ def api_boundary(fun): # There seems to be no way to alter the currently raised exception's # traceback, except via the C API. The currently raised exception # is part of the interpreter's thread state: value `e` is a copy. - if hasattr(xla_extension, 'replace_thread_exc_traceback'): - xla_extension.replace_thread_exc_traceback(filtered_tb) - raise - else: - # TODO(phawkins): remove this case when jaxlib 0.1.66 is the - # minimum. - - # Fallback case for older jaxlibs; includes the current frame. - raise e.with_traceback(filtered_tb) + xla_extension.replace_thread_exc_traceback(filtered_tb) + raise finally: del filtered_tb del unfiltered diff --git a/jax/experimental/sparse/ops.py b/jax/experimental/sparse/ops.py index 57643b744..fdd98cfa1 100644 --- a/jax/experimental/sparse/ops.py +++ b/jax/experimental/sparse/ops.py @@ -23,8 +23,8 @@ product, sparse matrix/matrix product) for two common sparse representations These routines have reference implementations defined via XLA scatter/gather operations that will work on any backend, although they are not particularly -performant. On GPU runtimes with jaxlib 0.1.66 or newer built against CUDA 11.0 -or newer, each operation is computed efficiently via cusparse. +performant. On GPU runtimes built against CUDA 11.0 or newer, each operation is +computed efficiently via cusparse. Further down are some examples of potential high-level wrappers for sparse objects. (API should be considered unstable and subject to change). diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index b476c3043..2e2dd4d65 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -38,7 +38,6 @@ from .._src.util import (partial, partialmethod, cache, prod, unzip2, extend_name_stack, wrap_name, safe_zip, safe_map) from ..lib import xla_bridge as xb from ..lib import xla_client as xc -from ..lib import _xla_extension_version from . import partial_eval as pe from . import ad from . import masking @@ -778,18 +777,8 @@ def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args): return tuple(out_donated_args) -# Pruning unused JIT arguments require jaxlib 0.1.66 or newer. -# TODO(zhangqiaorjc): remove when jaxlib 0.1.66 is the minimum. -_ALLOW_ARG_PRUNING = _xla_extension_version >= 20 - - def _prune_unused_inputs( jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]: - if not _ALLOW_ARG_PRUNING: - kept_const_idx = range(len(jaxpr.constvars)) - kept_var_idx = range(len(jaxpr.invars)) - return jaxpr, set(kept_const_idx), set(kept_var_idx) - used = {v for v in jaxpr.outvars if isinstance(v, core.Var)} # TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive # applications that do not produce used outputs. Must handle side-effecting @@ -1111,11 +1100,9 @@ class _DeviceArray(DeviceArray): # type: ignore """A DeviceArray is an ndarray backed by a single device memory buffer.""" # We don't subclass ndarray because that would open up a host of issues, # but lax_numpy.py overrides isinstance behavior and attaches ndarray methods. - # TODO(phawkins): make __weakref__ an unconditional slot when jaxlib 0.1.66 - # is the minimum version. __slots__ = [ - "aval", "device_buffer", "_npy_value", "_device" - ] + ([] if device_array_supports_weakrefs() else ["__weakref__"]) + "aval", "device_buffer", "_npy_value", "_device", "__weakref__" + ] __array_priority__ = 100 # DeviceArray has methods that are dynamically populated in lax_numpy.py, diff --git a/jax/lib/__init__.py b/jax/lib/__init__.py index ef216619f..e92c5d950 100644 --- a/jax/lib/__init__.py +++ b/jax/lib/__init__.py @@ -67,9 +67,8 @@ def _check_jaxlib_version(): _check_jaxlib_version() -if version >= (0, 1, 68): - from jaxlib import cpu_feature_guard - cpu_feature_guard.check_cpu_features() +from jaxlib import cpu_feature_guard +cpu_feature_guard.check_cpu_features() from jaxlib import xla_client from jaxlib import lapack diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 305ce7816..aa174486b 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -155,40 +155,23 @@ def register_backend_factory(name, factory, *, priority=0): _backend_factories[name] = (factory, priority) -if jax.lib._xla_extension_version >= 23: - register_backend_factory('interpreter', xla_client.make_interpreter_client, - priority=-100) - if jax.lib._xla_extension_version >= 27: - if FLAGS.jax_cpu_backend_variant == 'stream_executor': - register_backend_factory('cpu', - partial(xla_client.make_cpu_client, use_tfrt=False), - priority=0) - else: - assert FLAGS.jax_cpu_backend_variant == 'tfrt' - register_backend_factory('cpu', - partial(xla_client.make_cpu_client, use_tfrt=True), - priority=0) - else: - register_backend_factory('cpu', - partial(xla_client.make_cpu_client, use_tfrt=False), - priority=0) - register_backend_factory('tpu_driver', _make_tpu_driver_client, - priority=100) - register_backend_factory('gpu', xla_client.make_gpu_client, - priority=200) - register_backend_factory('tpu', xla_client.make_tpu_client, - priority=300) +register_backend_factory('interpreter', xla_client.make_interpreter_client, + priority=-100) +if FLAGS.jax_cpu_backend_variant == 'stream_executor': + register_backend_factory('cpu', + partial(xla_client.make_cpu_client, use_tfrt=False), + priority=0) else: - register_backend_factory('interpreter', - xla_client._interpreter_backend_factory, - priority=-100) - register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0) - register_backend_factory('tpu_driver', _make_tpu_driver_client, - priority=100) - register_backend_factory('gpu', xla_client._gpu_backend_factory, - priority=200) - register_backend_factory('tpu', xla_client._tpu_backend_factory, - priority=300) + assert FLAGS.jax_cpu_backend_variant == 'tfrt' + register_backend_factory('cpu', + partial(xla_client.make_cpu_client, use_tfrt=True), + priority=0) +register_backend_factory('tpu_driver', _make_tpu_driver_client, + priority=100) +register_backend_factory('gpu', xla_client.make_gpu_client, + priority=200) +register_backend_factory('tpu', xla_client.make_tpu_client, + priority=300) _default_backend = None _backends = None @@ -258,11 +241,6 @@ def get_backend(platform=None): def get_device_backend(device=None): """Returns the Backend associated with `device`, or the default Backend.""" if device is not None: - # TODO(phawkins): remove this workaround after jaxlib 0.1.68 becomes the - # minimum and it is safe to call `.client` on a tpu_driver TpuDevice. - if tpu_driver_client and isinstance( - device, tpu_driver_client._tpu_client.TpuDevice): - return get_backend('tpu_driver') return device.client return get_backend() diff --git a/jax/version.py b/jax/version.py index 195533a78..6ecc80f20 100644 --- a/jax/version.py +++ b/jax/version.py @@ -13,4 +13,4 @@ # limitations under the License. __version__ = "0.2.17" -_minimum_jaxlib_version = "0.1.65" +_minimum_jaxlib_version = "0.1.69" diff --git a/tests/api_test.py b/tests/api_test.py index 3344cb732..2cbc5d9b0 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -50,7 +50,6 @@ from jax import test_util as jtu from jax import tree_util from jax import linear_util as lu import jax._src.util -from jax._src.api import _ALLOW_STATIC_ARGNAMES from jax.config import config config.parse_flags_with_absl() @@ -309,15 +308,6 @@ class CPPJitTest(jtu.BufferDonationTestCase): f(2) assert len(effects) == 3 - # TODO(phawkins): delete this test after jaxlib 0.1.66 is the minimum. - @unittest.skipIf(_ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66") - def test_static_argnum_errors_on_keyword_arguments(self): - f = self.jit(lambda x: x, static_argnums=0) - msg = ("jitted function has static_argnums=(0,), donate_argnums=() but was " - "called with only 0 positional arguments.") - with self.assertRaisesRegex(ValueError, re.escape(msg)): - f(x=4) - def test_static_argnum_on_method(self): class A: @@ -587,7 +577,6 @@ class CPPJitTest(jtu.BufferDonationTestCase): assert argnums == () assert argnames == ('foo', 'bar') - @unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66") def test_jit_with_static_argnames(self): def f(x): @@ -602,19 +591,16 @@ class CPPJitTest(jtu.BufferDonationTestCase): assert f_names('foo') == 1 assert f_names(x='foo') == 1 - @unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66") def test_new_static_argnum_on_keyword_arguments(self): f = self.jit(lambda x: x, static_argnums=0) y = f(x=4) assert y == 4 - @unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66") def test_new_static_argnum_with_default_arguments(self): f = self.jit(lambda x=4: x, static_argnums=0) y = f() assert y == 4 - @unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66") def test_jit_with_mismatched_static_argnames(self): x_is_tracer, y_is_tracer = False, False def f(x, y): @@ -646,7 +632,6 @@ class CPPJitTest(jtu.BufferDonationTestCase): # TODO(zhangqiaorjc): Test pruning constants after DCE pass prunes primitive # applications. - @unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66") @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_num_args={}".format(num_args), "num_args": num_args} @@ -2697,8 +2682,6 @@ class APITest(jtu.JaxTestCase): self.assertIn('Precision.HIGH', str(jaxpr)) self.assertEqual(prev_val, config._read("jax_default_matmul_precision")) - @unittest.skipIf(jax.lib._xla_extension_version <= 17, - "Test requires jaxlib 0.1.66") def test_dot_precision_forces_retrace(self): num_traces = 0 @@ -2738,9 +2721,6 @@ class APITest(jtu.JaxTestCase): finally: FLAGS.jax_default_matmul_precision = precision - - @unittest.skipIf(jax.lib._xla_extension_version <= 17, - "Test requires jaxlib 0.1.66") def test_rank_promotion_forces_retrace(self): num_traces = 0 @@ -5508,8 +5488,6 @@ class NamedCallTest(jtu.JaxTestCase): for jit_type in [None, "python", "cpp"] if not (jit_type is None and func == 'identity'))) def test_integer_overflow(self, jit_type, func): - if jit_type == "cpp" and not config.x64_enabled and jax.lib.version < (0, 1, 65): - self.skipTest("int32 overflow detection not yet implemented in CPP JIT.") funcdict = { 'identity': lambda x: x, 'asarray': jnp.asarray, diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 201f251ea..4d2770ff9 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -80,10 +80,6 @@ class DLPackTest(jtu.JaxTestCase): np = rng(shape, dtype) if gpu and jax.default_backend() == "cpu": raise unittest.SkipTest("Skipping GPU test case on CPU") - if (not gpu and jax.default_backend() == "gpu" and - jax.lib._xla_extension_version < 25): - raise unittest.SkipTest("Mixed CPU/GPU dlpack support requires jaxlib " - "0.1.68 or newer") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index f8162338d..579e5ef07 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -20,7 +20,6 @@ import jax.test_util as jtu import numpy as np import random import tempfile -import unittest from unittest import SkipTest from jax.config import config @@ -29,7 +28,6 @@ FLAGS = config.FLAGS class CompilationCacheTest(jtu.JaxTestCase): - @unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs") def test_compile_options(self): compile_options_not_filled = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) @@ -40,7 +38,6 @@ class CompilationCacheTest(jtu.JaxTestCase): self.assertEqual(filled_hash1, filled_hash2) self.assertNotEqual(filled_hash1, not_filled_hash3) - @unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs") def test_executable_build_options(self): compile_options_not_filled = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) @@ -66,7 +63,6 @@ class CompilationCacheTest(jtu.JaxTestCase): hash3 = self.get_hashed_value(cc._hash_debug_options, new_debug_options) self.assertNotEqual(hash1, hash3) - @unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs") def test_hash_platform(self): hash1 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend()) hash2 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend()) @@ -97,7 +93,6 @@ class CompilationCacheTest(jtu.JaxTestCase): self.assertEqual(hash2, hash3) self.assertNotEqual(hash1, hash2) - @unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs") def test_same_hash_key(self): computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax.lib.xla_bridge.get_compile_options( @@ -105,7 +100,6 @@ class CompilationCacheTest(jtu.JaxTestCase): self.assertEqual(cc.get_cache_key(computation, compile_options), cc.get_cache_key(computation, compile_options)) - @unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs") def test_different_hash_key(self): computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options_not_filled = jax.lib.xla_bridge.get_compile_options( @@ -114,7 +108,6 @@ class CompilationCacheTest(jtu.JaxTestCase): self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled), cc.get_cache_key(computation, compile_options_filled)) - @unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs") def test_different_computations(self): computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1) computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2) @@ -123,7 +116,6 @@ class CompilationCacheTest(jtu.JaxTestCase): self.assertNotEqual(cc.get_cache_key(computation1, compile_options), cc.get_cache_key(computation2, compile_options)) - @unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs") def test_get_no_executable(self): if jtu.device_under_test() != "tpu": raise SkipTest("serialize executable only works on TPU") @@ -137,7 +129,6 @@ class CompilationCacheTest(jtu.JaxTestCase): num_replicas=1, num_partitions=1) self.assertEqual(cc.get_executable(computation, compile_options), None) - @unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs") def test_diff_executables(self): if jtu.device_under_test() != "tpu": raise SkipTest("serialize executable only works on TPU") @@ -158,7 +149,6 @@ class CompilationCacheTest(jtu.JaxTestCase): self.assertNotEqual(cc.get_executable(computation1, compile_options), cc.get_executable(computation2, compile_options)) - @unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs") def test_put_executable(self): if jtu.device_under_test() != "tpu": raise SkipTest("serialize executable only works on TPU") diff --git a/tests/errors_test.py b/tests/errors_test.py index a62880a14..f46f20390 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -26,7 +26,6 @@ import jax.numpy as jnp from jax import test_util as jtu from jax._src import source_info_util from jax._src import traceback_util -from jax.lib import xla_extension from jax.config import config @@ -58,11 +57,6 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=[], if filter_mode == "tracebackhide": if "__tracebackhide__" in frame.f_locals.keys(): continue - elif filter_mode == "remove_frames": - # TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum. - if (not hasattr(xla_extension, "replace_thread_exc_traceback") and - frame.f_code.co_name == "reraise_with_filtered_traceback"): - continue frames.append((frame, lineno)) c_tb = traceback.format_list(traceback.StackSummary.extract(frames)) @@ -341,12 +335,9 @@ class UserContextTracebackTest(jtu.JaxTestCase): e = exc self.assertIsNot(e, None) self.assertIn("invalid value", str(e)) - # TODO(phawkins): make this test unconditional after jaxlib 0.1.66 is the - # minimum. - if jax.lib._xla_extension_version >= 19: - self.assertIsInstance( - e.__cause__.__cause__, - source_info_util.JaxStackTraceBeforeTransformation) + self.assertIsInstance( + e.__cause__.__cause__, + source_info_util.JaxStackTraceBeforeTransformation) class CustomErrorsTest(jtu.JaxTestCase): diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 1adb06594..3e9991a99 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -324,10 +324,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase): if not config.x64_enabled: raise unittest.SkipTest("requires x64 mode") - # The LLVM bug that caused this appears to be fixed in jaxlib 0.1.67. - if jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 66): - raise unittest.SkipTest("test fails on CPU jaxlib <= 0.1.66") - rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) b = rng(shape[:1], dtype) diff --git a/tests/lax_test.py b/tests/lax_test.py index 58414f041..9b52b2d32 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2417,16 +2417,10 @@ class LaxTest(jtu.JaxTestCase): self.assertEqual(dtypes.result_type(val), dtypes.result_type(const)) - # TODO(phawkins): make this test unconditional after jaxlib 0.1.67 is the - # default. - @unittest.skipIf(jax.lib._xla_extension_version < 22, - "Test requires jaxlib 0.1.67 or newer") def testIgammaSpecial(self): self.assertEqual(lax.igamma(1., np.inf), 1.) self.assertEqual(lax.igammac(1., np.inf), 0.) - @unittest.skipIf(jax.lib.version < (0, 1, 66), - "Test fails on jaxlib 0.1.65 or earlier.") def testRegressionIssue5728(self): # The computation in this test gave garbage data on CPU due to an LLVM bug. @jax.jit @@ -2642,8 +2636,6 @@ class LazyConstantTest(jtu.JaxTestCase): out = lax.cumsum(x) self.assertArraysEqual(out, x) - @unittest.skipIf(jax.lib._xla_extension_version < 24, - "Test requires Jaxlib 0.1.68") def testLog1pNearOne(self): np.testing.assert_array_almost_equal_nulp( np.log1p(np.float32(1e-5)), lax.log1p(np.float32(1e-5))) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index b23708810..b41b4e69a 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -356,7 +356,6 @@ class NumpyLinalgTest(jtu.JaxTestCase): self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker, rtol=1e-3) - @unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs") def testEighZeroDiagonal(self): a = np.array([[0., -1., -1., 1.], [-1., 0., 1., -1.], diff --git a/tests/profiler_test.py b/tests/profiler_test.py index e2f86afc3..49a8c07d8 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -76,8 +76,7 @@ class ProfilerTest(unittest.TestCase): self.assertIn(b"/host:CPU", proto) if jtu.device_under_test() == "tpu": self.assertIn(b"/device:TPU", proto) - if jax.lib.version >= (0, 1, 65): - self.assertIn(b"pxla.py", proto) + self.assertIn(b"pxla.py", proto) def testProgrammaticProfilingErrors(self): with self.assertRaisesRegex(RuntimeError, "No profile started"): diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 24fd67183..61ff7b1f0 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -15,7 +15,6 @@ import collections import re -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -25,7 +24,6 @@ from jax import tree_util from jax._src.tree_util import _process_pytree from jax import flatten_util import jax.numpy as jnp -from jax import lib def _dummy_func(*args, **kwargs): @@ -306,8 +304,6 @@ class TreeTest(jtu.JaxTestCase): FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual) - @unittest.skipIf(lib._xla_extension_version < 17, - "Test requires jaxlib 0.1.66.") @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)]) def testStringRepresentation(self, tree, correct_string): """Checks that the string representation of a tree works.""" diff --git a/tests/xla_interpreter_test.py b/tests/xla_interpreter_test.py index d3b758aa0..bb719876e 100644 --- a/tests/xla_interpreter_test.py +++ b/tests/xla_interpreter_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - from absl.testing import absltest from jax import test_util as jtu @@ -23,7 +21,6 @@ from jax.interpreters import xla class XlaInterpreterTest(jtu.JaxTestCase): - @unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66") def test_prune_jit_args(self): def f(*args): return args[0]