Update minimum jaxlib version to 0.1.69.

This commit is contained in:
Peter Hawkins 2021-07-15 16:39:18 -04:00
parent 6aa20d8f8f
commit 3ddcec27f2
24 changed files with 71 additions and 241 deletions

View File

@ -15,12 +15,16 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* The minimum jaxlib version is now 0.1.69.
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
removed.
* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`#7196`)
## jaxlib 0.1.70 (unreleased)
* Breaking changes:
* Support for Python 3.6 has been dropped, per the

View File

@ -1,8 +1,6 @@
flake8
flatbuffers==1.12
# For now, we pin the numpy version here
# TODO(jakevdp): unpin maximum version when minimum jaxlib supports newer numpy
numpy>=1.17,<1.21
numpy>=1.17
mypy==0.902
pillow>=8.3.1
pytest-benchmark

View File

@ -13,6 +13,5 @@ pytest-xdist
# Packages used for notebook execution
matplotlib
scikit-learn
# TODO(jakevdp) remove numpy pinning when minimum jaxlib supports newer numpy.
numpy<1.21
numpy
.[cpu] # Install jax from the current directory; jaxlib from pypi.

View File

@ -197,12 +197,6 @@ def _infer_argnums_and_argnames(
return argnums, argnames
# Static kwargs require jaxlib 0.1.66 or newer. (Technically the Python jit
# implementation works with any jaxlib, but we want consistency between C++ and
# Python implementations.)
# TODO(phawkins): remove when jaxlib 0.1.66 is the minimum.
_ALLOW_STATIC_ARGNAMES = lib._xla_extension_version >= 14
def jit(
fun: F,
*,
@ -285,10 +279,6 @@ def jit(
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232 0.76827 0.59566 ]
"""
if not _ALLOW_STATIC_ARGNAMES:
if static_argnames is not None:
raise ValueError("static_argnames requires jaxlib 0.1.66 or newer")
static_argnames = ()
if FLAGS.experimental_cpp_jit:
return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline)
@ -319,17 +309,10 @@ def _python_jit(
def f_jitted(*args, **kwargs):
if config.jax_disable_jit:
return fun(*args, **kwargs)
if _ALLOW_STATIC_ARGNAMES:
if max(donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
else:
if max(static_argnums + donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has static_argnums={static_argnums}, "
f"donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
if max(donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
f = lu.wrap_init(fun)
f, args = argnums_partial_except(f, static_argnums, args,
@ -362,8 +345,7 @@ class _FastpathData(NamedTuple):
lazy_exprs: Iterable[Any]
kept_var_bitvec: Iterable[bool]
if lib._xla_extension_version >= 16:
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
def _cpp_jit(
fun: F,
@ -398,17 +380,10 @@ def _cpp_jit(
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
if _ALLOW_STATIC_ARGNAMES:
if max(donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
else:
if max(static_argnums + donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has static_argnums={static_argnums}, "
f"donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
if max(donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
f = lu.wrap_init(fun)
f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
f, kwargs = argnames_partial_except(f, static_argnames, kwargs)
@ -452,14 +427,10 @@ def _cpp_jit(
aval, sticky_device = result_handler.args
avals.append(aval)
assert len(avals) == len(out_flat)
if xla._ALLOW_ARG_PRUNING:
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
fastpath_data = _FastpathData(xla_executable, out_pytree_def,
sticky_device, avals, lazy_exprs,
kept_var_bitvec)
else:
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals,
lazy_exprs)
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
fastpath_data = _FastpathData(xla_executable, out_pytree_def,
sticky_device, avals, lazy_exprs,
kept_var_bitvec)
else:
fastpath_data = None
@ -477,21 +448,12 @@ def _cpp_jit(
return _BackendAndDeviceInfo(default_device, committed_to_device)
if lib._xla_extension_version < 14:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, static_argnums)
f_jitted = wraps(fun)(cpp_jitted_f)
elif lib._xla_extension_version < 16:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums=static_argnums,
static_argnames=static_argnames)
f_jitted = wraps(fun)(cpp_jitted_f)
else:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
cache=_cpp_jit_cache)
f_jitted = wraps(fun)(cpp_jitted_f)
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
cache=_cpp_jit_cache)
f_jitted = wraps(fun)(cpp_jitted_f)
return f_jitted
@ -627,7 +589,7 @@ def xla_computation(fun: Callable,
Alternatively, the assignment to ``c`` above could be written:
>>> import types
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32)
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> c = jax.xla_computation(f)(scalar)
@ -2063,7 +2025,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
>>> import types
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32)
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))
@ -2441,7 +2403,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
>>> class MyArgArray(object):
... def __init__(self, shape, dtype):
... self.shape = shape
... self.dtype = dtype
... self.dtype = jnp.dtype(dtype)
...
>>> A = MyArgArray((2000, 3000), jnp.float32)
>>> x = MyArgArray((3000, 1000), jnp.float32)

View File

@ -15,7 +15,6 @@
from jax import core
from jax import numpy as jnp
from jax.interpreters import xla
import jax.lib
from jax.lib import xla_client
from jax.lib import xla_bridge
@ -44,28 +43,21 @@ def to_dlpack(x: xla.DeviceArrayProtocol, take_ownership: bool = False):
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.device_buffer, take_ownership=take_ownership)
def from_dlpack(dlpack, backend=None):
def from_dlpack(dlpack):
"""Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.
The returned `DeviceArray` shares memory with `dlpack`.
Args:
dlpack: a DLPack tensor, on either CPU or GPU.
backend: deprecated, do not use.
"""
if jax.lib._xla_extension_version >= 25:
cpu_backend = xla_bridge.get_backend("cpu")
try:
gpu_backend = xla_bridge.get_backend("gpu")
except RuntimeError:
gpu_backend = None
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)
else:
# TODO(phawkins): drop the backend argument after deleting this case.
backend = backend or xla_bridge.get_backend()
client = getattr(backend, "client", backend)
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
cpu_backend = xla_bridge.get_backend("cpu")
try:
gpu_backend = xla_bridge.get_backend("gpu")
except RuntimeError:
gpu_backend = None
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)
xla_shape = buf.xla_shape()
assert not xla_shape.is_tuple()

View File

@ -2443,10 +2443,7 @@ acos_p = standard_unop(_float | _complex, 'acos',
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
def atan_translation_rule(x):
if jax.lib._xla_extension_version < 26 and dtypes.issubdtype(_dtype(x), np.complexfloating):
return mul(_const(x, -1j), atanh(mul(_const(x, 1j), x)))
else:
return atan2(x, _const(x, 1))
return atan2(x, _const(x, 1))
atan_p = standard_unop(_float | _complex, 'atan',
translation_rule=xla.lower_fun(atan_translation_rule,
@ -6149,14 +6146,10 @@ ad.primitive_transposes[select_and_gather_add_p] = \
batching.primitive_batchers[select_and_gather_add_p] = \
_select_and_gather_add_batching_rule
# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
if jax.lib._xla_extension_version >= 15:
xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \
_select_and_gather_add_translation_using_variadic_reducewindow
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \
_select_and_gather_add_translation_using_variadic_reducewindow
else:
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
_select_and_gather_add_translation, max_bits=32)
xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \
_select_and_gather_add_translation_using_variadic_reducewindow
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \
_select_and_gather_add_translation_using_variadic_reducewindow
def _sort_abstract_eval(*args, **kwargs):

View File

@ -437,15 +437,10 @@ def block_diag(*arrs):
return acc
# TODO(phawkins): use static_argnames when jaxlib 0.1.66 is the minimum and
# remove this wrapper.
@_wraps(scipy.linalg.eigh_tridiagonal)
@partial(jit, static_argnames=("eigvals_only", "select", "select_range"))
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
select_range=None, tol=None):
return _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol)
@partial(jit, static_argnums=(2, 3, 4))
def _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol):
if not eigvals_only:
raise NotImplementedError("Calculation of eigenvectors is not implemented")

View File

@ -90,10 +90,6 @@ def user_context(c):
except Exception as e:
if c is None or has_user_context(e):
raise
# TODO(phawkins): remove the following condition after Jaxlib 0.1.66 is the
# minimum.
if not hasattr(c, 'as_python_traceback'):
raise
filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
if filtered_tb:
msg = traceback_util.format_exception_only(e)

View File

@ -208,15 +208,8 @@ def api_boundary(fun):
# There seems to be no way to alter the currently raised exception's
# traceback, except via the C API. The currently raised exception
# is part of the interpreter's thread state: value `e` is a copy.
if hasattr(xla_extension, 'replace_thread_exc_traceback'):
xla_extension.replace_thread_exc_traceback(filtered_tb)
raise
else:
# TODO(phawkins): remove this case when jaxlib 0.1.66 is the
# minimum.
# Fallback case for older jaxlibs; includes the current frame.
raise e.with_traceback(filtered_tb)
xla_extension.replace_thread_exc_traceback(filtered_tb)
raise
finally:
del filtered_tb
del unfiltered

View File

@ -23,8 +23,8 @@ product, sparse matrix/matrix product) for two common sparse representations
These routines have reference implementations defined via XLA scatter/gather
operations that will work on any backend, although they are not particularly
performant. On GPU runtimes with jaxlib 0.1.66 or newer built against CUDA 11.0
or newer, each operation is computed efficiently via cusparse.
performant. On GPU runtimes built against CUDA 11.0 or newer, each operation is
computed efficiently via cusparse.
Further down are some examples of potential high-level wrappers for sparse objects.
(API should be considered unstable and subject to change).

View File

@ -38,7 +38,6 @@ from .._src.util import (partial, partialmethod, cache, prod, unzip2,
extend_name_stack, wrap_name, safe_zip, safe_map)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from ..lib import _xla_extension_version
from . import partial_eval as pe
from . import ad
from . import masking
@ -778,18 +777,8 @@ def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
return tuple(out_donated_args)
# Pruning unused JIT arguments require jaxlib 0.1.66 or newer.
# TODO(zhangqiaorjc): remove when jaxlib 0.1.66 is the minimum.
_ALLOW_ARG_PRUNING = _xla_extension_version >= 20
def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
if not _ALLOW_ARG_PRUNING:
kept_const_idx = range(len(jaxpr.constvars))
kept_var_idx = range(len(jaxpr.invars))
return jaxpr, set(kept_const_idx), set(kept_var_idx)
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
# applications that do not produce used outputs. Must handle side-effecting
@ -1111,11 +1100,9 @@ class _DeviceArray(DeviceArray): # type: ignore
"""A DeviceArray is an ndarray backed by a single device memory buffer."""
# We don't subclass ndarray because that would open up a host of issues,
# but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
# TODO(phawkins): make __weakref__ an unconditional slot when jaxlib 0.1.66
# is the minimum version.
__slots__ = [
"aval", "device_buffer", "_npy_value", "_device"
] + ([] if device_array_supports_weakrefs() else ["__weakref__"])
"aval", "device_buffer", "_npy_value", "_device", "__weakref__"
]
__array_priority__ = 100
# DeviceArray has methods that are dynamically populated in lax_numpy.py,

View File

@ -67,9 +67,8 @@ def _check_jaxlib_version():
_check_jaxlib_version()
if version >= (0, 1, 68):
from jaxlib import cpu_feature_guard
cpu_feature_guard.check_cpu_features()
from jaxlib import cpu_feature_guard
cpu_feature_guard.check_cpu_features()
from jaxlib import xla_client
from jaxlib import lapack

View File

@ -155,40 +155,23 @@ def register_backend_factory(name, factory, *, priority=0):
_backend_factories[name] = (factory, priority)
if jax.lib._xla_extension_version >= 23:
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
if jax.lib._xla_extension_version >= 27:
if FLAGS.jax_cpu_backend_variant == 'stream_executor':
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
else:
assert FLAGS.jax_cpu_backend_variant == 'tfrt'
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=True),
priority=0)
else:
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('tpu', xla_client.make_tpu_client,
priority=300)
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
if FLAGS.jax_cpu_backend_variant == 'stream_executor':
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
else:
register_backend_factory('interpreter',
xla_client._interpreter_backend_factory,
priority=-100)
register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client._gpu_backend_factory,
priority=200)
register_backend_factory('tpu', xla_client._tpu_backend_factory,
priority=300)
assert FLAGS.jax_cpu_backend_variant == 'tfrt'
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=True),
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('tpu', xla_client.make_tpu_client,
priority=300)
_default_backend = None
_backends = None
@ -258,11 +241,6 @@ def get_backend(platform=None):
def get_device_backend(device=None):
"""Returns the Backend associated with `device`, or the default Backend."""
if device is not None:
# TODO(phawkins): remove this workaround after jaxlib 0.1.68 becomes the
# minimum and it is safe to call `.client` on a tpu_driver TpuDevice.
if tpu_driver_client and isinstance(
device, tpu_driver_client._tpu_client.TpuDevice):
return get_backend('tpu_driver')
return device.client
return get_backend()

View File

@ -13,4 +13,4 @@
# limitations under the License.
__version__ = "0.2.17"
_minimum_jaxlib_version = "0.1.65"
_minimum_jaxlib_version = "0.1.69"

View File

@ -50,7 +50,6 @@ from jax import test_util as jtu
from jax import tree_util
from jax import linear_util as lu
import jax._src.util
from jax._src.api import _ALLOW_STATIC_ARGNAMES
from jax.config import config
config.parse_flags_with_absl()
@ -309,15 +308,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f(2)
assert len(effects) == 3
# TODO(phawkins): delete this test after jaxlib 0.1.66 is the minimum.
@unittest.skipIf(_ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_static_argnum_errors_on_keyword_arguments(self):
f = self.jit(lambda x: x, static_argnums=0)
msg = ("jitted function has static_argnums=(0,), donate_argnums=() but was "
"called with only 0 positional arguments.")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
f(x=4)
def test_static_argnum_on_method(self):
class A:
@ -587,7 +577,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
assert argnums == ()
assert argnames == ('foo', 'bar')
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_jit_with_static_argnames(self):
def f(x):
@ -602,19 +591,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
assert f_names('foo') == 1
assert f_names(x='foo') == 1
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_new_static_argnum_on_keyword_arguments(self):
f = self.jit(lambda x: x, static_argnums=0)
y = f(x=4)
assert y == 4
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_new_static_argnum_with_default_arguments(self):
f = self.jit(lambda x=4: x, static_argnums=0)
y = f()
assert y == 4
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_jit_with_mismatched_static_argnames(self):
x_is_tracer, y_is_tracer = False, False
def f(x, y):
@ -646,7 +632,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
# TODO(zhangqiaorjc): Test pruning constants after DCE pass prunes primitive
# applications.
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_num_args={}".format(num_args),
"num_args": num_args}
@ -2697,8 +2682,6 @@ class APITest(jtu.JaxTestCase):
self.assertIn('Precision.HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
@unittest.skipIf(jax.lib._xla_extension_version <= 17,
"Test requires jaxlib 0.1.66")
def test_dot_precision_forces_retrace(self):
num_traces = 0
@ -2738,9 +2721,6 @@ class APITest(jtu.JaxTestCase):
finally:
FLAGS.jax_default_matmul_precision = precision
@unittest.skipIf(jax.lib._xla_extension_version <= 17,
"Test requires jaxlib 0.1.66")
def test_rank_promotion_forces_retrace(self):
num_traces = 0
@ -5508,8 +5488,6 @@ class NamedCallTest(jtu.JaxTestCase):
for jit_type in [None, "python", "cpp"]
if not (jit_type is None and func == 'identity')))
def test_integer_overflow(self, jit_type, func):
if jit_type == "cpp" and not config.x64_enabled and jax.lib.version < (0, 1, 65):
self.skipTest("int32 overflow detection not yet implemented in CPP JIT.")
funcdict = {
'identity': lambda x: x,
'asarray': jnp.asarray,

View File

@ -80,10 +80,6 @@ class DLPackTest(jtu.JaxTestCase):
np = rng(shape, dtype)
if gpu and jax.default_backend() == "cpu":
raise unittest.SkipTest("Skipping GPU test case on CPU")
if (not gpu and jax.default_backend() == "gpu" and
jax.lib._xla_extension_version < 25):
raise unittest.SkipTest("Mixed CPU/GPU dlpack support requires jaxlib "
"0.1.68 or newer")
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)

View File

@ -20,7 +20,6 @@ import jax.test_util as jtu
import numpy as np
import random
import tempfile
import unittest
from unittest import SkipTest
from jax.config import config
@ -29,7 +28,6 @@ FLAGS = config.FLAGS
class CompilationCacheTest(jtu.JaxTestCase):
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_compile_options(self):
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
@ -40,7 +38,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertEqual(filled_hash1, filled_hash2)
self.assertNotEqual(filled_hash1, not_filled_hash3)
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_executable_build_options(self):
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
@ -66,7 +63,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
hash3 = self.get_hashed_value(cc._hash_debug_options, new_debug_options)
self.assertNotEqual(hash1, hash3)
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_hash_platform(self):
hash1 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
hash2 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
@ -97,7 +93,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertEqual(hash2, hash3)
self.assertNotEqual(hash1, hash2)
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_same_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax.lib.xla_bridge.get_compile_options(
@ -105,7 +100,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertEqual(cc.get_cache_key(computation, compile_options),
cc.get_cache_key(computation, compile_options))
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_different_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
@ -114,7 +108,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled),
cc.get_cache_key(computation, compile_options_filled))
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_different_computations(self):
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
@ -123,7 +116,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(cc.get_cache_key(computation1, compile_options),
cc.get_cache_key(computation2, compile_options))
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
def test_get_no_executable(self):
if jtu.device_under_test() != "tpu":
raise SkipTest("serialize executable only works on TPU")
@ -137,7 +129,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
num_replicas=1, num_partitions=1)
self.assertEqual(cc.get_executable(computation, compile_options), None)
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
def test_diff_executables(self):
if jtu.device_under_test() != "tpu":
raise SkipTest("serialize executable only works on TPU")
@ -158,7 +149,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(cc.get_executable(computation1, compile_options),
cc.get_executable(computation2, compile_options))
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
def test_put_executable(self):
if jtu.device_under_test() != "tpu":
raise SkipTest("serialize executable only works on TPU")

View File

@ -26,7 +26,6 @@ import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import source_info_util
from jax._src import traceback_util
from jax.lib import xla_extension
from jax.config import config
@ -58,11 +57,6 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=[],
if filter_mode == "tracebackhide":
if "__tracebackhide__" in frame.f_locals.keys():
continue
elif filter_mode == "remove_frames":
# TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum.
if (not hasattr(xla_extension, "replace_thread_exc_traceback") and
frame.f_code.co_name == "reraise_with_filtered_traceback"):
continue
frames.append((frame, lineno))
c_tb = traceback.format_list(traceback.StackSummary.extract(frames))
@ -341,12 +335,9 @@ class UserContextTracebackTest(jtu.JaxTestCase):
e = exc
self.assertIsNot(e, None)
self.assertIn("invalid value", str(e))
# TODO(phawkins): make this test unconditional after jaxlib 0.1.66 is the
# minimum.
if jax.lib._xla_extension_version >= 19:
self.assertIsInstance(
e.__cause__.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)
self.assertIsInstance(
e.__cause__.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)
class CustomErrorsTest(jtu.JaxTestCase):

View File

@ -324,10 +324,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if not config.x64_enabled:
raise unittest.SkipTest("requires x64 mode")
# The LLVM bug that caused this appears to be fixed in jaxlib 0.1.67.
if jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 66):
raise unittest.SkipTest("test fails on CPU jaxlib <= 0.1.66")
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
b = rng(shape[:1], dtype)

View File

@ -2417,16 +2417,10 @@ class LaxTest(jtu.JaxTestCase):
self.assertEqual(dtypes.result_type(val), dtypes.result_type(const))
# TODO(phawkins): make this test unconditional after jaxlib 0.1.67 is the
# default.
@unittest.skipIf(jax.lib._xla_extension_version < 22,
"Test requires jaxlib 0.1.67 or newer")
def testIgammaSpecial(self):
self.assertEqual(lax.igamma(1., np.inf), 1.)
self.assertEqual(lax.igammac(1., np.inf), 0.)
@unittest.skipIf(jax.lib.version < (0, 1, 66),
"Test fails on jaxlib 0.1.65 or earlier.")
def testRegressionIssue5728(self):
# The computation in this test gave garbage data on CPU due to an LLVM bug.
@jax.jit
@ -2642,8 +2636,6 @@ class LazyConstantTest(jtu.JaxTestCase):
out = lax.cumsum(x)
self.assertArraysEqual(out, x)
@unittest.skipIf(jax.lib._xla_extension_version < 24,
"Test requires Jaxlib 0.1.68")
def testLog1pNearOne(self):
np.testing.assert_array_almost_equal_nulp(
np.log1p(np.float32(1e-5)), lax.log1p(np.float32(1e-5)))

View File

@ -356,7 +356,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
rtol=1e-3)
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def testEighZeroDiagonal(self):
a = np.array([[0., -1., -1., 1.],
[-1., 0., 1., -1.],

View File

@ -76,8 +76,7 @@ class ProfilerTest(unittest.TestCase):
self.assertIn(b"/host:CPU", proto)
if jtu.device_under_test() == "tpu":
self.assertIn(b"/device:TPU", proto)
if jax.lib.version >= (0, 1, 65):
self.assertIn(b"pxla.py", proto)
self.assertIn(b"pxla.py", proto)
def testProgrammaticProfilingErrors(self):
with self.assertRaisesRegex(RuntimeError, "No profile started"):

View File

@ -15,7 +15,6 @@
import collections
import re
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -25,7 +24,6 @@ from jax import tree_util
from jax._src.tree_util import _process_pytree
from jax import flatten_util
import jax.numpy as jnp
from jax import lib
def _dummy_func(*args, **kwargs):
@ -306,8 +304,6 @@ class TreeTest(jtu.JaxTestCase):
FlatCache({"a": [3, 4], "b": [5, 6]}))
self.assertEqual(expected, actual)
@unittest.skipIf(lib._xla_extension_version < 17,
"Test requires jaxlib 0.1.66.")
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
def testStringRepresentation(self, tree, correct_string):
"""Checks that the string representation of a tree works."""

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
from jax import test_util as jtu
@ -23,7 +21,6 @@ from jax.interpreters import xla
class XlaInterpreterTest(jtu.JaxTestCase):
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
def test_prune_jit_args(self):
def f(*args):
return args[0]