From 328930b91792166bda63886971193dc7c5102d27 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 16 Mar 2021 12:13:41 -0400 Subject: [PATCH] Increase minimum jaxlib version to 0.1.62. --- CHANGELOG.md | 1 + build/test-requirements.txt | 2 +- jax/_src/lax/lax.py | 99 ++++++++++++++++------------------ jax/api.py | 38 ++----------- jax/config.py | 22 ++------ jax/dtypes.py | 6 +-- jax/lib/__init__.py | 2 +- tests/jax_jit_test.py | 6 +-- tests/lax_control_flow_test.py | 3 -- tests/lax_test.py | 2 - tests/x64_context_test.py | 5 -- tests/xmap_test.py | 4 -- 12 files changed, 58 insertions(+), 132 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 843688467..6550a2597 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * New features: * Bug fixes: * Breaking changes: + * The minimum jaxlib version is now 0.1.62. ## jaxlib 0.1.63 (Unreleased) diff --git a/build/test-requirements.txt b/build/test-requirements.txt index c1fadfa55..a5254a264 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -2,7 +2,7 @@ flake8 # For now, we pin the numpy version here numpy>=1.16 # Must be kept in sync with the minimum jaxlib version in jax/lib/__init__.py -jaxlib==0.1.60 +jaxlib==0.1.62 mypy==0.790 pillow pytest-benchmark diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 49bb72213..d895ccc16 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6086,75 +6086,66 @@ rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval) xla.translations[rng_uniform_p] = _rng_uniform_translation_rule -if jax.lib.version >= (0, 1, 62): - def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm): - _ = dtype, algorithm - return (key.shape, tuple(shape)) +def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm): + _ = dtype, algorithm + return (key.shape, tuple(shape)) - def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm): - _ = key, shape, algorithm - return (key.dtype, dtype) +def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm): + _ = key, shape, algorithm + return (key.dtype, dtype) - def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): - _ = shape, dtype, algorithm - return (key.weak_type, False) +def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): + _ = shape, dtype, algorithm + return (key.weak_type, False) - def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm): - _ = c - xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape) - return xops.RngBitGenerator(algorithm, key, xla_shape) +def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm): + _ = c + xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape) + return xops.RngBitGenerator(algorithm, key, xla_shape) - def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm): - return [key.named_shape, key.named_shape] +def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm): + return [key.named_shape, key.named_shape] - rng_bit_generator_p = Primitive("rng_bit_generator") - rng_bit_generator_p.multiple_results = True - rng_bit_generator_p.def_impl( - partial(xla.apply_primitive, rng_bit_generator_p)) - rng_bit_generator_p.def_abstract_eval( - partial(standard_multi_result_abstract_eval, rng_bit_generator_p, - _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, - _rng_bit_generator_named_shape_rule)) - xla.translations[rng_bit_generator_p] = _rng_bit_generator_translation_rule +rng_bit_generator_p = Primitive("rng_bit_generator") +rng_bit_generator_p.multiple_results = True +rng_bit_generator_p.def_impl( + partial(xla.apply_primitive, rng_bit_generator_p)) +rng_bit_generator_p.def_abstract_eval( + partial(standard_multi_result_abstract_eval, rng_bit_generator_p, + _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, + _rng_bit_generator_weak_type_rule, + _rng_bit_generator_named_shape_rule)) +xla.translations[rng_bit_generator_p] = _rng_bit_generator_translation_rule - RandomAlgorithm = xops.RandomAlgorithm - RandomAlgorithm.__str__ = lambda algorithm: algorithm.name +RandomAlgorithm = xops.RandomAlgorithm +RandomAlgorithm.__str__ = lambda algorithm: algorithm.name - def rng_bit_generator(key, - shape, - dtype=np.uint32, - algorithm=RandomAlgorithm.RNG_DEFAULT): - """Stateless PRNG bit generator. Experimental and its use is discouraged. +def rng_bit_generator(key, + shape, + dtype=np.uint32, + algorithm=RandomAlgorithm.RNG_DEFAULT): + """Stateless PRNG bit generator. Experimental and its use is discouraged. - Returns uniformly distributed random bits with the specified shape and dtype - (what is requirted to be an integer type) using the platform specific - default algorithm or the one specified. + Returns uniformly distributed random bits with the specified shape and dtype + (what is requirted to be an integer type) using the platform specific + default algorithm or the one specified. - It provides direct acces to the RngBitGenerator primitive exposed by XLA - (https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low - level API access. + It provides direct acces to the RngBitGenerator primitive exposed by XLA + (https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low + level API access. - Most users should use `jax.random` instead for a stable and more user - friendly API. - """ - shape = jax.core.canonicalize_shape(shape) - return tuple( - rng_bit_generator_p.bind( - key, shape=shape, dtype=dtype, algorithm=algorithm)) -else: - # TODO(tberghammer): Remove when minimum jaxlib version is past (0, 1, 62). - rng_bit_generator_p = Primitive("rng_bit_generator") - class RandomAlgorithm: pass # type: ignore - - - def rng_bit_generator(key, shape, dtype=np.uint32, algorithm=None): - raise "rng_bit_generator needs jaxlib 0.1.62 or newer" + Most users should use `jax.random` instead for a stable and more user + friendly API. + """ + shape = jax.core.canonicalize_shape(shape) + return tuple( + rng_bit_generator_p.bind( + key, shape=shape, dtype=dtype, algorithm=algorithm)) def _iota_abstract_eval(*, dtype, shape, dimension): diff --git a/jax/api.py b/jax/api.py index 0d1c48018..90043df38 100644 --- a/jax/api.py +++ b/jax/api.py @@ -98,9 +98,7 @@ zip = safe_zip FLAGS = flags.FLAGS flags.DEFINE_bool("jax_disable_jit", bool_env("JAX_DISABLE_JIT", False), "Disable JIT compilation and just call original Python.") -# TODO(jblespiau): Remove the `if` when jaxlib 0.1.62 is the minimal version. -if lib._xla_extension_version >= 5: - jax_jit.set_disable_jit_cpp_flag(bool_env("JAX_DISABLE_JIT", False)) +jax_jit.set_disable_jit_cpp_flag(bool_env("JAX_DISABLE_JIT", False)) flags.DEFINE_bool( "experimental_cpp_jit", bool_env("JAX_CPP_JIT", True), @@ -347,39 +345,9 @@ def _cpp_jit( return _BackendAndDeviceInfo(default_device, committed_to_device) - # TODO(jblespiau): Delete `get_jax_enable_x64` and `get_jax_disable_jit_flag` - # when jaxlib 0.1.62 is the minimal version. - def get_jax_enable_x64(): - """Returns the value of the flag after GoogleInit. - - We must wait until flags have been parsed (in particular for top-level - functions decorated with jax.jit), so we delay inspecting the value - of the jax_enable_x64 flag until JIT time. - """ - # TODO(jblespiau): Delete when jaxlib 0.1.62 is the minimal version. - if lib._xla_extension_version >= 4: - return config.read("jax_enable_x64") - else: - return config.x64_enabled - - def get_jax_disable_jit_flag(): - """Returns the value of the `jax_disable_jit` flag. - - Both a flag and the `disable_jit` context manager can disable jit. We access - the flag only once, when jitting the function, and the context manager - modifies a C++ thread-local value. - """ - return config.read("jax_disable_jit") - static_argnums_ = (0,) + tuple(i + 1 for i in static_argnums) - # TODO(jblespiau): Remove when jaxlib 0.1.62 is the minimal version. - if lib._xla_extension_version >= 5: - cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, - static_argnums_) - else: - cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, - get_jax_enable_x64, get_jax_disable_jit_flag, - static_argnums_) + cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, + static_argnums_) # TODO(mattjj): make cpp callable follow descriptor protocol for bound methods @wraps(fun) diff --git a/jax/config.py b/jax/config.py index 4331ad766..acbf5a76a 100644 --- a/jax/config.py +++ b/jax/config.py @@ -50,9 +50,6 @@ class _ThreadLocalState(threading.local): class Config: - # TODO(jakevdp): Remove when minimum jaxlib is has extension version 4 - _thread_local_state = _ThreadLocalState() - def __init__(self): self.values = {} self.meta = {} @@ -70,10 +67,9 @@ class Config: raise Exception("Unrecognized config option: {}".format(name)) self.values[name] = val - # TODO(jblespiau): Remove when jaxlib 0.1.62 is the minimal version. - if lib._xla_extension_version >= 5 and name == "jax_disable_jit": + if name == "jax_disable_jit": lib.jax_jit.set_disable_jit_cpp_flag(val) - elif lib._xla_extension_version >= 5 and name == "jax_enable_x64": + elif name == "jax_enable_x64": lib.jax_jit.set_enable_x64_cpp_flag(val) def read(self, name): @@ -157,21 +153,11 @@ class Config: @property def x64_enabled(self): - if lib._xla_extension_version >= 5: - return lib.jax_jit.get_enable_x64() - else: - # TODO(jakevdp): Remove when minimum jaxlib is has extension version 4 - if self._thread_local_state.enable_x64 is None: - self._thread_local_state.enable_x64 = bool(self.read('jax_enable_x64')) - return self._thread_local_state.enable_x64 + return lib.jax_jit.get_enable_x64() # TODO(jakevdp): make this public when thread-local x64 is fully implemented. def _set_x64_enabled(self, state): - if lib._xla_extension_version >= 5: - lib.jax_jit.set_enable_x64_thread_local(bool(state)) - else: - # TODO(jakevdp): Remove when minimum jaxlib is has extension version 4 - self._thread_local_state.enable_x64 = bool(state) + lib.jax_jit.set_enable_x64_thread_local(bool(state)) class NameSpace(object): diff --git a/jax/dtypes.py b/jax/dtypes.py index 58dc864b2..ba448ac3a 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -39,10 +39,8 @@ FLAGS = flags.FLAGS flags.DEFINE_bool('jax_enable_x64', strtobool(os.getenv('JAX_ENABLE_X64', 'False')), 'Enable 64-bit types to be used.') -# TODO(jblespiau): Remove the `if` when jaxlib 0.1.62 is the minimal version. -if lib._xla_extension_version >= 5: - lib.jax_jit.set_enable_x64_cpp_flag( - strtobool(os.getenv('JAX_ENABLE_X64', 'False'))) +lib.jax_jit.set_enable_x64_cpp_flag( + strtobool(os.getenv('JAX_ENABLE_X64', 'False'))) # bfloat16 support bfloat16: type = xla_client.bfloat16 diff --git a/jax/lib/__init__.py b/jax/lib/__init__.py index 894e6a785..355776e17 100644 --- a/jax/lib/__init__.py +++ b/jax/lib/__init__.py @@ -29,7 +29,7 @@ except ModuleNotFoundError as err: ) from err # Must be kept in sync with the jaxlib version in build/test-requirements.txt -_minimum_jaxlib_version = (0, 1, 60) +_minimum_jaxlib_version = (0, 1, 62) try: from jaxlib import version as jaxlib_version except Exception as err: diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index db5b37f07..9121a366a 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -29,8 +28,6 @@ import numpy as np # It covers all JAX numpy types types except bfloat16 and numpy array. # TODO(jblespiau): Add support for float0 in the C++ path. _EXCLUDED_TYPES = [np.ndarray] -if jax.lib._xla_extension_version < 6: - _EXCLUDED_TYPES.append(jax.dtypes.bfloat16) _SCALAR_NUMPY_TYPES = [ x for x in jax.abstract_arrays.array_types if x not in _EXCLUDED_TYPES @@ -138,10 +135,9 @@ class JaxJitTest(parameterized.TestCase): self.assertEqual(res.dtype, complex_type) self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype) - @unittest.skipIf(jax.lib._xla_extension_version < 3, "jaxlib too old") def test_convert_int_overflow(self): with self.assertRaisesRegex( - RuntimeError if jax.lib._xla_extension_version >= 6 else OverflowError, + RuntimeError, "(Python int too large|Unable to convert Python scalar).*"): jaxlib.jax_jit.device_put(int(1e100), True, jax.devices()[0]) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 21cca21ac..2a8f23ef6 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2582,9 +2582,6 @@ class LaxControlFlowTest(jtu.JaxTestCase): def test_xla_cpu_gpu_loop_cond_bug(self): # https://github.com/google/jax/issues/5900 - if jax.lib.version < (0, 1, 62): - raise SkipTest("test is broken on jaxlib==0.1.61 and 0.1.60") - def deriv(f): return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8de180d3d..a672c7b54 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -17,7 +17,6 @@ import collections from functools import partial import itertools import operator -import unittest from unittest import SkipTest from absl.testing import absltest @@ -2285,7 +2284,6 @@ class LaxTest(jtu.JaxTestCase): (x,), (1.,)))(1.) self.assertLen(jaxpr.jaxpr.eqns, 2) - @unittest.skipIf(jax.lib.version < (0, 1, 62), "Needs jaxlib 0.1.62 or newer") def testRngBitGenerator(self): if not config.x64_enabled: raise SkipTest("RngBitGenerator requires 64bit key") diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 6b653e477..84b370139 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -19,13 +19,11 @@ import time from absl.testing import absltest from absl.testing import parameterized -import jax from jax import api from jax import lax from jax import partial from jax import random from jax.config import config -from jax.config import FLAGS from jax.experimental import enable_x64, disable_x64 import jax.numpy as jnp import jax.test_util as jtu @@ -145,9 +143,6 @@ class X64ContextTests(jtu.JaxTestCase): def test_jit_cache(self): if jtu.device_under_test() == "tpu": self.skipTest("64-bit random not available on TPU") - if jax.lib._xla_extension_version < 4 and FLAGS.experimental_cpp_jit: - self.skipTest( - "Known failure due to https://github.com/google/jax/issues/5532") f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1) with disable_x64(): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 2b080cf41..2f92390ee 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -236,8 +236,6 @@ def schedules(sizes: Dict[str, int] class XMapTestCase(jtu.BufferDonationTestCase): def setUp(self): - if jax.lib.version < (0, 1, 58): - raise SkipTest("xmap requires jaxlib version >= 0.1.58") if not config.omnistaging_enabled: raise SkipTest("xmap requires omnistaging") super().setUp() @@ -691,8 +689,6 @@ class NamedNNTest(XMapTestCase): class NewPrimitiveTest(XMapTestCase): def setUp(self): - if jax.lib.version < (0, 1, 58): - raise SkipTest("xmap requires jaxlib version >= 0.1.58") if not config.omnistaging_enabled: raise SkipTest("xmap requires omnistaging")