Increase minimum jaxlib version to 0.1.38. (#2120)

This commit is contained in:
Peter Hawkins 2020-01-29 14:16:58 -05:00 committed by GitHub
parent 09d2421f28
commit 991324f8df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 37 additions and 47 deletions

View File

@ -28,7 +28,7 @@ install:
# The jaxlib version should match the minimum jaxlib version in
# jax/lib/__init__.py. This tests JAX PRs against the oldest permitted
# jaxlib.
- pip install jaxlib==0.1.37
- pip install jaxlib==0.1.38
- pip install -v .
# The following are needed to test the Colab notebooks and the documentation building
- if [[ "$JAX_ONLY_DOCUMENTATION" != "" ]]; then

View File

@ -4,7 +4,16 @@ These are the release notes for JAX.
## jax 0.1.59 (unreleased)
## jax 0.1.58
### Breaking changes
* The minimum jaxlib version is now 0.1.38.
## jaxlib 0.1.38 (January 29, 2020)
* CUDA 9.0 is no longer supported.
* CUDA 10.2 wheels are now built by default.
## jax 0.1.58 (January 28, 2020)
### Breaking changes

View File

@ -132,31 +132,21 @@ def _nan_like(c, operand):
nan = c.Constant(onp.array(onp.nan, dtype=dtype))
return c.Broadcast(nan, shape.dimensions())
# TODO(phawkins): remove supports_batching argument after the minimum jaxlib
# version is 0.1.38.
def _cholesky_cpu_gpu_translation_rule(potrf_impl, potrf_supports_batching, c,
operand):
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
dtype = shape.element_type().type
if len(batch_dims) == 0 or potrf_supports_batching:
result, info = potrf_impl(c, operand, lower=True)
ok = c.Eq(info, c.ConstantS32Scalar(0))
return _broadcasting_select(c,
c.Reshape(ok, None, batch_dims + (1, 1)), result,
_nan_like(c, result))
else:
# Fall back to the HLO implementation for batched Cholesky decomposition.
return c.Cholesky(operand)
result, info = potrf_impl(c, operand, lower=True)
ok = c.Eq(info, c.ConstantS32Scalar(0))
return _broadcasting_select(c,
c.Reshape(ok, None, batch_dims + (1, 1)), result,
_nan_like(c, result))
xla.backend_specific_translations['cpu'][cholesky_p] = partial(
_cholesky_cpu_gpu_translation_rule, lapack.potrf,
not hasattr(lapack, "jax_potrf"))
_cholesky_cpu_gpu_translation_rule, lapack.potrf)
# TODO(phawkins): remove after the minimum jaxlib version is 0.1.38.
if hasattr(cusolver, "potrf"):
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
_cholesky_cpu_gpu_translation_rule, cusolver.potrf, True)
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
_cholesky_cpu_gpu_translation_rule, cusolver.potrf)
# Asymmetric eigendecomposition

View File

@ -17,7 +17,7 @@
import jaxlib
_minimum_jaxlib_version = (0, 1, 37)
_minimum_jaxlib_version = (0, 1, 38)
try:
from jaxlib import version as jaxlib_version
except:

View File

@ -185,12 +185,10 @@ JAX_COMPOUND_OP_RECORDS = [
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive,
[], tolerance={onp.float64: 1e-8}, inexact=True),
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero,
["rev"]),
# TODO(phawkins): merge this with the preceding entry after the minimum
# Jaxlib version is increased to 0.1.38.
op_record("floor_divide", 2, uint_dtypes, all_shapes, jtu.rand_nonzero,
["rev"]),
op_record("floor_divide", 2, number_dtypes, all_shapes,
jtu.rand_nonzero, ["rev"]),
op_record("floor_divide", 2, uint_dtypes, all_shapes,
jtu.rand_nonzero, ["rev"]),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
inexact=True),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],

View File

@ -58,7 +58,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, test_name=None):
JAX_SPECIAL_FUNCTION_RECORDS = [
# TODO: digamma has no JVP implemented.
op_record("betaln", 2, float_dtypes, jtu.rand_positive, False),
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
op_record("digamma", 1, float_dtypes, jtu.rand_positive, False),
op_record("gammainc", 2, float_dtypes, jtu.rand_positive, False),
op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, False),
op_record("erf", 1, float_dtypes, jtu.rand_small_positive, True),
op_record("erfc", 1, float_dtypes, jtu.rand_small_positive, True),
op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive, True),
@ -74,12 +77,6 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
op_record("entr", 1, float_dtypes, jtu.rand_default, False),
]
if lib.version > (0, 1, 37):
JAX_SPECIAL_FUNCTION_RECORDS += [
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
op_record("gammainc", 2, float_dtypes, jtu.rand_positive, False),
op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, False),
]
CombosWithReplacement = itertools.combinations_with_replacement

View File

@ -118,6 +118,14 @@ LAX_OPS = [
onp.float64: 1e-14}),
op_record("digamma", 1, float_dtypes, jtu.rand_positive,
{onp.float64: 1e-14}),
op_record("betainc", 3, float_dtypes, jtu.rand_positive,
{onp.float64: 1e-14}),
op_record("igamma", 2,
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
jtu.rand_positive, {onp.float64: 1e-14}),
op_record("igammac", 2,
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
jtu.rand_positive, {onp.float64: 1e-14}),
op_record("erf", 1, float_dtypes, jtu.rand_small),
op_record("erfc", 1, float_dtypes, jtu.rand_small),
# TODO(b/142976030): the approximation of erfinf used by XLA is only
@ -156,17 +164,6 @@ LAX_OPS = [
op_record("le", 2, default_dtypes, jtu.rand_small),
op_record("lt", 2, default_dtypes, jtu.rand_small),
]
if lib.version > (0, 1, 37):
LAX_OPS += [
op_record("betainc", 3, float_dtypes, jtu.rand_positive,
{onp.float64: 1e-14}),
op_record("igamma", 2,
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
jtu.rand_positive, {onp.float64: 1e-14}),
op_record("igammac", 2,
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
jtu.rand_positive, {onp.float64: 1e-14}),
]
CombosWithReplacement = itertools.combinations_with_replacement

View File

@ -70,8 +70,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
return [onp.matmul(a, np.conj(T(a)))]
if (np.issubdtype(dtype, np.complexfloating) and
(jtu.device_under_test() == "tpu" or
(jtu.device_under_test() == "cpu" and jax.lib.version < (0, 1, 38)))):
jtu.device_under_test() == "tpu"):
self.skipTest("Unimplemented case for complex Cholesky decomposition.")
self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker,