Update minimum jaxlib version to 0.1.69.

This commit is contained in:
Peter Hawkins 2021-07-15 16:39:18 -04:00
parent 6aa20d8f8f
commit 3ddcec27f2
24 changed files with 71 additions and 241 deletions

View File

@ -15,12 +15,16 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Support for Python 3.6 has been dropped, per the * Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version. Please upgrade to a supported Python version.
* The minimum jaxlib version is now 0.1.69.
* The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been
removed.
* Bug fixes: * Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are * Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with invalid `axis` value, or with an empty reduction dimension. not used with invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`#7196`) ({jax-issue}`#7196`)
## jaxlib 0.1.70 (unreleased) ## jaxlib 0.1.70 (unreleased)
* Breaking changes: * Breaking changes:
* Support for Python 3.6 has been dropped, per the * Support for Python 3.6 has been dropped, per the

View File

@ -1,8 +1,6 @@
flake8 flake8
flatbuffers==1.12 flatbuffers==1.12
# For now, we pin the numpy version here numpy>=1.17
# TODO(jakevdp): unpin maximum version when minimum jaxlib supports newer numpy
numpy>=1.17,<1.21
mypy==0.902 mypy==0.902
pillow>=8.3.1 pillow>=8.3.1
pytest-benchmark pytest-benchmark

View File

@ -13,6 +13,5 @@ pytest-xdist
# Packages used for notebook execution # Packages used for notebook execution
matplotlib matplotlib
scikit-learn scikit-learn
# TODO(jakevdp) remove numpy pinning when minimum jaxlib supports newer numpy. numpy
numpy<1.21
.[cpu] # Install jax from the current directory; jaxlib from pypi. .[cpu] # Install jax from the current directory; jaxlib from pypi.

View File

