Increase minimum jaxlib version to 0.1.62.

This commit is contained in:
Peter Hawkins 2021-03-16 12:13:41 -04:00
parent d326b077d9
commit 328930b917
12 changed files with 58 additions and 132 deletions

View File

@ -14,6 +14,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* New features:
* Bug fixes:
* Breaking changes:
* The minimum jaxlib version is now 0.1.62.
## jaxlib 0.1.63 (Unreleased)

View File

@ -2,7 +2,7 @@ flake8
# For now, we pin the numpy version here
numpy>=1.16
# Must be kept in sync with the minimum jaxlib version in jax/lib/__init__.py
jaxlib==0.1.60
jaxlib==0.1.62
mypy==0.790
pillow
pytest-benchmark

View File

@ -6086,75 +6086,66 @@ rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
xla.translations[rng_uniform_p] = _rng_uniform_translation_rule
if jax.lib.version >= (0, 1, 62):
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
_ = dtype, algorithm
return (key.shape, tuple(shape))
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
_ = dtype, algorithm
return (key.shape, tuple(shape))
def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
_ = key, shape, algorithm
return (key.dtype, dtype)
def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
_ = key, shape, algorithm
return (key.dtype, dtype)
def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
_ = shape, dtype, algorithm
return (key.weak_type, False)
def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
_ = shape, dtype, algorithm
return (key.weak_type, False)
def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm):
_ = c
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
return xops.RngBitGenerator(algorithm, key, xla_shape)
def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm):
_ = c
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
return xops.RngBitGenerator(algorithm, key, xla_shape)
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]
rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl(
partial(xla.apply_primitive, rng_bit_generator_p))
rng_bit_generator_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
_rng_bit_generator_weak_type_rule,
_rng_bit_generator_named_shape_rule))
xla.translations[rng_bit_generator_p] = _rng_bit_generator_translation_rule
rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl(
partial(xla.apply_primitive, rng_bit_generator_p))
rng_bit_generator_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
_rng_bit_generator_weak_type_rule,
_rng_bit_generator_named_shape_rule))
xla.translations[rng_bit_generator_p] = _rng_bit_generator_translation_rule
RandomAlgorithm = xops.RandomAlgorithm
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name
RandomAlgorithm = xops.RandomAlgorithm
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name
def rng_bit_generator(key,
shape,
dtype=np.uint32,
algorithm=RandomAlgorithm.RNG_DEFAULT):
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
def rng_bit_generator(key,
shape,
dtype=np.uint32,
algorithm=RandomAlgorithm.RNG_DEFAULT):
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
Returns uniformly distributed random bits with the specified shape and dtype
(what is requirted to be an integer type) using the platform specific
default algorithm or the one specified.
Returns uniformly distributed random bits with the specified shape and dtype
(what is requirted to be an integer type) using the platform specific
default algorithm or the one specified.
It provides direct acces to the RngBitGenerator primitive exposed by XLA
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
level API access.
It provides direct acces to the RngBitGenerator primitive exposed by XLA
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
level API access.
Most users should use `jax.random` instead for a stable and more user
friendly API.
"""
shape = jax.core.canonicalize_shape(shape)
return tuple(
rng_bit_generator_p.bind(
key, shape=shape, dtype=dtype, algorithm=algorithm))
else:
# TODO(tberghammer): Remove when minimum jaxlib version is past (0, 1, 62).
rng_bit_generator_p = Primitive("rng_bit_generator")
class RandomAlgorithm: pass # type: ignore
def rng_bit_generator(key, shape, dtype=np.uint32, algorithm=None):
raise "rng_bit_generator needs jaxlib 0.1.62 or newer"
Most users should use `jax.random` instead for a stable and more user
friendly API.
"""
shape = jax.core.canonicalize_shape(shape)
return tuple(
rng_bit_generator_p.bind(
key, shape=shape, dtype=dtype, algorithm=algorithm))
def _iota_abstract_eval(*, dtype, shape, dimension):

View File

@ -98,9 +98,7 @@ zip = safe_zip
FLAGS = flags.FLAGS
flags.DEFINE_bool("jax_disable_jit", bool_env("JAX_DISABLE_JIT", False),
"Disable JIT compilation and just call original Python.")
# TODO(jblespiau): Remove the `if` when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 5:
jax_jit.set_disable_jit_cpp_flag(bool_env("JAX_DISABLE_JIT", False))
jax_jit.set_disable_jit_cpp_flag(bool_env("JAX_DISABLE_JIT", False))
flags.DEFINE_bool(
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
@ -347,39 +345,9 @@ def _cpp_jit(
return _BackendAndDeviceInfo(default_device, committed_to_device)
# TODO(jblespiau): Delete `get_jax_enable_x64` and `get_jax_disable_jit_flag`
# when jaxlib 0.1.62 is the minimal version.
def get_jax_enable_x64():
"""Returns the value of the flag after GoogleInit.
We must wait until flags have been parsed (in particular for top-level
functions decorated with jax.jit), so we delay inspecting the value
of the jax_enable_x64 flag until JIT time.
"""
# TODO(jblespiau): Delete when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 4:
return config.read("jax_enable_x64")
else:
return config.x64_enabled
def get_jax_disable_jit_flag():
"""Returns the value of the `jax_disable_jit` flag.
Both a flag and the `disable_jit` context manager can disable jit. We access
the flag only once, when jitting the function, and the context manager
modifies a C++ thread-local value.
"""
return config.read("jax_disable_jit")
static_argnums_ = (0,) + tuple(i + 1 for i in static_argnums)
# TODO(jblespiau): Remove when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 5:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums_)
else:
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
get_jax_enable_x64, get_jax_disable_jit_flag,
static_argnums_)
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
static_argnums_)
# TODO(mattjj): make cpp callable follow descriptor protocol for bound methods
@wraps(fun)

