Updates to jax that were deferred until after jaxlib 0.1.40 became th… (#2362)

* Updates to jax that were deferred until after jaxlib 0.1.40 became the minimum version.
* Remove backward compatibility code.
* Use CustomCallWithLayout instead of CustomCall.

* Mention jaxlib version bump in changelog.
This commit is contained in:
Peter Hawkins 2020-03-05 13:10:20 -05:00 committed by GitHub
parent ddd76b803b
commit 64b1da9d48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 24 additions and 29 deletions

View File

@ -22,6 +22,7 @@ jax 0.1.60 (unreleased)
``static_argnums`` in :py:func:`jax.jit`.
* Improved error messages for when tracers are mistakenly saved in global state.
* Added :py:func:`jax.nn.one_hot` utility function.
* The minimum jaxlib version is now 0.1.40.
jaxlib 0.1.40 (March 4, 2020)
--------------------------------

View File

@ -46,7 +46,7 @@ def threefry2x32(c, keys, data):
opaque = cuda_prng_kernels.cuda_threefry2x32_descriptor(_prod(dims))
layout = tuple(range(ndims - 1, -1, -1))
shape = xla_client.Shape.array_shape(dtype, dims, layout)
return c.CustomCall(
return c.CustomCallWithLayout(
b"cuda_threefry2x32",
operands=(keys[0], keys[1], data[0], data[1]),
shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]),

View File

@ -81,7 +81,7 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
lwork, opaque = cublas_kernels.build_trsm_batched_descriptor(
np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = c.CustomCall(
out = c.CustomCallWithLayout(
b"cublas_trsm_batched",
operands=(a, b),
shape_with_layout=_Shape.tuple_shape((
@ -111,7 +111,7 @@ def potrf(c, a, lower):
np.dtype(dtype), lower, batch, n)
kernel = b"cusolver_potrf"
out = c.CustomCall(
out = c.CustomCallWithLayout(
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
@ -151,7 +151,7 @@ def getrf(c, a):
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
kernel = b"cusolver_getrf"
out = c.CustomCall(
out = c.CustomCallWithLayout(
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
@ -188,7 +188,7 @@ def geqrf(c, a):
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
kernel = b"cusolver_geqrf"
out = c.CustomCall(
out = c.CustomCallWithLayout(
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
@ -229,7 +229,7 @@ def orgqr(c, a, tau):
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
kernel = b"cusolver_orgqr"
out = c.CustomCall(
out = c.CustomCallWithLayout(
kernel,
operands=(a, tau),
shape_with_layout=_Shape.tuple_shape((
@ -276,7 +276,7 @@ def syevd(c, a, lower=False):
np.dtype(dtype), lower, batch, n)
eigvals_type = _real_type(dtype)
out = c.CustomCall(
out = c.CustomCallWithLayout(
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
@ -316,7 +316,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
out = c.CustomCall(
out = c.CustomCallWithLayout(
b"cusolver_gesvdj",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
@ -345,7 +345,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd + 1, num_bd) + scalar_layout
out = c.CustomCall(
out = c.CustomCallWithLayout(
b"cusolver_gesvd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
@ -372,7 +372,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
out = c.CustomCall(
out = c.CustomCallWithLayout(
b"cusolver_gesvd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((

View File

@ -16,7 +16,7 @@
# distutils: language = c++
# Shims that allow the XLA CPU backend to call scipy-provided LAPACK kernels
# via CustomCall.
# via CustomCallWithLayout.
from __future__ import print_function
@ -240,7 +240,7 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
raise NotImplementedError("Conjugation without transposition not supported")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
return c.CustomCall(
return c.CustomCallWithLayout(
fn,
operands=(
c.ConstantS32Scalar(int(left_side)),
@ -384,7 +384,7 @@ def getrf(c, a):
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = c.CustomCall(
out = c.CustomCallWithLayout(
fn,
operands=(
c.ConstantS32Scalar(b),
@ -576,7 +576,7 @@ def geqrf(c, a):
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = c.CustomCall(
out = c.CustomCallWithLayout(
fn,
operands=(
c.ConstantS32Scalar(b),
@ -777,7 +777,7 @@ def orgqr(c, a, tau):
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = c.CustomCall(
out = c.CustomCallWithLayout(
fn,
operands=(
c.ConstantS32Scalar(b),
@ -929,7 +929,7 @@ def potrf(c, a, lower=False):
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = c.CustomCall(
out = c.CustomCallWithLayout(
fn,
operands=(c.ConstantS32Scalar(int(lower)),
c.ConstantS32Scalar(b), c.ConstantS32Scalar(n), a),
@ -1234,7 +1234,7 @@ def gesdd(c, a, full_matrices=True, compute_uv=True):
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
out = c.CustomCall(
out = c.CustomCallWithLayout(
fn,
operands=(c.ConstantS32Scalar(int(full_matrices)),
c.ConstantS32Scalar(int(compute_uv)),
@ -1458,7 +1458,7 @@ def syevd(c, a, lower=False):
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = c.CustomCall(
out = c.CustomCallWithLayout(
fn,
operands=(c.ConstantS32Scalar(1 if lower else 0),
c.ConstantS32Scalar(b),
@ -1750,7 +1750,7 @@ def geev(c, a):
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = c.CustomCall(
out = c.CustomCallWithLayout(
fn,
operands=(c.ConstantS32Scalar(b), c.ConstantS32Scalar(n), a),
shape_with_layout=Shape.tuple_shape(workspaces + eigvals + (

View File

@ -46,11 +46,8 @@ torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
empty_array_shapes = []
# TODO(phawkins): size 0 and 1 dimensions are mishandled (with an error) when
# being imported to JAX in jaxlib 0.1.38.
if jax.lib.version > (0, 1, 38):
empty_array_shapes += [(0,), (0, 4), (3, 0),]
nonempty_nonscalar_array_shapes += [(3, 1), (1, 4), (2, 1, 4)]
empty_array_shapes += [(0,), (0, 4), (3, 0),]
nonempty_nonscalar_array_shapes += [(3, 1), (1, 4), (2, 1, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
all_shapes = nonempty_array_shapes + empty_array_shapes
@ -100,8 +97,7 @@ class DLPackTest(jtu.JaxTestCase):
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in torch_dtypes))
@unittest.skipIf(not torch or jax.lib.version <= (0, 1, 38),
"Test requires PyTorch and jaxlib >= 0.1.39")
@unittest.skipIf(not torch, "Test requires PyTorch")
# TODO(phawkins): the dlpack destructor issues errors in jaxlib 0.1.38.
def testJaxToTorch(self, shape, dtype):
rng = jtu.rand_default()
@ -124,8 +120,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
"shape": shape, "dtype": dtype}
for shape in all_shapes
for dtype in dlpack_dtypes))
@unittest.skipIf(not cupy or jax.lib.version <= (0, 1, 38),
"Test requires CuPy and jaxlib >= 0.1.39")
@unittest.skipIf(not cupy, "Test requires CuPy")
def testJaxToCuPy(self, shape, dtype):
rng = jtu.rand_default()
x = rng(shape, dtype)

View File

@ -1328,7 +1328,6 @@ class LaxTest(jtu.JaxTestCase):
for shape in [(3,), (5, 3)]
for k in [1, 3]
for rng_factory in [jtu.rand_default]))
@unittest.skipIf(jax.lib.version < (0, 1, 40), "Test requires jaxlib 0.1.40")
def testTopK(self, shape, dtype, k, rng_factory):
rng = rng_factory()
perm_rng = onp.random.RandomState(0)