@ -197,12 +197,6 @@ def _infer_argnums_and_argnames(
return argnums, argnames return argnums, argnames
# Static kwargs require jaxlib 0.1.66 or newer. (Technically the Python jit
# implementation works with any jaxlib, but we want consistency between C++ and
# Python implementations.)
# TODO(phawkins): remove when jaxlib 0.1.66 is the minimum.
_ALLOW_STATIC_ARGNAMES = lib._xla_extension_version >= 14
def jit( def jit(
fun: F, fun: F,
*, *,
@ -285,10 +279,6 @@ def jit(
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232 0.76827 0.59566 ] -0.85743 -0.78232 0.76827 0.59566 ]
""" """
if not _ALLOW_STATIC_ARGNAMES:
if static_argnames is not None:
raise ValueError("static_argnames requires jaxlib 0.1.66 or newer")
static_argnames = ()
if FLAGS.experimental_cpp_jit: if FLAGS.experimental_cpp_jit:
return _cpp_jit(fun, static_argnums, static_argnames, device, backend, return _cpp_jit(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline) donate_argnums, inline)
@ -319,17 +309,10 @@ def _python_jit(
def f_jitted(*args, **kwargs): def f_jitted(*args, **kwargs):
if config.jax_disable_jit: if config.jax_disable_jit:
return fun(*args, **kwargs) return fun(*args, **kwargs)
if _ALLOW_STATIC_ARGNAMES: if max(donate_argnums, default=-1) >= len(args):
if max(donate_argnums, default=-1) >= len(args): raise ValueError(
raise ValueError( f"jitted function has donate_argnums={donate_argnums} but "
f"jitted function has donate_argnums={donate_argnums} but " f"was called with only {len(args)} positional arguments.")
f"was called with only {len(args)} positional arguments.")
else:
if max(static_argnums + donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has static_argnums={static_argnums}, "
f"donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
f = lu.wrap_init(fun) f = lu.wrap_init(fun)
f, args = argnums_partial_except(f, static_argnums, args, f, args = argnums_partial_except(f, static_argnums, args,
@ -362,8 +345,7 @@ class _FastpathData(NamedTuple):
lazy_exprs: Iterable[Any] lazy_exprs: Iterable[Any]
kept_var_bitvec: Iterable[bool] kept_var_bitvec: Iterable[bool]
if lib._xla_extension_version >= 16: _cpp_jit_cache = jax_jit.CompiledFunctionCache()
_cpp_jit_cache = jax_jit.CompiledFunctionCache()
def _cpp_jit( def _cpp_jit(
fun: F, fun: F,
@ -398,17 +380,10 @@ def _cpp_jit(
# An alternative would be for cache_miss to accept from C++ the arguments # An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have # (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later. # work/code that is redundant between C++ and Python. We can try that later.
if _ALLOW_STATIC_ARGNAMES: if max(donate_argnums, default=-1) >= len(args):
if max(donate_argnums, default=-1) >= len(args): raise ValueError(
raise ValueError( f"jitted function has donate_argnums={donate_argnums} but "
f"jitted function has donate_argnums={donate_argnums} but " f"was called with only {len(args)} positional arguments.")
f"was called with only {len(args)} positional arguments.")
else:
if max(static_argnums + donate_argnums, default=-1) >= len(args):
raise ValueError(
f"jitted function has static_argnums={static_argnums}, "
f"donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
f = lu.wrap_init(fun) f = lu.wrap_init(fun)
f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True) f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
f, kwargs = argnames_partial_except(f, static_argnames, kwargs) f, kwargs = argnames_partial_except(f, static_argnames, kwargs)
@ -452,14 +427,10 @@ def _cpp_jit(
aval, sticky_device = result_handler.args aval, sticky_device = result_handler.args
avals.append(aval) avals.append(aval)
assert len(avals) == len(out_flat) assert len(avals) == len(out_flat)
if xla._ALLOW_ARG_PRUNING: kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))]
kept_var_bitvec = [i in kept_var_idx for i in range(len(args_flat))] fastpath_data = _FastpathData(xla_executable, out_pytree_def,
fastpath_data = _FastpathData(xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs,
sticky_device, avals, lazy_exprs, kept_var_bitvec)
kept_var_bitvec)
else:
fastpath_data = (xla_executable, out_pytree_def, sticky_device, avals,
lazy_exprs)
else: else:
fastpath_data = None fastpath_data = None
@ -477,21 +448,12 @@ def _cpp_jit(
return _BackendAndDeviceInfo(default_device, committed_to_device) return _BackendAndDeviceInfo(default_device, committed_to_device)
if lib._xla_extension_version < 14: cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, static_argnums) static_argnums=static_argnums,
f_jitted = wraps(fun)(cpp_jitted_f) static_argnames=static_argnames,
elif lib._xla_extension_version < 16: donate_argnums=donate_argnums,
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, cache=_cpp_jit_cache)
static_argnums=static_argnums, f_jitted = wraps(fun)(cpp_jitted_f)
static_argnames=static_argnames)
f_jitted = wraps(fun)(cpp_jitted_f)
else:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
cache=_cpp_jit_cache)
f_jitted = wraps(fun)(cpp_jitted_f)
return f_jitted return f_jitted
@ -627,7 +589,7 @@ def xla_computation(fun: Callable,
Alternatively, the assignment to ``c`` above could be written: Alternatively, the assignment to ``c`` above could be written:
>>> import types >>> import types
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32) >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> c = jax.xla_computation(f)(scalar) >>> c = jax.xla_computation(f)(scalar)
@ -2063,7 +2025,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
>>> import types >>> import types
>>> >>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y >>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32) >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0) >>> f_transpose(1.0)
(DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32)) (DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))
@ -2441,7 +2403,7 @@ def eval_shape(fun: Callable, *args, **kwargs):
>>> class MyArgArray(object): >>> class MyArgArray(object):
... def __init__(self, shape, dtype): ... def __init__(self, shape, dtype):
... self.shape = shape ... self.shape = shape
... self.dtype = dtype ... self.dtype = jnp.dtype(dtype)
... ...
>>> A = MyArgArray((2000, 3000), jnp.float32) >>> A = MyArgArray((2000, 3000), jnp.float32)
>>> x = MyArgArray((3000, 1000), jnp.float32) >>> x = MyArgArray((3000, 1000), jnp.float32)

View File

