mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Updates to jax that were deferred until after jaxlib 0.1.40 became th… (#2362)
* Updates to jax that were deferred until after jaxlib 0.1.40 became the minimum version. * Remove backward compatibility code. * Use CustomCallWithLayout instead of CustomCall. * Mention jaxlib version bump in changelog.
This commit is contained in:
parent
ddd76b803b
commit
64b1da9d48
@ -22,6 +22,7 @@ jax 0.1.60 (unreleased)
|
||||
``static_argnums`` in :py:func:`jax.jit`.
|
||||
* Improved error messages for when tracers are mistakenly saved in global state.
|
||||
* Added :py:func:`jax.nn.one_hot` utility function.
|
||||
* The minimum jaxlib version is now 0.1.40.
|
||||
|
||||
jaxlib 0.1.40 (March 4, 2020)
|
||||
--------------------------------
|
||||
|
@ -46,7 +46,7 @@ def threefry2x32(c, keys, data):
|
||||
opaque = cuda_prng_kernels.cuda_threefry2x32_descriptor(_prod(dims))
|
||||
layout = tuple(range(ndims - 1, -1, -1))
|
||||
shape = xla_client.Shape.array_shape(dtype, dims, layout)
|
||||
return c.CustomCall(
|
||||
return c.CustomCallWithLayout(
|
||||
b"cuda_threefry2x32",
|
||||
operands=(keys[0], keys[1], data[0], data[1]),
|
||||
shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]),
|
||||
|
@ -81,7 +81,7 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
|
||||
lwork, opaque = cublas_kernels.build_trsm_batched_descriptor(
|
||||
np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
b"cublas_trsm_batched",
|
||||
operands=(a, b),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -111,7 +111,7 @@ def potrf(c, a, lower):
|
||||
np.dtype(dtype), lower, batch, n)
|
||||
kernel = b"cusolver_potrf"
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
kernel,
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -151,7 +151,7 @@ def getrf(c, a):
|
||||
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
|
||||
kernel = b"cusolver_getrf"
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
kernel,
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -188,7 +188,7 @@ def geqrf(c, a):
|
||||
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
|
||||
kernel = b"cusolver_geqrf"
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
kernel,
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -229,7 +229,7 @@ def orgqr(c, a, tau):
|
||||
workspace = _Shape.array_shape(dtype, (lwork,), (0,))
|
||||
kernel = b"cusolver_orgqr"
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
kernel,
|
||||
operands=(a, tau),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -276,7 +276,7 @@ def syevd(c, a, lower=False):
|
||||
np.dtype(dtype), lower, batch, n)
|
||||
eigvals_type = _real_type(dtype)
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
kernel,
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -316,7 +316,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
vector_layout = (num_bd,) + scalar_layout
|
||||
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
b"cusolver_gesvdj",
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -345,7 +345,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
vector_layout = (num_bd,) + scalar_layout
|
||||
matrix_layout = (num_bd + 1, num_bd) + scalar_layout
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
b"cusolver_gesvd",
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -372,7 +372,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
vector_layout = (num_bd,) + scalar_layout
|
||||
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
b"cusolver_gesvd",
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
|
@ -16,7 +16,7 @@
|
||||
# distutils: language = c++
|
||||
|
||||
# Shims that allow the XLA CPU backend to call scipy-provided LAPACK kernels
|
||||
# via CustomCall.
|
||||
# via CustomCallWithLayout.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
@ -240,7 +240,7 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
raise NotImplementedError("Conjugation without transposition not supported")
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
return c.CustomCall(
|
||||
return c.CustomCallWithLayout(
|
||||
fn,
|
||||
operands=(
|
||||
c.ConstantS32Scalar(int(left_side)),
|
||||
@ -384,7 +384,7 @@ def getrf(c, a):
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
fn,
|
||||
operands=(
|
||||
c.ConstantS32Scalar(b),
|
||||
@ -576,7 +576,7 @@ def geqrf(c, a):
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
fn,
|
||||
operands=(
|
||||
c.ConstantS32Scalar(b),
|
||||
@ -777,7 +777,7 @@ def orgqr(c, a, tau):
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
fn,
|
||||
operands=(
|
||||
c.ConstantS32Scalar(b),
|
||||
@ -929,7 +929,7 @@ def potrf(c, a, lower=False):
|
||||
b *= d
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
fn,
|
||||
operands=(c.ConstantS32Scalar(int(lower)),
|
||||
c.ConstantS32Scalar(b), c.ConstantS32Scalar(n), a),
|
||||
@ -1234,7 +1234,7 @@ def gesdd(c, a, full_matrices=True, compute_uv=True):
|
||||
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
vector_layout = (num_bd,) + scalar_layout
|
||||
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
fn,
|
||||
operands=(c.ConstantS32Scalar(int(full_matrices)),
|
||||
c.ConstantS32Scalar(int(compute_uv)),
|
||||
@ -1458,7 +1458,7 @@ def syevd(c, a, lower=False):
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
fn,
|
||||
operands=(c.ConstantS32Scalar(1 if lower else 0),
|
||||
c.ConstantS32Scalar(b),
|
||||
@ -1750,7 +1750,7 @@ def geev(c, a):
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
|
||||
out = c.CustomCall(
|
||||
out = c.CustomCallWithLayout(
|
||||
fn,
|
||||
operands=(c.ConstantS32Scalar(b), c.ConstantS32Scalar(n), a),
|
||||
shape_with_layout=Shape.tuple_shape(workspaces + eigvals + (
|
||||
|
@ -46,11 +46,8 @@ torch_dtypes = [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
|
||||
empty_array_shapes = []
|
||||
# TODO(phawkins): size 0 and 1 dimensions are mishandled (with an error) when
|
||||
# being imported to JAX in jaxlib 0.1.38.
|
||||
if jax.lib.version > (0, 1, 38):
|
||||
empty_array_shapes += [(0,), (0, 4), (3, 0),]
|
||||
nonempty_nonscalar_array_shapes += [(3, 1), (1, 4), (2, 1, 4)]
|
||||
empty_array_shapes += [(0,), (0, 4), (3, 0),]
|
||||
nonempty_nonscalar_array_shapes += [(3, 1), (1, 4), (2, 1, 4)]
|
||||
|
||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||
all_shapes = nonempty_array_shapes + empty_array_shapes
|
||||
@ -100,8 +97,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in torch_dtypes))
|
||||
@unittest.skipIf(not torch or jax.lib.version <= (0, 1, 38),
|
||||
"Test requires PyTorch and jaxlib >= 0.1.39")
|
||||
@unittest.skipIf(not torch, "Test requires PyTorch")
|
||||
# TODO(phawkins): the dlpack destructor issues errors in jaxlib 0.1.38.
|
||||
def testJaxToTorch(self, shape, dtype):
|
||||
rng = jtu.rand_default()
|
||||
@ -124,8 +120,7 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in all_shapes
|
||||
for dtype in dlpack_dtypes))
|
||||
@unittest.skipIf(not cupy or jax.lib.version <= (0, 1, 38),
|
||||
"Test requires CuPy and jaxlib >= 0.1.39")
|
||||
@unittest.skipIf(not cupy, "Test requires CuPy")
|
||||
def testJaxToCuPy(self, shape, dtype):
|
||||
rng = jtu.rand_default()
|
||||
x = rng(shape, dtype)
|
||||
|
@ -1328,7 +1328,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for shape in [(3,), (5, 3)]
|
||||
for k in [1, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 40), "Test requires jaxlib 0.1.40")
|
||||
def testTopK(self, shape, dtype, k, rng_factory):
|
||||
rng = rng_factory()
|
||||
perm_rng = onp.random.RandomState(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user