mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Increase minimum jaxlib version to 0.1.62.
This commit is contained in:
parent
d326b077d9
commit
328930b917
@ -14,6 +14,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* New features:
|
||||
* Bug fixes:
|
||||
* Breaking changes:
|
||||
* The minimum jaxlib version is now 0.1.62.
|
||||
|
||||
## jaxlib 0.1.63 (Unreleased)
|
||||
|
||||
|
@ -2,7 +2,7 @@ flake8
|
||||
# For now, we pin the numpy version here
|
||||
numpy>=1.16
|
||||
# Must be kept in sync with the minimum jaxlib version in jax/lib/__init__.py
|
||||
jaxlib==0.1.60
|
||||
jaxlib==0.1.62
|
||||
mypy==0.790
|
||||
pillow
|
||||
pytest-benchmark
|
||||
|
@ -6086,75 +6086,66 @@ rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
||||
xla.translations[rng_uniform_p] = _rng_uniform_translation_rule
|
||||
|
||||
|
||||
if jax.lib.version >= (0, 1, 62):
|
||||
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
|
||||
_ = dtype, algorithm
|
||||
return (key.shape, tuple(shape))
|
||||
def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm):
|
||||
_ = dtype, algorithm
|
||||
return (key.shape, tuple(shape))
|
||||
|
||||
|
||||
def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
|
||||
_ = key, shape, algorithm
|
||||
return (key.dtype, dtype)
|
||||
def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm):
|
||||
_ = key, shape, algorithm
|
||||
return (key.dtype, dtype)
|
||||
|
||||
|
||||
def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
|
||||
_ = shape, dtype, algorithm
|
||||
return (key.weak_type, False)
|
||||
def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
|
||||
_ = shape, dtype, algorithm
|
||||
return (key.weak_type, False)
|
||||
|
||||
|
||||
def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm):
|
||||
_ = c
|
||||
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
|
||||
return xops.RngBitGenerator(algorithm, key, xla_shape)
|
||||
def _rng_bit_generator_translation_rule(c, key, *, shape, dtype, algorithm):
|
||||
_ = c
|
||||
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
|
||||
return xops.RngBitGenerator(algorithm, key, xla_shape)
|
||||
|
||||
|
||||
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
|
||||
return [key.named_shape, key.named_shape]
|
||||
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
|
||||
return [key.named_shape, key.named_shape]
|
||||
|
||||
rng_bit_generator_p = Primitive("rng_bit_generator")
|
||||
rng_bit_generator_p.multiple_results = True
|
||||
rng_bit_generator_p.def_impl(
|
||||
partial(xla.apply_primitive, rng_bit_generator_p))
|
||||
rng_bit_generator_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
|
||||
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
|
||||
_rng_bit_generator_weak_type_rule,
|
||||
_rng_bit_generator_named_shape_rule))
|
||||
xla.translations[rng_bit_generator_p] = _rng_bit_generator_translation_rule
|
||||
rng_bit_generator_p = Primitive("rng_bit_generator")
|
||||
rng_bit_generator_p.multiple_results = True
|
||||
rng_bit_generator_p.def_impl(
|
||||
partial(xla.apply_primitive, rng_bit_generator_p))
|
||||
rng_bit_generator_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
|
||||
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
|
||||
_rng_bit_generator_weak_type_rule,
|
||||
_rng_bit_generator_named_shape_rule))
|
||||
xla.translations[rng_bit_generator_p] = _rng_bit_generator_translation_rule
|
||||
|
||||
RandomAlgorithm = xops.RandomAlgorithm
|
||||
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name
|
||||
RandomAlgorithm = xops.RandomAlgorithm
|
||||
RandomAlgorithm.__str__ = lambda algorithm: algorithm.name
|
||||
|
||||
|
||||
def rng_bit_generator(key,
|
||||
shape,
|
||||
dtype=np.uint32,
|
||||
algorithm=RandomAlgorithm.RNG_DEFAULT):
|
||||
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
|
||||
def rng_bit_generator(key,
|
||||
shape,
|
||||
dtype=np.uint32,
|
||||
algorithm=RandomAlgorithm.RNG_DEFAULT):
|
||||
"""Stateless PRNG bit generator. Experimental and its use is discouraged.
|
||||
|
||||
Returns uniformly distributed random bits with the specified shape and dtype
|
||||
(what is requirted to be an integer type) using the platform specific
|
||||
default algorithm or the one specified.
|
||||
Returns uniformly distributed random bits with the specified shape and dtype
|
||||
(what is requirted to be an integer type) using the platform specific
|
||||
default algorithm or the one specified.
|
||||
|
||||
It provides direct acces to the RngBitGenerator primitive exposed by XLA
|
||||
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
|
||||
level API access.
|
||||
It provides direct acces to the RngBitGenerator primitive exposed by XLA
|
||||
(https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low
|
||||
level API access.
|
||||
|
||||
Most users should use `jax.random` instead for a stable and more user
|
||||
friendly API.
|
||||
"""
|
||||
shape = jax.core.canonicalize_shape(shape)
|
||||
return tuple(
|
||||
rng_bit_generator_p.bind(
|
||||
key, shape=shape, dtype=dtype, algorithm=algorithm))
|
||||
else:
|
||||
# TODO(tberghammer): Remove when minimum jaxlib version is past (0, 1, 62).
|
||||
rng_bit_generator_p = Primitive("rng_bit_generator")
|
||||
class RandomAlgorithm: pass # type: ignore
|
||||
|
||||
|
||||
def rng_bit_generator(key, shape, dtype=np.uint32, algorithm=None):
|
||||
raise "rng_bit_generator needs jaxlib 0.1.62 or newer"
|
||||
Most users should use `jax.random` instead for a stable and more user
|
||||
friendly API.
|
||||
"""
|
||||
shape = jax.core.canonicalize_shape(shape)
|
||||
return tuple(
|
||||
rng_bit_generator_p.bind(
|
||||
key, shape=shape, dtype=dtype, algorithm=algorithm))
|
||||
|
||||
|
||||
def _iota_abstract_eval(*, dtype, shape, dimension):
|
||||
|
38
jax/api.py
38
jax/api.py
@ -98,9 +98,7 @@ zip = safe_zip
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool("jax_disable_jit", bool_env("JAX_DISABLE_JIT", False),
|
||||
"Disable JIT compilation and just call original Python.")
|
||||
# TODO(jblespiau): Remove the `if` when jaxlib 0.1.62 is the minimal version.
|
||||
if lib._xla_extension_version >= 5:
|
||||
jax_jit.set_disable_jit_cpp_flag(bool_env("JAX_DISABLE_JIT", False))
|
||||
jax_jit.set_disable_jit_cpp_flag(bool_env("JAX_DISABLE_JIT", False))
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
|
||||
@ -347,39 +345,9 @@ def _cpp_jit(
|
||||
|
||||
return _BackendAndDeviceInfo(default_device, committed_to_device)
|
||||
|
||||
# TODO(jblespiau): Delete `get_jax_enable_x64` and `get_jax_disable_jit_flag`
|
||||
# when jaxlib 0.1.62 is the minimal version.
|
||||
def get_jax_enable_x64():
|
||||
"""Returns the value of the flag after GoogleInit.
|
||||
|
||||
We must wait until flags have been parsed (in particular for top-level
|
||||
functions decorated with jax.jit), so we delay inspecting the value
|
||||
of the jax_enable_x64 flag until JIT time.
|
||||
"""
|
||||
# TODO(jblespiau): Delete when jaxlib 0.1.62 is the minimal version.
|
||||
if lib._xla_extension_version >= 4:
|
||||
return config.read("jax_enable_x64")
|
||||
else:
|
||||
return config.x64_enabled
|
||||
|
||||
def get_jax_disable_jit_flag():
|
||||
"""Returns the value of the `jax_disable_jit` flag.
|
||||
|
||||
Both a flag and the `disable_jit` context manager can disable jit. We access
|
||||
the flag only once, when jitting the function, and the context manager
|
||||
modifies a C++ thread-local value.
|
||||
"""
|
||||
return config.read("jax_disable_jit")
|
||||
|
||||
static_argnums_ = (0,) + tuple(i + 1 for i in static_argnums)
|
||||
# TODO(jblespiau): Remove when jaxlib 0.1.62 is the minimal version.
|
||||
if lib._xla_extension_version >= 5:
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
static_argnums_)
|
||||
else:
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
get_jax_enable_x64, get_jax_disable_jit_flag,
|
||||
static_argnums_)
|
||||
cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info,
|
||||
static_argnums_)
|
||||
|
||||
# TODO(mattjj): make cpp callable follow descriptor protocol for bound methods
|
||||
@wraps(fun)
|
||||
|
@ -50,9 +50,6 @@ class _ThreadLocalState(threading.local):
|
||||
|
||||
|
||||
class Config:
|
||||
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
def __init__(self):
|
||||
self.values = {}
|
||||
self.meta = {}
|
||||
@ -70,10 +67,9 @@ class Config:
|
||||
raise Exception("Unrecognized config option: {}".format(name))
|
||||
self.values[name] = val
|
||||
|
||||
# TODO(jblespiau): Remove when jaxlib 0.1.62 is the minimal version.
|
||||
if lib._xla_extension_version >= 5 and name == "jax_disable_jit":
|
||||
if name == "jax_disable_jit":
|
||||
lib.jax_jit.set_disable_jit_cpp_flag(val)
|
||||
elif lib._xla_extension_version >= 5 and name == "jax_enable_x64":
|
||||
elif name == "jax_enable_x64":
|
||||
lib.jax_jit.set_enable_x64_cpp_flag(val)
|
||||
|
||||
def read(self, name):
|
||||
@ -157,21 +153,11 @@ class Config:
|
||||
|
||||
@property
|
||||
def x64_enabled(self):
|
||||
if lib._xla_extension_version >= 5:
|
||||
return lib.jax_jit.get_enable_x64()
|
||||
else:
|
||||
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
|
||||
if self._thread_local_state.enable_x64 is None:
|
||||
self._thread_local_state.enable_x64 = bool(self.read('jax_enable_x64'))
|
||||
return self._thread_local_state.enable_x64
|
||||
return lib.jax_jit.get_enable_x64()
|
||||
|
||||
# TODO(jakevdp): make this public when thread-local x64 is fully implemented.
|
||||
def _set_x64_enabled(self, state):
|
||||
if lib._xla_extension_version >= 5:
|
||||
lib.jax_jit.set_enable_x64_thread_local(bool(state))
|
||||
else:
|
||||
# TODO(jakevdp): Remove when minimum jaxlib is has extension version 4
|
||||
self._thread_local_state.enable_x64 = bool(state)
|
||||
lib.jax_jit.set_enable_x64_thread_local(bool(state))
|
||||
|
||||
|
||||
class NameSpace(object):
|
||||
|
@ -39,10 +39,8 @@ FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool('jax_enable_x64',
|
||||
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
|
||||
'Enable 64-bit types to be used.')
|
||||
# TODO(jblespiau): Remove the `if` when jaxlib 0.1.62 is the minimal version.
|
||||
if lib._xla_extension_version >= 5:
|
||||
lib.jax_jit.set_enable_x64_cpp_flag(
|
||||
strtobool(os.getenv('JAX_ENABLE_X64', 'False')))
|
||||
lib.jax_jit.set_enable_x64_cpp_flag(
|
||||
strtobool(os.getenv('JAX_ENABLE_X64', 'False')))
|
||||
|
||||
# bfloat16 support
|
||||
bfloat16: type = xla_client.bfloat16
|
||||
|
@ -29,7 +29,7 @@ except ModuleNotFoundError as err:
|
||||
) from err
|
||||
|
||||
# Must be kept in sync with the jaxlib version in build/test-requirements.txt
|
||||
_minimum_jaxlib_version = (0, 1, 60)
|
||||
_minimum_jaxlib_version = (0, 1, 62)
|
||||
try:
|
||||
from jaxlib import version as jaxlib_version
|
||||
except Exception as err:
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -29,8 +28,6 @@ import numpy as np
|
||||
# It covers all JAX numpy types types except bfloat16 and numpy array.
|
||||
# TODO(jblespiau): Add support for float0 in the C++ path.
|
||||
_EXCLUDED_TYPES = [np.ndarray]
|
||||
if jax.lib._xla_extension_version < 6:
|
||||
_EXCLUDED_TYPES.append(jax.dtypes.bfloat16)
|
||||
|
||||
_SCALAR_NUMPY_TYPES = [
|
||||
x for x in jax.abstract_arrays.array_types if x not in _EXCLUDED_TYPES
|
||||
@ -138,10 +135,9 @@ class JaxJitTest(parameterized.TestCase):
|
||||
self.assertEqual(res.dtype, complex_type)
|
||||
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
|
||||
|
||||
@unittest.skipIf(jax.lib._xla_extension_version < 3, "jaxlib too old")
|
||||
def test_convert_int_overflow(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError if jax.lib._xla_extension_version >= 6 else OverflowError,
|
||||
RuntimeError,
|
||||
"(Python int too large|Unable to convert Python scalar).*"):
|
||||
jaxlib.jax_jit.device_put(int(1e100), True, jax.devices()[0])
|
||||
|
||||
|
@ -2582,9 +2582,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def test_xla_cpu_gpu_loop_cond_bug(self):
|
||||
# https://github.com/google/jax/issues/5900
|
||||
if jax.lib.version < (0, 1, 62):
|
||||
raise SkipTest("test is broken on jaxlib==0.1.61 and 0.1.60")
|
||||
|
||||
def deriv(f):
|
||||
return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0)
|
||||
|
||||
|
@ -17,7 +17,6 @@ import collections
|
||||
from functools import partial
|
||||
import itertools
|
||||
import operator
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -2285,7 +2284,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
(x,), (1.,)))(1.)
|
||||
self.assertLen(jaxpr.jaxpr.eqns, 2)
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 62), "Needs jaxlib 0.1.62 or newer")
|
||||
def testRngBitGenerator(self):
|
||||
if not config.x64_enabled:
|
||||
raise SkipTest("RngBitGenerator requires 64bit key")
|
||||
|
@ -19,13 +19,11 @@ import time
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import api
|
||||
from jax import lax
|
||||
from jax import partial
|
||||
from jax import random
|
||||
from jax.config import config
|
||||
from jax.config import FLAGS
|
||||
from jax.experimental import enable_x64, disable_x64
|
||||
import jax.numpy as jnp
|
||||
import jax.test_util as jtu
|
||||
@ -145,9 +143,6 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
def test_jit_cache(self):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.skipTest("64-bit random not available on TPU")
|
||||
if jax.lib._xla_extension_version < 4 and FLAGS.experimental_cpp_jit:
|
||||
self.skipTest(
|
||||
"Known failure due to https://github.com/google/jax/issues/5532")
|
||||
|
||||
f = partial(random.uniform, random.PRNGKey(0), (1,), 'float64', -1, 1)
|
||||
with disable_x64():
|
||||
|
@ -236,8 +236,6 @@ def schedules(sizes: Dict[str, int]
|
||||
|
||||
class XMapTestCase(jtu.BufferDonationTestCase):
|
||||
def setUp(self):
|
||||
if jax.lib.version < (0, 1, 58):
|
||||
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
super().setUp()
|
||||
@ -691,8 +689,6 @@ class NamedNNTest(XMapTestCase):
|
||||
|
||||
class NewPrimitiveTest(XMapTestCase):
|
||||
def setUp(self):
|
||||
if jax.lib.version < (0, 1, 58):
|
||||
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user