@ -15,7 +15,6 @@
from jax import core from jax import core
from jax import numpy as jnp from jax import numpy as jnp
from jax.interpreters import xla from jax.interpreters import xla
import jax.lib
from jax.lib import xla_client from jax.lib import xla_client
from jax.lib import xla_bridge from jax.lib import xla_bridge
@ -44,28 +43,21 @@ def to_dlpack(x: xla.DeviceArrayProtocol, take_ownership: bool = False):
return xla_client._xla.buffer_to_dlpack_managed_tensor( return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.device_buffer, take_ownership=take_ownership) x.device_buffer, take_ownership=take_ownership)
def from_dlpack(dlpack, backend=None): def from_dlpack(dlpack):
"""Returns a `DeviceArray` representation of a DLPack tensor `dlpack`. """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.
The returned `DeviceArray` shares memory with `dlpack`. The returned `DeviceArray` shares memory with `dlpack`.
Args: Args:
dlpack: a DLPack tensor, on either CPU or GPU. dlpack: a DLPack tensor, on either CPU or GPU.
backend: deprecated, do not use.
""" """
if jax.lib._xla_extension_version >= 25: cpu_backend = xla_bridge.get_backend("cpu")
cpu_backend = xla_bridge.get_backend("cpu") try:
try: gpu_backend = xla_bridge.get_backend("gpu")
gpu_backend = xla_bridge.get_backend("gpu") except RuntimeError:
except RuntimeError: gpu_backend = None
gpu_backend = None buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
buf = xla_client._xla.dlpack_managed_tensor_to_buffer( dlpack, cpu_backend, gpu_backend)
dlpack, cpu_backend, gpu_backend)
else:
# TODO(phawkins): drop the backend argument after deleting this case.
backend = backend or xla_bridge.get_backend()
client = getattr(backend, "client", backend)
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
xla_shape = buf.xla_shape() xla_shape = buf.xla_shape()
assert not xla_shape.is_tuple() assert not xla_shape.is_tuple()

View File

@ -2443,10 +2443,7 @@ acos_p = standard_unop(_float | _complex, 'acos',
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
def atan_translation_rule(x): def atan_translation_rule(x):
if jax.lib._xla_extension_version < 26 and dtypes.issubdtype(_dtype(x), np.complexfloating): return atan2(x, _const(x, 1))
return mul(_const(x, -1j), atanh(mul(_const(x, 1j), x)))
else:
return atan2(x, _const(x, 1))
atan_p = standard_unop(_float | _complex, 'atan', atan_p = standard_unop(_float | _complex, 'atan',
translation_rule=xla.lower_fun(atan_translation_rule, translation_rule=xla.lower_fun(atan_translation_rule,
@ -6149,14 +6146,10 @@ ad.primitive_transposes[select_and_gather_add_p] = \
batching.primitive_batchers[select_and_gather_add_p] = \ batching.primitive_batchers[select_and_gather_add_p] = \
_select_and_gather_add_batching_rule _select_and_gather_add_batching_rule
# TODO(b/183233858): use variadic reducewindow on GPU, when implemented. # TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
if jax.lib._xla_extension_version >= 15: xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \
xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \ _select_and_gather_add_translation_using_variadic_reducewindow
_select_and_gather_add_translation_using_variadic_reducewindow xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \ _select_and_gather_add_translation_using_variadic_reducewindow
_select_and_gather_add_translation_using_variadic_reducewindow
else:
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
_select_and_gather_add_translation, max_bits=32)
def _sort_abstract_eval(*args, **kwargs): def _sort_abstract_eval(*args, **kwargs):

View File

@ -437,15 +437,10 @@ def block_diag(*arrs):
return acc return acc
# TODO(phawkins): use static_argnames when jaxlib 0.1.66 is the minimum and
# remove this wrapper.
@_wraps(scipy.linalg.eigh_tridiagonal) @_wraps(scipy.linalg.eigh_tridiagonal)
@partial(jit, static_argnames=("eigvals_only", "select", "select_range"))
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
select_range=None, tol=None): select_range=None, tol=None):
return _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol)
@partial(jit, static_argnums=(2, 3, 4))
def _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol):
if not eigvals_only: if not eigvals_only:
raise NotImplementedError("Calculation of eigenvectors is not implemented") raise NotImplementedError("Calculation of eigenvectors is not implemented")

View File

@ -90,10 +90,6 @@ def user_context(c):
except Exception as e: except Exception as e:
if c is None or has_user_context(e): if c is None or has_user_context(e):
raise raise
# TODO(phawkins): remove the following condition after Jaxlib 0.1.66 is the
# minimum.
if not hasattr(c, 'as_python_traceback'):
raise
filtered_tb = traceback_util.filter_traceback(c.as_python_traceback()) filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
if filtered_tb: if filtered_tb:
msg = traceback_util.format_exception_only(e) msg = traceback_util.format_exception_only(e)

View File

@ -208,15 +208,8 @@ def api_boundary(fun):
# There seems to be no way to alter the currently raised exception's # There seems to be no way to alter the currently raised exception's
# traceback, except via the C API. The currently raised exception # traceback, except via the C API. The currently raised exception
# is part of the interpreter's thread state: value `e` is a copy. # is part of the interpreter's thread state: value `e` is a copy.
if hasattr(xla_extension, 'replace_thread_exc_traceback'): xla_extension.replace_thread_exc_traceback(filtered_tb)
xla_extension.replace_thread_exc_traceback(filtered_tb) raise
raise
else:
# TODO(phawkins): remove this case when jaxlib 0.1.66 is the
# minimum.
# Fallback case for older jaxlibs; includes the current frame.
raise e.with_traceback(filtered_tb)
finally: finally:
del filtered_tb del filtered_tb
del unfiltered del unfiltered