View File

@ -50,9 +50,6 @@ class _ThreadLocalState(threading.local):
class Config:
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
_thread_local_state = _ThreadLocalState()
def __init__(self):
self.values = {}
self.meta = {}
@ -70,10 +67,9 @@ class Config:
raise Exception("Unrecognized config option: {}".format(name))
self.values[name] = val
# TODO(jblespiau): Remove when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 5 and name == "jax_disable_jit":
if name == "jax_disable_jit":
lib.jax_jit.set_disable_jit_cpp_flag(val)
elif lib._xla_extension_version >= 5 and name == "jax_enable_x64":
elif name == "jax_enable_x64":
lib.jax_jit.set_enable_x64_cpp_flag(val)
def read(self, name):
@ -157,21 +153,11 @@ class Config:
@property
def x64_enabled(self):
if lib._xla_extension_version >= 5:
return lib.jax_jit.get_enable_x64()
else:
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
if self._thread_local_state.enable_x64 is None:
self._thread_local_state.enable_x64 = bool(self.read('jax_enable_x64'))
return self._thread_local_state.enable_x64
return lib.jax_jit.get_enable_x64()
# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
def _set_x64_enabled(self, state):
if lib._xla_extension_version >= 5:
lib.jax_jit.set_enable_x64_thread_local(bool(state))
else:
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
self._thread_local_state.enable_x64 = bool(state)
lib.jax_jit.set_enable_x64_thread_local(bool(state))
class NameSpace(object):

View File

@ -39,10 +39,8 @@ FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_enable_x64',
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
'Enable 64-bit types to be used.')
# TODO(jblespiau): Remove the `if` when jaxlib 0.1.62 is the minimal version.
if lib._xla_extension_version >= 5:
lib.jax_jit.set_enable_x64_cpp_flag(
strtobool(os.getenv('JAX_ENABLE_X64', 'False')))
lib.jax_jit.set_enable_x64_cpp_flag(
strtobool(os.getenv('JAX_ENABLE_X64', 'False')))
# bfloat16 support
bfloat16: type = xla_client.bfloat16

View File

@ -29,7 +29,7 @@ except ModuleNotFoundError as err:
) from err
# Must be kept in sync with the jaxlib version in build/test-requirements.txt
_minimum_jaxlib_version = (0, 1, 60)
_minimum_jaxlib_version = (0, 1, 62)
try:
from jaxlib import version as jaxlib_version
except Exception as err:

View File

@ -13,7 +13,6 @@
# limitations under the License.
import inspect
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -29,8 +28,6 @@ import numpy as np
# It covers all JAX numpy types types except bfloat16 and numpy array.
# TODO(jblespiau): Add support for float0 in the C++ path.
_EXCLUDED_TYPES = [np.ndarray]
if jax.lib._xla_extension_version < 6:
_EXCLUDED_TYPES.append(jax.dtypes.bfloat16)
_SCALAR_NUMPY_TYPES = [
x for x in jax.abstract_arrays.array_types if x not in _EXCLUDED_TYPES
@ -138,10 +135,9 @@ class JaxJitTest(parameterized.TestCase):
self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
@unittest.skipIf(jax.lib._xla_extension_version < 3, "jaxlib too old")
def test_convert_int_overflow(self):
with self.assertRaisesRegex(
RuntimeError if jax.lib._xla_extension_version >= 6 else OverflowError,
RuntimeError,
"(Python int too large|Unable to convert Python scalar).*"):
jaxlib.jax_jit.device_put(int(1e100), True, jax.devices()[0])

View File

@ -2582,9 +2582,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def test_xla_cpu_gpu_loop_cond_bug(self):
# https://github.com/google/jax/issues/5900
if jax.lib.version < (0, 1, 62):
raise SkipTest("test is broken on jaxlib==0.1.61 and 0.1.60")
def deriv(f):
return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0)

View File

@ -17,7 +17,6 @@ import collections
from functools import partial
import itertools
import operator
import unittest
from unittest import SkipTest
from absl.testing import absltest
@ -2285,7 +2284,6 @@ class LaxTest(jtu.JaxTestCase):
(x,), (1.,)))(1.)
self.assertLen(jaxpr.jaxpr.eqns, 2)
@unittest.skipIf(jax.lib.version < (0, 1, 62), "Needs jaxlib 0.1.62 or newer")
def testRngBitGenerator(self):
if not config.x64_enabled:
raise SkipTest("RngBitGenerator requires 64bit key")

View File

@ -19,13 +19,11 @@ import time
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import api
from jax import lax
from jax import partial
from jax import random
from jax.config import config
from jax.config import FLAGS
from jax.experimental import enable_x64, disable_x64
import jax.numpy as jnp
import jax.test_util as jtu
@ -145,9 +143,6 @@ class X64ContextTests(jtu.JaxTestCase):
def test_jit_cache(self):
if jtu.device_under_test() == "tpu":
self.skipTest("64-bit random not available on TPU")
if jax.lib._xla_extension_version < 4 and FLAGS.experimental_cpp_jit:
self.skipTest(
"Known failure due to https://github.com/google/jax/issues/5532")
f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1)
with disable_x64():

View File

@ -236,8 +236,6 @@ def schedules(sizes: Dict[str, int]
class XMapTestCase(jtu.BufferDonationTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
@ -691,8 +689,6 @@ class NamedNNTest(XMapTestCase):
class NewPrimitiveTest(XMapTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")