mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Increase minimum jaxlib version to 0.1.38. (#2120)
This commit is contained in:
parent
09d2421f28
commit
991324f8df
@ -28,7 +28,7 @@ install:
|
||||
# The jaxlib version should match the minimum jaxlib version in
|
||||
# jax/lib/__init__.py. This tests JAX PRs against the oldest permitted
|
||||
# jaxlib.
|
||||
- pip install jaxlib==0.1.37
|
||||
- pip install jaxlib==0.1.38
|
||||
- pip install -v .
|
||||
# The following are needed to test the Colab notebooks and the documentation building
|
||||
- if [[ "$JAX_ONLY_DOCUMENTATION" != "" ]]; then
|
||||
|
11
CHANGELOG.md
11
CHANGELOG.md
@ -4,7 +4,16 @@ These are the release notes for JAX.
|
||||
|
||||
## jax 0.1.59 (unreleased)
|
||||
|
||||
## jax 0.1.58
|
||||
### Breaking changes
|
||||
|
||||
* The minimum jaxlib version is now 0.1.38.
|
||||
|
||||
## jaxlib 0.1.38 (January 29, 2020)
|
||||
|
||||
* CUDA 9.0 is no longer supported.
|
||||
* CUDA 10.2 wheels are now built by default.
|
||||
|
||||
## jax 0.1.58 (January 28, 2020)
|
||||
|
||||
### Breaking changes
|
||||
|
||||
|
@ -132,31 +132,21 @@ def _nan_like(c, operand):
|
||||
nan = c.Constant(onp.array(onp.nan, dtype=dtype))
|
||||
return c.Broadcast(nan, shape.dimensions())
|
||||
|
||||
# TODO(phawkins): remove supports_batching argument after the minimum jaxlib
|
||||
# version is 0.1.38.
|
||||
def _cholesky_cpu_gpu_translation_rule(potrf_impl, potrf_supports_batching, c,
|
||||
operand):
|
||||
def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand):
|
||||
shape = c.GetShape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
dtype = shape.element_type().type
|
||||
if len(batch_dims) == 0 or potrf_supports_batching:
|
||||
result, info = potrf_impl(c, operand, lower=True)
|
||||
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
||||
return _broadcasting_select(c,
|
||||
c.Reshape(ok, None, batch_dims + (1, 1)), result,
|
||||
_nan_like(c, result))
|
||||
else:
|
||||
# Fall back to the HLO implementation for batched Cholesky decomposition.
|
||||
return c.Cholesky(operand)
|
||||
result, info = potrf_impl(c, operand, lower=True)
|
||||
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
||||
return _broadcasting_select(c,
|
||||
c.Reshape(ok, None, batch_dims + (1, 1)), result,
|
||||
_nan_like(c, result))
|
||||
|
||||
xla.backend_specific_translations['cpu'][cholesky_p] = partial(
|
||||
_cholesky_cpu_gpu_translation_rule, lapack.potrf,
|
||||
not hasattr(lapack, "jax_potrf"))
|
||||
_cholesky_cpu_gpu_translation_rule, lapack.potrf)
|
||||
|
||||
# TODO(phawkins): remove after the minimum jaxlib version is 0.1.38.
|
||||
if hasattr(cusolver, "potrf"):
|
||||
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
|
||||
_cholesky_cpu_gpu_translation_rule, cusolver.potrf, True)
|
||||
xla.backend_specific_translations['gpu'][cholesky_p] = partial(
|
||||
_cholesky_cpu_gpu_translation_rule, cusolver.potrf)
|
||||
|
||||
# Asymmetric eigendecomposition
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import jaxlib
|
||||
|
||||
_minimum_jaxlib_version = (0, 1, 37)
|
||||
_minimum_jaxlib_version = (0, 1, 38)
|
||||
try:
|
||||
from jaxlib import version as jaxlib_version
|
||||
except:
|
||||
|
@ -185,12 +185,10 @@ JAX_COMPOUND_OP_RECORDS = [
|
||||
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive,
|
||||
[], tolerance={onp.float64: 1e-8}, inexact=True),
|
||||
op_record("fix", 1, float_dtypes, all_shapes, jtu.rand_default, []),
|
||||
op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero,
|
||||
["rev"]),
|
||||
# TODO(phawkins): merge this with the preceding entry after the minimum
|
||||
# Jaxlib version is increased to 0.1.38.
|
||||
op_record("floor_divide", 2, uint_dtypes, all_shapes, jtu.rand_nonzero,
|
||||
["rev"]),
|
||||
op_record("floor_divide", 2, number_dtypes, all_shapes,
|
||||
jtu.rand_nonzero, ["rev"]),
|
||||
op_record("floor_divide", 2, uint_dtypes, all_shapes,
|
||||
jtu.rand_nonzero, ["rev"]),
|
||||
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
||||
inexact=True),
|
||||
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [],
|
||||
|
@ -58,7 +58,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, test_name=None):
|
||||
JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
# TODO: digamma has no JVP implemented.
|
||||
op_record("betaln", 2, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("digamma", 1, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("gammainc", 2, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("erf", 1, float_dtypes, jtu.rand_small_positive, True),
|
||||
op_record("erfc", 1, float_dtypes, jtu.rand_small_positive, True),
|
||||
op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive, True),
|
||||
@ -74,12 +77,6 @@ JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
op_record("entr", 1, float_dtypes, jtu.rand_default, False),
|
||||
]
|
||||
|
||||
if lib.version > (0, 1, 37):
|
||||
JAX_SPECIAL_FUNCTION_RECORDS += [
|
||||
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("gammainc", 2, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, False),
|
||||
]
|
||||
|
||||
CombosWithReplacement = itertools.combinations_with_replacement
|
||||
|
||||
|
@ -118,6 +118,14 @@ LAX_OPS = [
|
||||
onp.float64: 1e-14}),
|
||||
op_record("digamma", 1, float_dtypes, jtu.rand_positive,
|
||||
{onp.float64: 1e-14}),
|
||||
op_record("betainc", 3, float_dtypes, jtu.rand_positive,
|
||||
{onp.float64: 1e-14}),
|
||||
op_record("igamma", 2,
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
|
||||
jtu.rand_positive, {onp.float64: 1e-14}),
|
||||
op_record("igammac", 2,
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
|
||||
jtu.rand_positive, {onp.float64: 1e-14}),
|
||||
op_record("erf", 1, float_dtypes, jtu.rand_small),
|
||||
op_record("erfc", 1, float_dtypes, jtu.rand_small),
|
||||
# TODO(b/142976030): the approximation of erfinf used by XLA is only
|
||||
@ -156,17 +164,6 @@ LAX_OPS = [
|
||||
op_record("le", 2, default_dtypes, jtu.rand_small),
|
||||
op_record("lt", 2, default_dtypes, jtu.rand_small),
|
||||
]
|
||||
if lib.version > (0, 1, 37):
|
||||
LAX_OPS += [
|
||||
op_record("betainc", 3, float_dtypes, jtu.rand_positive,
|
||||
{onp.float64: 1e-14}),
|
||||
op_record("igamma", 2,
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
|
||||
jtu.rand_positive, {onp.float64: 1e-14}),
|
||||
op_record("igammac", 2,
|
||||
[f for f in float_dtypes if f not in [dtypes.bfloat16, onp.float16]],
|
||||
jtu.rand_positive, {onp.float64: 1e-14}),
|
||||
]
|
||||
|
||||
CombosWithReplacement = itertools.combinations_with_replacement
|
||||
|
||||
|
@ -70,8 +70,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
return [onp.matmul(a, np.conj(T(a)))]
|
||||
|
||||
if (np.issubdtype(dtype, np.complexfloating) and
|
||||
(jtu.device_under_test() == "tpu" or
|
||||
(jtu.device_under_test() == "cpu" and jax.lib.version < (0, 1, 38)))):
|
||||
jtu.device_under_test() == "tpu"):
|
||||
self.skipTest("Unimplemented case for complex Cholesky decomposition.")
|
||||
|
||||
self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker,
|
||||
|
Loading…
x
Reference in New Issue
Block a user