View File

@ -23,8 +23,8 @@ product, sparse matrix/matrix product) for two common sparse representations
These routines have reference implementations defined via XLA scatter/gather These routines have reference implementations defined via XLA scatter/gather
operations that will work on any backend, although they are not particularly operations that will work on any backend, although they are not particularly
performant. On GPU runtimes with jaxlib 0.1.66 or newer built against CUDA 11.0 performant. On GPU runtimes built against CUDA 11.0 or newer, each operation is
or newer, each operation is computed efficiently via cusparse. computed efficiently via cusparse.
Further down are some examples of potential high-level wrappers for sparse objects. Further down are some examples of potential high-level wrappers for sparse objects.
(API should be considered unstable and subject to change). (API should be considered unstable and subject to change).

View File

@ -38,7 +38,6 @@ from .._src.util import (partial, partialmethod, cache, prod, unzip2,
extend_name_stack, wrap_name, safe_zip, safe_map) extend_name_stack, wrap_name, safe_zip, safe_map)
from ..lib import xla_bridge as xb from ..lib import xla_bridge as xb
from ..lib import xla_client as xc from ..lib import xla_client as xc
from ..lib import _xla_extension_version
from . import partial_eval as pe from . import partial_eval as pe
from . import ad from . import ad
from . import masking from . import masking
@ -778,18 +777,8 @@ def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
return tuple(out_donated_args) return tuple(out_donated_args)
# Pruning unused JIT arguments require jaxlib 0.1.66 or newer.
# TODO(zhangqiaorjc): remove when jaxlib 0.1.66 is the minimum.
_ALLOW_ARG_PRUNING = _xla_extension_version >= 20
def _prune_unused_inputs( def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]: jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
if not _ALLOW_ARG_PRUNING:
kept_const_idx = range(len(jaxpr.constvars))
kept_var_idx = range(len(jaxpr.invars))
return jaxpr, set(kept_const_idx), set(kept_var_idx)
used = {v for v in jaxpr.outvars if isinstance(v, core.Var)} used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
# TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive # TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
# applications that do not produce used outputs. Must handle side-effecting # applications that do not produce used outputs. Must handle side-effecting
@ -1111,11 +1100,9 @@ class _DeviceArray(DeviceArray): # type: ignore
"""A DeviceArray is an ndarray backed by a single device memory buffer.""" """A DeviceArray is an ndarray backed by a single device memory buffer."""
# We don't subclass ndarray because that would open up a host of issues, # We don't subclass ndarray because that would open up a host of issues,
# but lax_numpy.py overrides isinstance behavior and attaches ndarray methods. # but lax_numpy.py overrides isinstance behavior and attaches ndarray methods.
# TODO(phawkins): make __weakref__ an unconditional slot when jaxlib 0.1.66
# is the minimum version.
__slots__ = [ __slots__ = [
"aval", "device_buffer", "_npy_value", "_device" "aval", "device_buffer", "_npy_value", "_device", "__weakref__"
] + ([] if device_array_supports_weakrefs() else ["__weakref__"]) ]
__array_priority__ = 100 __array_priority__ = 100
# DeviceArray has methods that are dynamically populated in lax_numpy.py, # DeviceArray has methods that are dynamically populated in lax_numpy.py,

View File

@ -67,9 +67,8 @@ def _check_jaxlib_version():
_check_jaxlib_version() _check_jaxlib_version()
if version >= (0, 1, 68): from jaxlib import cpu_feature_guard
from jaxlib import cpu_feature_guard cpu_feature_guard.check_cpu_features()
cpu_feature_guard.check_cpu_features()
from jaxlib import xla_client from jaxlib import xla_client
from jaxlib import lapack from jaxlib import lapack

View File

@ -155,40 +155,23 @@ def register_backend_factory(name, factory, *, priority=0):
_backend_factories[name] = (factory, priority) _backend_factories[name] = (factory, priority)
if jax.lib._xla_extension_version >= 23: register_backend_factory('interpreter', xla_client.make_interpreter_client,
register_backend_factory('interpreter', xla_client.make_interpreter_client, priority=-100)
priority=-100) if FLAGS.jax_cpu_backend_variant == 'stream_executor':
if jax.lib._xla_extension_version >= 27: register_backend_factory('cpu',
if FLAGS.jax_cpu_backend_variant == 'stream_executor': partial(xla_client.make_cpu_client, use_tfrt=False),
register_backend_factory('cpu', priority=0)
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
else:
assert FLAGS.jax_cpu_backend_variant == 'tfrt'
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=True),
priority=0)
else:
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=False),
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('tpu', xla_client.make_tpu_client,
priority=300)
else: else:
register_backend_factory('interpreter', assert FLAGS.jax_cpu_backend_variant == 'tfrt'
xla_client._interpreter_backend_factory, register_backend_factory('cpu',
priority=-100) partial(xla_client.make_cpu_client, use_tfrt=True),
register_backend_factory('cpu', xla_client._cpu_backend_factory, priority=0) priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client, register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100) priority=100)
register_backend_factory('gpu', xla_client._gpu_backend_factory, register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200) priority=200)
register_backend_factory('tpu', xla_client._tpu_backend_factory, register_backend_factory('tpu', xla_client.make_tpu_client,
priority=300) priority=300)
_default_backend = None _default_backend = None
_backends = None _backends = None
@ -258,11 +241,6 @@ def get_backend(platform=None):
def get_device_backend(device=None): def get_device_backend(device=None):
"""Returns the Backend associated with `device`, or the default Backend.""" """Returns the Backend associated with `device`, or the default Backend."""
if device is not None: if device is not None:
# TODO(phawkins): remove this workaround after jaxlib 0.1.68 becomes the
# minimum and it is safe to call `.client` on a tpu_driver TpuDevice.
if tpu_driver_client and isinstance(
device, tpu_driver_client._tpu_client.TpuDevice):
return get_backend('tpu_driver')
return device.client return device.client
return get_backend() return get_backend()

View File

@ -13,4 +13,4 @@
# limitations under the License. # limitations under the License.
__version__ = "0.2.17" __version__ = "0.2.17"
_minimum_jaxlib_version = "0.1.65" _minimum_jaxlib_version = "0.1.69"

View File

@ -50,7 +50,6 @@ from jax import test_util as jtu
from jax import tree_util from jax import tree_util
from jax import linear_util as lu from jax import linear_util as lu
import jax._src.util import jax._src.util
from jax._src.api import _ALLOW_STATIC_ARGNAMES
from jax.config import config from jax.config import config
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -309,15 +308,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f(2) f(2)
assert len(effects) == 3 assert len(effects) == 3
# TODO(phawkins): delete this test after jaxlib 0.1.66 is the minimum.
@unittest.skipIf(_ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_static_argnum_errors_on_keyword_arguments(self):
f = self.jit(lambda x: x, static_argnums=0)
msg = ("jitted function has static_argnums=(0,), donate_argnums=() but was "
"called with only 0 positional arguments.")
with self.assertRaisesRegex(ValueError, re.escape(msg)):
f(x=4)
def test_static_argnum_on_method(self): def test_static_argnum_on_method(self):
class A: class A:
@ -587,7 +577,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
assert argnums == () assert argnums == ()
assert argnames == ('foo', 'bar') assert argnames == ('foo', 'bar')
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_jit_with_static_argnames(self): def test_jit_with_static_argnames(self):
def f(x): def f(x):
@ -602,19 +591,16 @@ class CPPJitTest(jtu.BufferDonationTestCase):
assert f_names('foo') == 1 assert f_names('foo') == 1
assert f_names(x='foo') == 1 assert f_names(x='foo') == 1
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_new_static_argnum_on_keyword_arguments(self): def test_new_static_argnum_on_keyword_arguments(self):
f = self.jit(lambda x: x, static_argnums=0) f = self.jit(lambda x: x, static_argnums=0)
y = f(x=4) y = f(x=4)
assert y == 4 assert y == 4
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_new_static_argnum_with_default_arguments(self): def test_new_static_argnum_with_default_arguments(self):
f = self.jit(lambda x=4: x, static_argnums=0) f = self.jit(lambda x=4: x, static_argnums=0)
y = f() y = f()
assert y == 4 assert y == 4
@unittest.skipIf(not _ALLOW_STATIC_ARGNAMES, "Test requires jaxlib 0.1.66")
def test_jit_with_mismatched_static_argnames(self): def test_jit_with_mismatched_static_argnames(self):
x_is_tracer, y_is_tracer = False, False x_is_tracer, y_is_tracer = False, False
def f(x, y): def f(x, y):
@ -646,7 +632,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
# TODO(zhangqiaorjc): Test pruning constants after DCE pass prunes primitive # TODO(zhangqiaorjc): Test pruning constants after DCE pass prunes primitive
# applications. # applications.
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
@parameterized.named_parameters(jtu.cases_from_list( @parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_num_args={}".format(num_args), {"testcase_name": "_num_args={}".format(num_args),
"num_args": num_args} "num_args": num_args}
@ -2697,8 +2682,6 @@ class APITest(jtu.JaxTestCase):
self.assertIn('Precision.HIGH', str(jaxpr)) self.assertIn('Precision.HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision")) self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
@unittest.skipIf(jax.lib._xla_extension_version <= 17,
"Test requires jaxlib 0.1.66")
def test_dot_precision_forces_retrace(self): def test_dot_precision_forces_retrace(self):
num_traces = 0 num_traces = 0
@ -2738,9 +2721,6 @@ class APITest(jtu.JaxTestCase):
finally: finally:
FLAGS.jax_default_matmul_precision = precision FLAGS.jax_default_matmul_precision = precision
@unittest.skipIf(jax.lib._xla_extension_version <= 17,
"Test requires jaxlib 0.1.66")
def test_rank_promotion_forces_retrace(self): def test_rank_promotion_forces_retrace(self):
num_traces = 0 num_traces = 0
@ -5508,8 +5488,6 @@ class NamedCallTest(jtu.JaxTestCase):
for jit_type in [None, "python", "cpp"] for jit_type in [None, "python", "cpp"]
if not (jit_type is None and func == 'identity'))) if not (jit_type is None and func == 'identity')))
def test_integer_overflow(self, jit_type, func): def test_integer_overflow(self, jit_type, func):
if jit_type == "cpp" and not config.x64_enabled and jax.lib.version < (0, 1, 65):
self.skipTest("int32 overflow detection not yet implemented in CPP JIT.")
funcdict = { funcdict = {
'identity': lambda x: x, 'identity': lambda x: x,
'asarray': jnp.asarray, 'asarray': jnp.asarray,

View File

@ -80,10 +80,6 @@ class DLPackTest(jtu.JaxTestCase):
np = rng(shape, dtype) np = rng(shape, dtype)
if gpu and jax.default_backend() == "cpu": if gpu and jax.default_backend() == "cpu":
raise unittest.SkipTest("Skipping GPU test case on CPU") raise unittest.SkipTest("Skipping GPU test case on CPU")
if (not gpu and jax.default_backend() == "gpu" and
jax.lib._xla_extension_version < 25):
raise unittest.SkipTest("Mixed CPU/GPU dlpack support requires jaxlib "
"0.1.68 or newer")
device = jax.devices("gpu" if gpu else "cpu")[0] device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device) x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership) dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership)

View File

@ -20,7 +20,6 @@ import jax.test_util as jtu
import numpy as np import numpy as np
import random import random
import tempfile import tempfile
import unittest
from unittest import SkipTest from unittest import SkipTest
from jax.config import config from jax.config import config
@ -29,7 +28,6 @@ FLAGS = config.FLAGS
class CompilationCacheTest(jtu.JaxTestCase): class CompilationCacheTest(jtu.JaxTestCase):
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_compile_options(self): def test_compile_options(self):
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options( compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1) num_replicas=1, num_partitions=1)
@ -40,7 +38,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertEqual(filled_hash1, filled_hash2) self.assertEqual(filled_hash1, filled_hash2)
self.assertNotEqual(filled_hash1, not_filled_hash3) self.assertNotEqual(filled_hash1, not_filled_hash3)
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_executable_build_options(self): def test_executable_build_options(self):
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options( compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1) num_replicas=1, num_partitions=1)
@ -66,7 +63,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
hash3 = self.get_hashed_value(cc._hash_debug_options, new_debug_options) hash3 = self.get_hashed_value(cc._hash_debug_options, new_debug_options)
self.assertNotEqual(hash1, hash3) self.assertNotEqual(hash1, hash3)
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_hash_platform(self): def test_hash_platform(self):
hash1 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend()) hash1 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
hash2 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend()) hash2 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
@ -97,7 +93,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertEqual(hash2, hash3) self.assertEqual(hash2, hash3)
self.assertNotEqual(hash1, hash2) self.assertNotEqual(hash1, hash2)
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_same_hash_key(self): def test_same_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1) computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax.lib.xla_bridge.get_compile_options( compile_options = jax.lib.xla_bridge.get_compile_options(
@ -105,7 +100,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertEqual(cc.get_cache_key(computation, compile_options), self.assertEqual(cc.get_cache_key(computation, compile_options),
cc.get_cache_key(computation, compile_options)) cc.get_cache_key(computation, compile_options))
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_different_hash_key(self): def test_different_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1) computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options( compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
@ -114,7 +108,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled), self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled),
cc.get_cache_key(computation, compile_options_filled)) cc.get_cache_key(computation, compile_options_filled))
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_different_computations(self): def test_different_computations(self):
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1) computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2) computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
@ -123,7 +116,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(cc.get_cache_key(computation1, compile_options), self.assertNotEqual(cc.get_cache_key(computation1, compile_options),
cc.get_cache_key(computation2, compile_options)) cc.get_cache_key(computation2, compile_options))
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
def test_get_no_executable(self): def test_get_no_executable(self):
if jtu.device_under_test() != "tpu": if jtu.device_under_test() != "tpu":
raise SkipTest("serialize executable only works on TPU") raise SkipTest("serialize executable only works on TPU")
@ -137,7 +129,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
num_replicas=1, num_partitions=1) num_replicas=1, num_partitions=1)
self.assertEqual(cc.get_executable(computation, compile_options), None) self.assertEqual(cc.get_executable(computation, compile_options), None)
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
def test_diff_executables(self): def test_diff_executables(self):
if jtu.device_under_test() != "tpu": if jtu.device_under_test() != "tpu":
raise SkipTest("serialize executable only works on TPU") raise SkipTest("serialize executable only works on TPU")
@ -158,7 +149,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(cc.get_executable(computation1, compile_options), self.assertNotEqual(cc.get_executable(computation1, compile_options),
cc.get_executable(computation2, compile_options)) cc.get_executable(computation2, compile_options))
@unittest.skipIf(jax.lib.version < (0, 1, 69), "fails with earlier jaxlibs")
def test_put_executable(self): def test_put_executable(self):
if jtu.device_under_test() != "tpu": if jtu.device_under_test() != "tpu":
raise SkipTest("serialize executable only works on TPU") raise SkipTest("serialize executable only works on TPU")

View File

@ -26,7 +26,6 @@ import jax.numpy as jnp
from jax import test_util as jtu from jax import test_util as jtu
from jax._src import source_info_util from jax._src import source_info_util
from jax._src import traceback_util from jax._src import traceback_util
from jax.lib import xla_extension
from jax.config import config from jax.config import config
@ -58,11 +57,6 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=[],
if filter_mode == "tracebackhide": if filter_mode == "tracebackhide":
if "__tracebackhide__" in frame.f_locals.keys(): if "__tracebackhide__" in frame.f_locals.keys():
continue continue
elif filter_mode == "remove_frames":
# TODO(phawkins): remove this condition after jaxlib 0.1.66 is the minimum.
if (not hasattr(xla_extension, "replace_thread_exc_traceback") and
frame.f_code.co_name == "reraise_with_filtered_traceback"):
continue
frames.append((frame, lineno)) frames.append((frame, lineno))
c_tb = traceback.format_list(traceback.StackSummary.extract(frames)) c_tb = traceback.format_list(traceback.StackSummary.extract(frames))
@ -341,12 +335,9 @@ class UserContextTracebackTest(jtu.JaxTestCase):
e = exc e = exc
self.assertIsNot(e, None) self.assertIsNot(e, None)
self.assertIn("invalid value", str(e)) self.assertIn("invalid value", str(e))
# TODO(phawkins): make this test unconditional after jaxlib 0.1.66 is the self.assertIsInstance(
# minimum. e.__cause__.__cause__,
if jax.lib._xla_extension_version >= 19: source_info_util.JaxStackTraceBeforeTransformation)
self.assertIsInstance(
e.__cause__.__cause__,
source_info_util.JaxStackTraceBeforeTransformation)
class CustomErrorsTest(jtu.JaxTestCase): class CustomErrorsTest(jtu.JaxTestCase):

View File

@ -324,10 +324,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if not config.x64_enabled: if not config.x64_enabled:
raise unittest.SkipTest("requires x64 mode") raise unittest.SkipTest("requires x64 mode")
# The LLVM bug that caused this appears to be fixed in jaxlib 0.1.67.
if jtu.device_under_test() == "cpu" and jax.lib.version <= (0, 1, 66):
raise unittest.SkipTest("test fails on CPU jaxlib <= 0.1.66")
rng = jtu.rand_default(self.rng()) rng = jtu.rand_default(self.rng())
A = rng(shape, dtype) A = rng(shape, dtype)
b = rng(shape[:1], dtype) b = rng(shape[:1], dtype)

View File

@ -2417,16 +2417,10 @@ class LaxTest(jtu.JaxTestCase):
self.assertEqual(dtypes.result_type(val), dtypes.result_type(const)) self.assertEqual(dtypes.result_type(val), dtypes.result_type(const))
# TODO(phawkins): make this test unconditional after jaxlib 0.1.67 is the
# default.
@unittest.skipIf(jax.lib._xla_extension_version < 22,
"Test requires jaxlib 0.1.67 or newer")
def testIgammaSpecial(self): def testIgammaSpecial(self):
self.assertEqual(lax.igamma(1., np.inf), 1.) self.assertEqual(lax.igamma(1., np.inf), 1.)
self.assertEqual(lax.igammac(1., np.inf), 0.) self.assertEqual(lax.igammac(1., np.inf), 0.)
@unittest.skipIf(jax.lib.version < (0, 1, 66),
"Test fails on jaxlib 0.1.65 or earlier.")
def testRegressionIssue5728(self): def testRegressionIssue5728(self):
# The computation in this test gave garbage data on CPU due to an LLVM bug. # The computation in this test gave garbage data on CPU due to an LLVM bug.
@jax.jit @jax.jit
@ -2642,8 +2636,6 @@ class LazyConstantTest(jtu.JaxTestCase):
out = lax.cumsum(x) out = lax.cumsum(x)
self.assertArraysEqual(out, x) self.assertArraysEqual(out, x)
@unittest.skipIf(jax.lib._xla_extension_version < 24,
"Test requires Jaxlib 0.1.68")
def testLog1pNearOne(self): def testLog1pNearOne(self):
np.testing.assert_array_almost_equal_nulp( np.testing.assert_array_almost_equal_nulp(
np.log1p(np.float32(1e-5)), lax.log1p(np.float32(1e-5))) np.log1p(np.float32(1e-5)), lax.log1p(np.float32(1e-5)))

View File

@ -356,7 +356,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker, self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
rtol=1e-3) rtol=1e-3)
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def testEighZeroDiagonal(self): def testEighZeroDiagonal(self):
a = np.array([[0., -1., -1., 1.], a = np.array([[0., -1., -1., 1.],
[-1., 0., 1., -1.], [-1., 0., 1., -1.],

View File

@ -76,8 +76,7 @@ class ProfilerTest(unittest.TestCase):
self.assertIn(b"/host:CPU", proto) self.assertIn(b"/host:CPU", proto)
if jtu.device_under_test() == "tpu": if jtu.device_under_test() == "tpu":
self.assertIn(b"/device:TPU", proto) self.assertIn(b"/device:TPU", proto)
if jax.lib.version >= (0, 1, 65): self.assertIn(b"pxla.py", proto)
self.assertIn(b"pxla.py", proto)
def testProgrammaticProfilingErrors(self): def testProgrammaticProfilingErrors(self):
with self.assertRaisesRegex(RuntimeError, "No profile started"): with self.assertRaisesRegex(RuntimeError, "No profile started"):

View File

@ -15,7 +15,6 @@
import collections import collections
import re import re
import unittest
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
@ -25,7 +24,6 @@ from jax import tree_util
from jax._src.tree_util import _process_pytree from jax._src.tree_util import _process_pytree
from jax import flatten_util from jax import flatten_util
import jax.numpy as jnp import jax.numpy as jnp
from jax import lib
def _dummy_func(*args, **kwargs): def _dummy_func(*args, **kwargs):
@ -306,8 +304,6 @@ class TreeTest(jtu.JaxTestCase):
FlatCache({"a": [3, 4], "b": [5, 6]})) FlatCache({"a": [3, 4], "b": [5, 6]}))
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
@unittest.skipIf(lib._xla_extension_version < 17,
"Test requires jaxlib 0.1.66.")
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)]) @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
def testStringRepresentation(self, tree, correct_string): def testStringRepresentation(self, tree, correct_string):
"""Checks that the string representation of a tree works.""" """Checks that the string representation of a tree works."""

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
from absl.testing import absltest from absl.testing import absltest
from jax import test_util as jtu from jax import test_util as jtu
@ -23,7 +21,6 @@ from jax.interpreters import xla
class XlaInterpreterTest(jtu.JaxTestCase): class XlaInterpreterTest(jtu.JaxTestCase):
@unittest.skipIf(not xla._ALLOW_ARG_PRUNING, "Test requires jaxlib 0.1.66")
def test_prune_jit_args(self): def test_prune_jit_args(self):
def f(*args): def f(*args):
return args[0] return args[0]