rocm_jax/tests/lax_autodiff_test.py
Peter Hawkins 7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00

1184 lines
48 KiB
Python

# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from functools import partial
import itertools
import math
from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import dtypes
from jax import lax
from jax._src import test_util as jtu
from jax._src.util import NumpyComplexWarning
from jax.test_util import check_grads
jax.config.parse_flags_with_absl()
compatible_shapes = [[(3,)],
[(), (3, 4), (3, 1), (1, 4)],
[(2, 3, 4), (2, 1, 4)]]
GradTestSpec = collections.namedtuple(
"GradTestSpec",
["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"])
def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
return GradTestSpec(
op, nargs, order, rng_factory, dtypes, name or op.__name__, tol)
float_dtypes = jtu.dtypes.all_floating
inexact_dtypes = jtu.dtypes.all_inexact
grad_float_dtypes = jtu.dtypes.floating
grad_complex_dtypes = jtu.dtypes.complex
grad_inexact_dtypes = jtu.dtypes.inexact
LAX_GRAD_OPS = [
grad_test_spec(lax.neg, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.floor, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
dtypes=grad_float_dtypes),
grad_test_spec(lax.ceil, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
dtypes=grad_float_dtypes),
grad_test_spec(lax.round, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
dtypes=grad_float_dtypes),
grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.exp2, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.log1p, nargs=1, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.sinh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes + [np.complex64], tol=1e-5),
grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol=1e-5),
grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol=2e-4),
grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol={np.float32: 5e-1}),
grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.tan, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3),
dtypes=grad_inexact_dtypes, tol=1e-3),
grad_test_spec(lax.asin, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3),
dtypes=grad_float_dtypes, tol=1e-3),
grad_test_spec(lax.acos, nargs=1, order=2,
rng_factory=partial(jtu.rand_uniform, low=-1.3, high=1.3),
dtypes=grad_float_dtypes, tol=2e-2),
# TODO(proteneer): atan2 input is already a representation of a
# complex number. Need to think harder about what this even means
# if each input itself is a complex number.
grad_test_spec(lax.atan2, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes),
grad_test_spec(lax.erf, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_float_dtypes),
grad_test_spec(lax.erfc, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_float_dtypes),
grad_test_spec(lax.erf_inv, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_float_dtypes),
# grad_test_spec(lax.lgamma, nargs=1, order=2, rng_factory=jtu.rand_small,
# dtypes=grad_float_dtypes), # TODO(mattjj): enable
grad_test_spec(lax.bessel_i0e, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes),
grad_test_spec(lax.bessel_i1e, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes),
grad_test_spec(lax.real, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_complex_dtypes),
grad_test_spec(lax.imag, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_complex_dtypes),
grad_test_spec(lax.complex, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes),
grad_test_spec(lax.conj, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.abs, nargs=1, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.pow, nargs=2, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_inexact_dtypes, tol={np.float32: 3e-1}),
grad_test_spec(lax.sqrt, nargs=1, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_float_dtypes),
grad_test_spec(lax.sqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_complex_dtypes),
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_positive,
dtypes=grad_float_dtypes),
grad_test_spec(lax.rsqrt, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_complex_dtypes, tol={np.float64: 2e-3}),
grad_test_spec(lax.cbrt, nargs=1, order=2, rng_factory=jtu.rand_not_small,
dtypes=grad_float_dtypes, tol={np.float64: 5e-3}),
grad_test_spec(lax.logistic, nargs=1, order=2,
rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.add, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.sub, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.mul, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.div, nargs=2, order=1, rng_factory=jtu.rand_not_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.max, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes),
grad_test_spec(lax.min, nargs=2, order=2, rng_factory=jtu.rand_default,
dtypes=grad_float_dtypes),
# TODO(mattjj): make some-equal checks more robust, enable second-order
# grad_test_spec(lax.max, nargs=2, order=1, rng_factory=jtu.rand_some_equal,
# dtypes=grad_float_dtypes, name="MaxSomeEqual"),
# grad_test_spec(lax.min, nargs=2, order=1, rng_factory=jtu.rand_some_equal,
# dtypes=grad_float_dtypes, name="MinSomeEqual"),
]
GradSpecialValuesTestSpec = collections.namedtuple(
"GradSpecialValuesTestSpec", ["op", "values", "tol"])
def grad_special_values_test_spec(op, values, tol=None):
return GradSpecialValuesTestSpec(op, values, tol)
LAX_GRAD_SPECIAL_VALUE_TESTS = [
grad_special_values_test_spec(lax.sinh, [0.]),
grad_special_values_test_spec(lax.cosh, [0.]),
grad_special_values_test_spec(lax.tanh, [0., 1000.], tol=5e-3),
grad_special_values_test_spec(lax.sin, [0., np.pi, np.pi/2., np.pi/4.]),
grad_special_values_test_spec(lax.cos, [0., np.pi, np.pi/2., np.pi/4.]),
grad_special_values_test_spec(lax.tan, [0.]),
grad_special_values_test_spec(lax.asin, [0.]),
grad_special_values_test_spec(lax.acos, [0.]),
grad_special_values_test_spec(lax.atan, [0., 1000.]),
grad_special_values_test_spec(lax.erf, [0., 10.]),
grad_special_values_test_spec(lax.erfc, [0., 10.]),
]
def check_grads_bilinear(f, args, order,
modes=("fwd", "rev"), atol=None, rtol=None):
# Can use large eps to make up for numerical inaccuracies since the op is
# bilinear (relying on the fact that we only check one arg at a time)
lhs, rhs = args
check_grads(lambda lhs: f(lhs, rhs), (lhs,), order,
modes=modes, atol=atol, rtol=rtol, eps=1.)
check_grads(lambda rhs: f(lhs, rhs), (rhs,), order,
modes=modes, atol=atol, rtol=rtol, eps=1.)
class LaxAutodiffTest(jtu.JaxTestCase):
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.op, rng_factory=rec.rng_factory, order=rec.order, tol=rec.tol)],
shapes=[
shapes for shape_group in compatible_shapes
for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
],
dtype=rec.dtypes,
)
for rec in LAX_GRAD_OPS
))
def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
rng = rng_factory(self.rng())
if jtu.test_device_matches(["cpu"]):
if op is lax.cosh and dtype == np.complex64:
tol = 3e-1 # 2nd-order gradients are noisy on CPU
if jtu.test_device_matches(["tpu"]):
if op is lax.pow:
raise SkipTest("pow grad imprecise on tpu")
if op is lax.cos:
order = 1 # 2nd-order gradient is imprecise on TPU.
if op is lax.log:
order = 1 # 2nd-order gradient is imprecise on TPU.
tol = jtu.join_tolerance(1.5e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
args = tuple(rng(shape, dtype) for shape in shapes)
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.op, tol=rec.tol)],
special_value=rec.values,
)
for rec in LAX_GRAD_SPECIAL_VALUE_TESTS
))
def testOpGradSpecialValue(self, op, special_value, tol):
if op in (lax.sinh, lax.cosh) and jtu.test_device_matches(["tpu"]):
tol = {np.float32: 1e-2}
check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)
@jtu.sample_product(
from_dtype=inexact_dtypes,
to_dtype=inexact_dtypes,
)
def testConvertElementTypeGrad(self, from_dtype, to_dtype):
rng = jtu.rand_default(self.rng())
tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
args = (rng((2, 3), from_dtype),)
convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
convert_element_type = jtu.ignore_warning(category=NumpyComplexWarning)(
convert_element_type)
check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)
@jtu.sample_product(
shape=[(), (2, 3)],
dtype=grad_float_dtypes,
)
def testClampGrad(self, shape, dtype):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
low = operand - dtype(10)
high = operand + dtype(10)
# Avoids points near the boundary where the gradient may be inaccurate.
check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2)
check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)
@jtu.sample_product(
[dict(base_shape=base_shape, dim=dim)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for dim in range(len(base_shape))
],
num_arrs=[3],
dtype=float_dtypes,
)
def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs):
rng = jtu.rand_default(self.rng())
shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
operands = tuple(rng(shape, dtype) for shape in shapes)
concatenate = lambda *args: lax.concatenate(args, dim)
check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(len(base_shape))
],
num_pieces=range(3),
dtype=float_dtypes,
)
def testSplitGrad(self, axis, base_shape, dtype, num_pieces):
sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64)
shape = list(base_shape)
shape[axis] = np.sum(sizes)
rng = jtu.rand_default(self.rng())
operands = (rng(shape, dtype),)
split = lambda x: lax.split(x, sizes, axis)
check_grads(split, operands, 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides)
for lhs_shape, rhs_shape, all_strides in itertools.chain(
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
for b, i, j in itertools.product([2, 3], repeat=3)],
[((4, 2, 1), (3, 2, 1), [(1,)])])
for strides in all_strides
],
dtype=float_dtypes,
padding=["VALID", "SAME"],
)
def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding):
rng = jtu.rand_small(self.rng())
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
conv = partial(lax.conv, window_strides=strides, padding=padding,
precision=lax.Precision.HIGHEST)
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=1e-2, rtol=1e-2)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides,
padding=padding, lhs_dil=lhs_dil, rhs_dil=rhs_dil)
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
itertools.chain(
[((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
[((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
[(1, 1), (2, 1)], [(1, 1)])
for b, i, j in itertools.product([2, 3], repeat=3)],
[((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)],
[(1,), (2,)], [(1,), (2,)])])
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
for padding in all_pads
],
dtype=float_dtypes,
)
def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dil, rhs_dil):
rng = jtu.rand_small(self.rng())
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
conv = partial(lax.conv_with_general_padding, window_strides=strides,
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
precision=lax.Precision.HIGHEST)
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=1e-2, rtol=1e-2)
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides,
padding=padding, lhs_dil=lhs_dil, rhs_dil=rhs_dil,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
([(b * batch_group_count, i * feature_group_count, 6, 7),
(b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape
(j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape
[(1, 1), (1, 2), (2, 1)], # strides
[(1, 1), (2, 1)], # lhs_dils
[(1, 1), (2, 2)]) # rhs_dils
for b, i, j in itertools.product([1, 2], repeat=3)]
for lhs_shape in lhs_shapes
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
],
[dict(dimension_numbers=dim_nums, perms=perms)
for dim_nums, perms in [
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
],
dtype=grad_inexact_dtypes,
)
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dil, rhs_dil, dimension_numbers,
perms, feature_group_count, batch_group_count):
if dtype == np.float16:
raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve
rng = jtu.rand_default(self.rng())
tol = {dtypes.bfloat16: 1e-0, np.float16: 5e-1, np.float32: 1e-3}
# permute shapes to match dim_spec, scale by feature_group_count
lhs_perm, rhs_perm = perms
lhs_shape = list(np.take(lhs_shape, lhs_perm))
rhs_shape = list(np.take(rhs_shape, rhs_perm))
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
conv = partial(lax.conv_general_dilated, window_strides=strides,
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=lax.Precision.HIGHEST)
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=tol, rtol=tol)
@jtu.sample_product(
lhs_shape=[(2,), (3, 2)],
rhs_shape=[(2,), (2, 4)],
dtype=float_dtypes,
)
def testDotGrad(self, lhs_shape, rhs_shape, dtype):
rng = jtu.rand_default(self.rng())
tol = {np.float16: 1e-1, np.float32: 1e-4}
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=tol, rtol=tol)
# check that precision config is preserved
result, pullback = jax.vjp(dot, lhs, rhs)
s = str(jax.make_jaxpr(pullback)(result))
assert "Precision.HIGHEST" in s
@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers)
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 2), (2, 4), (([1], [0]), ([], []))),
((3, 5), (2, 5), (([1], [1]), ([], []))),
((5, 3), (5, 2), (([0], [0]), ([], []))),
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))),
((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))),
]
],
dtype=float_dtypes,
)
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
dimension_numbers):
rng = jtu.rand_small(self.rng())
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
precision=lax.Precision.HIGHEST)
atol = {np.float16: 5E-2} if jtu.test_device_matches(['tpu']) else None
check_grads_bilinear(dot_general, (lhs, rhs), order=2,
modes=["fwd", "rev"], atol=atol)
# check that precision config is preserved
result, pullback = jax.vjp(dot_general, lhs, rhs)
s = str(jax.make_jaxpr(pullback)(result))
assert "Precision.HIGHEST" in s
def testDotPreferredElementType(self):
# https://github.com/jax-ml/jax/issues/10818
x = jax.numpy.ones((), jax.numpy.float16)
def f(x):
return jax.lax.dot_general(x, x, (((), ()), ((), ())),
preferred_element_type=jax.numpy.float32)
jax.jacrev(f)(x) # don't crash!
@jtu.sample_product(
shape=[(), (2, 3)],
dtype=float_dtypes,
broadcast_sizes=[(), (2,), (1, 2)],
)
def testBroadcastGrad(self, shape, dtype, broadcast_sizes):
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(inshape=inshape, outshape=outshape, dimensions=broadcast_dimensions)
for inshape, outshape, broadcast_dimensions in [
([2], [2, 2], [0]),
([2], [2, 2], [1]),
([2], [2, 3], [0]),
([], [2, 3], []),
]
],
dtype=float_dtypes,
)
def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions):
rng = jtu.rand_default(self.rng())
operand = rng(inshape, dtype)
broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(arg_shape=arg_shape, out_shape=out_shape, permutation=permutation)
for arg_shape, out_shape, permutation in [
[(3, 4), (12,), None],
[(2, 1, 4), (8,), None],
[(2, 2, 4), (2, 8), None],
[(3, 4), (12,), (0, 1)],
[(3, 4), (12,), (1, 0)],
[(2, 1, 4), (8,), (0, 2, 1)],
[(2, 1, 4), (8,), (2, 0, 1)],
[(2, 2, 4), (2, 8), (0, 2, 1)],
[(2, 2, 4), (2, 8), (2, 0, 1)],
]
],
dtype=float_dtypes,
)
def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype):
rng = jtu.rand_default(self.rng())
operand = rng(arg_shape, dtype)
reshape = lambda x: lax.reshape(x, out_shape, permutation)
check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(shape=shape, pads=pads)
for shape, paddings in [
[(), [()]],
((2, 3), [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]),
]
for pads in paddings
],
dtype=float_dtypes,
)
def testPadGrad(self, shape, dtype, pads):
rng = jtu.rand_small(self.rng())
operand = rng(shape, dtype)
pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)
operand = rng(shape, dtype)
padding_value = np.array(0., dtype)
pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)
def testReverseGrad(self):
rev = lambda operand: lax.rev(operand, dimensions)
dimensions = [0]
check_grads(rev, (np.array([3., 2., 1.]),), 2)
dimensions = [0, 1]
check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
rtol={np.float32: 3e-3})
def testPowSecondDerivative(self):
# https://github.com/jax-ml/jax/issues/12033
x, y = 4.0, 0.0
expected = ((0.0, 1/x), (1/x, np.log(x) ** 2))
with self.subTest("jacfwd"):
result_fwd = jax.jacfwd(jax.jacfwd(lax.pow, (0, 1)), (0, 1))(x, y)
self.assertAllClose(result_fwd, expected)
with self.subTest("jacrev"):
result_rev = jax.jacrev(jax.jacrev(lax.pow, (0, 1)), (0, 1))(x, y)
self.assertAllClose(result_rev, expected)
with self.subTest("zero to the zero"):
result = jax.grad(lax.pow)(0.0, 0.0)
# TODO(jakevdp) special-case zero in a way that doesn't break other cases
# See https://github.com/jax-ml/jax/pull/12041#issuecomment-1222766191
# self.assertEqual(result, 0.0)
self.assertAllClose(result, np.nan)
def testPowIntPowerAtZero(self):
# https://github.com/jax-ml/jax/issues/14397
ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0)
self.assertAllClose(ans, 0., check_dtypes=False)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testPowIntPowerAtZero2(self):
# https://github.com/jax-ml/jax/issues/17995
a = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=int))
b = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=float))
c = lambda z: 1 + z
d = lambda z: z ** 0 + z
e = lambda z: z ** 0. + z
self.assertAllClose(jax.grad(a)(3.14), 1., check_dtypes=False)
self.assertAllClose(jax.grad(b)(3.14), 1., check_dtypes=False)
self.assertAllClose(jax.grad(c)(3.14), 1., check_dtypes=False)
self.assertAllClose(jax.grad(d)(3.14), 1., check_dtypes=False)
self.assertAllClose(jax.grad(e)(3.14), 1., check_dtypes=False)
@jtu.sample_product(
[dict(arg_shape=arg_shape, pred_shape=pred_shape)
for arg_shape in [(), (3,), (2, 3)]
for pred_shape in ([(), arg_shape] if arg_shape else [()])
],
dtype=float_dtypes,
)
def testSelectGrad(self, pred_shape, arg_shape, dtype):
rng = jtu.rand_default(self.rng())
pred = rng(pred_shape, np.bool_)
on_true = rng(arg_shape, dtype)
on_false = rng(arg_shape, dtype)
select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(shape=shape, starts=start_indices, limits=limit_indices,
strides=strides)
for shape, start_indices, limit_indices, strides in [
[(3,), (1,), (2,), None],
[(7,), (4,), (7,), None],
[(5,), (1,), (5,), (2,)],
[(8,), (1,), (6,), (2,)],
[(5, 3), (1, 1), (3, 2), None],
[(5, 3), (1, 1), (3, 1), None],
[(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
[(5, 3), (1, 1), (2, 1), (1, 1)],
[(5, 3), (1, 1), (5, 3), (2, 1)],
[(3, 3, 5), (0, 2, 0), (3, 2, 5), (1, 2, 1)]
]
],
dtype=float_dtypes,
)
def testSliceGrad(self, shape, dtype, starts, limits, strides):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
slice = lambda x: lax.slice(x, starts, limits, strides)
check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(shape=shape, start_indices=start_indices, size_indices=size_indices)
for shape, start_indices, size_indices in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
],
dtype=float_dtypes,
)
def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(shape=shape, start_indices=start_indices, update_shape=update_shape)
for shape, start_indices, update_shape in [
[(3,), (1,), (1,)],
[(5, 3), (1, 1), (3, 1)],
[(7, 5, 3), (4, 1, 0), (2, 0, 1)],
]
],
dtype=float_dtypes,
)
def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
update_shape):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
update = rng(update_shape, dtype)
start_indices = np.array(start_indices)
dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)
dus = lambda x: lax.dynamic_update_slice(x, update, start_indices)
check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.)
dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)
def testDynamicSliceValueAndGrad(self):
# Regression test for https://github.com/jax-ml/jax/issues/10984
# Issue arose due to an out-of-range negative index.
rng = jtu.rand_default(self.rng())
shape = (5, 5)
axis = 0
index = -(shape[axis] + 3)
def f(x):
return lax.dynamic_index_in_dim(x, index, axis).sum()
x = rng(shape, np.float32)
result1 = f(x)
result2, _ = jax.value_and_grad(f, 0)(x)
self.assertAllClose(result1, result2)
def testDynamicUpdateSliceValueAndGrad(self):
# Regression test for https://github.com/jax-ml/jax/issues/10984
# Issue arose due to an out-of-range negative index.
rng = jtu.rand_default(self.rng())
shape = (5, 5)
axis = 0
index = -(shape[axis] + 3)
def f(x, y):
return lax.dynamic_update_index_in_dim(x, y, index, axis).sum()
x = rng(shape, np.float32)
y = rng([1 for s in shape], np.float32)
result1 = f(x, y)
result2, _ = jax.value_and_grad(f, 0)(x, y)
self.assertAllClose(result1, result2)
@jtu.sample_product(
[dict(shape=shape, perm=perm)
for shape, perm in [
[(3, 4), (1, 0)],
[(3, 4), (0, 1)],
[(3, 4, 5), (2, 1, 0)],
[(3, 4, 5), (1, 0, 2)],
]
],
dtype=float_dtypes,
)
def testTransposeGrad(self, shape, dtype, perm):
rng = jtu.rand_default(self.rng())
operand = rng(shape, dtype)
transpose = lambda x: lax.transpose(x, perm)
check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory)
for init_val, op, dtypes, rng_factory in [
(0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default),
(-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
(np.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
(1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
]
for dtype in dtypes
],
[dict(shape=shape, dims=dims)
for shape, dims in [
[(), ()],
[(3, 4, 5), ()],
[(3, 4, 5), (0,)],
[(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)],
[(3, 4, 5), (0, 1, 2)],
[(3, 1), (1,)],
[(3, 0, 5), (1,)],
]
],
)
def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
rng = rng_factory(self.rng())
if jtu.test_device_matches(["tpu"]) and op is lax.mul:
raise SkipTest("unimplemented case")
tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-1, np.float32: 1e-1,
np.float64: 1e-3, np.complex64: 1e-1}
operand = rng(shape, dtype)
init_val = np.asarray(init_val, dtype=dtype)
reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
1e-1 if dtype == dtypes.bfloat16 else
1e-2 if dtypes.finfo(dtype).bits == 32 else None)
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
@jtu.sample_product(
[dict(shape=shape, dims=dims)
for shape, dims in [
[(3, 4, 5), ()],
[(3, 4, 5), (0,)],
[(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)],
[(3, 4, 5), (0, 1, 2)],
[(3, 1), (1,)],
[(3, 0, 5), (1,)],
]
],
dtype=grad_float_dtypes,
)
def testReducePairGrad(self, shape, dtype, dims):
rng = jtu.rand_default(self.rng(), scale=1)
tol = {np.float32: 1e-2, np.float64: 1e-4}
operands = (rng(shape, dtype), rng(shape, dtype))
init_vals = (np.array(0, dtype), np.array(1, dtype))
def op(xs, ys):
return (xs[0] + ys[0], xs[1] * ys[1])
reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims)
check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)
@jtu.sample_product(
[dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory,
shape=shape, dims=dims, strides=strides, padding=padding,
base_dilation=base_dilation, window_dilation=window_dilation)
for init_val, op, dtypes, rng_factory in [
(0, lax.add, grad_float_dtypes, jtu.rand_small),
(-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
(np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
]
for dtype in dtypes
for shape, dims, strides, padding, base_dilation, window_dilation in (
itertools.chain(
itertools.product(
[(4, 6)],
[(2, 1), (1, 2)],
[(1, 1), (2, 1), (1, 2)],
["VALID", "SAME", [(0, 3), (1, 2)]],
[(1, 1)] + ([(2, 3)]),
[(1, 1)] + ([(1, 2)] if op is lax.add else [])),
itertools.product(
[(3, 2, 4, 6)],
[(1, 1, 2, 1), (2, 1, 2, 1)],
[(1, 2, 2, 1), (1, 1, 1, 1)],
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
[(1, 1, 1, 1)] + ([(2, 1, 3, 2)]),
[(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else []))))
],
)
@jtu.ignore_warning(category=UserWarning,
message="Using reduced precision for gradient.*")
def testReduceWindowGrad(
self, op, init_val, dtype, shape, dims, strides,
padding, base_dilation, window_dilation, rng_factory):
rng = rng_factory(self.rng())
init_val = np.asarray(init_val, dtype=dtype)
gradient_order = 3
# We need this conditional and the corresponding loop logic to be in the
# test method, rather than at the parameterized test level, because it
# depends on FLAGS for the device under test.
# TODO(b/31565929): enable when fixed.
if jtu.test_device_matches(["tpu"]) and op is not lax.add:
if (len(shape) != 4 or dims != (1, 1, 2, 1)
or not isinstance(padding, str)):
raise SkipTest("Only R4 SelectAndScatter implemented on TPU")
def fun(operand):
return lax.reduce_window(operand, init_val, op, dims, strides, padding,
base_dilation, window_dilation)
operand = rng(shape, dtype)
if op is lax.add:
eps = 1.
tol = None
else:
# this test can fail if there are duplicates in operand
self.assertEqual(np.unique(operand).size, operand.size,
msg="test requires operand elements to be unique.")
eps = 1e-2
tol = {np.float16: 1e-1, np.float32: 6e-2, np.float64: 6e-2}
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
eps)
@jtu.sample_product(
[dict(op=op, dtype=dtype)
for op, types in [
(lax.cumsum, [np.float32, np.float64]),
(lax.cumprod, [np.float32, np.float64]),
]
for dtype in types
],
[dict(shape=shape, axis=axis)
for shape in [[10], [3, 4, 5]]
for axis in range(len(shape))
],
reverse=[False, True],
)
def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small)
rng = rng_factory(self.rng())
check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),),
order=2)
# TODO(b/205052657): enable more tests when supported
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(5,), (5, 7), (4, 9, 3)]
for axis in [len(shape) - 1]
],
dtype=[np.float32],
is_stable=[False, True],
)
def testSortGrad(self, shape, dtype, axis, is_stable):
rng = jtu.rand_unique_int(self.rng())
operand = rng(shape, dtype)
sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)
# TODO(b/205052657): enable more tests when supported
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(3,), (5, 3), (4, 9, 3)]
for axis in [len(shape) - 1]
],
key_dtype=[np.float32],
val_dtype=[np.float32],
is_stable=[False, True],
)
def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable):
rng = jtu.rand_default(self.rng())
# This test relies on the property that wherever keys are tied, values are
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
def args_maker():
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
keys = self.rng().permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
keys, values = args_maker()
fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)
@jtu.sample_product(
dtype=[np.float32,],
shape=[(4,), (5, 5), (2, 1, 4)],
k=[1, 3],
)
def testTopKGrad(self, shape, dtype, k):
flat_values = np.arange(math.prod(shape), dtype=dtype)
values = self.rng().permutation(flat_values).reshape(shape)
fun = lambda vs: lax.top_k(vs, k=k)[0]
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
@jtu.sample_product(
[dict(shape=shape, idxs=idxs, axes=axes)
for shape, idxs, axes in [
[(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
[(3, 4, 5), (np.array([-1, -2]),), (0,)],
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
[(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)],
]
],
dtype=float_dtypes,
)
@jax.numpy_rank_promotion('allow') # Test explicitly exercises implicit rank promotion.
def testIndexTakeGrad(self, shape, dtype, idxs, axes):
rng = jtu.rand_default(self.rng())
src = rng(shape, dtype)
index_take = lambda src: lax.index_take(src, idxs, axes)
check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)
@jtu.sample_product(
[dict(shape=shape, idxs_shape=idxs.shape, idxs_dtype=idxs.dtype,
dnums=dnums, slice_sizes=slice_sizes, max_idx=max_idx)
for shape, idxs, dnums, slice_sizes, max_idx in [
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,), 5),
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,), 9),
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3), 3),
]
],
dtype=grad_float_dtypes,
mode=["clip", "fill", "promise_in_bounds"],
iteration=range(5),
)
def testGatherGrad(self, shape, dtype, idxs_shape, idxs_dtype, dnums,
slice_sizes, mode, max_idx, iteration):
rng = jtu.rand_default(self.rng())
if mode == "promise_in_bounds":
rng_idx = jtu.rand_int(self.rng(), high=max_idx)
else:
# Only test out-of-bounds indices if using a mode that guarantees correct
# gradients for out-of-bounds indices.
rng_idx = jtu.rand_int(self.rng(), low=-max_idx, high=2 * max_idx)
idxs = rng_idx(idxs_shape, idxs_dtype)
# Use an arbitrary finite fill_value, since NaNs won't work in a numerical
# gradient test.
gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
slice_sizes=slice_sizes, mode=mode,
fill_value=-1)
x = rng(shape, dtype)
check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs_shape=idxs.shape, idxs_dtype=idxs.dtype,
dnums=dnums, update_shape=update_shape, max_idx=max_idx)
for arg_shape, idxs, update_shape, dnums, max_idx in [
((5,), np.array([[0], [2]]), (2,),
lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)), 4),
((10,), np.array([[0], [0], [0]]), (3, 2),
lax.ScatterDimensionNumbers(update_window_dims=(1,),
inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,)), 9),
((10, 5,), np.array([[0], [2], [1]]), (3, 3),
lax.ScatterDimensionNumbers(update_window_dims=(1,),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)), 3),
]
],
dtype=grad_float_dtypes,
mode=["clip", "fill", "promise_in_bounds"],
iteration=range(5),
)
def testScatterAddGrad(self, arg_shape, dtype, idxs_shape, idxs_dtype,
update_shape, dnums, max_idx, mode, iteration):
rng = jtu.rand_default(self.rng())
if mode == "promise_in_bounds":
rng_idx = jtu.rand_int(self.rng(), high=max_idx)
else:
# Only test out-of-bounds indices if using a mode that guarantees correct
# gradients for out-of-bounds indices.
rng_idx = jtu.rand_int(self.rng(), low=-max_idx, high=2 * max_idx)
idxs = rng_idx(idxs_shape, idxs_dtype)
scatter_add = lambda x, y: lax.scatter_add(
x, idxs, y, dimension_numbers=dnums, mode=mode)
x = rng(arg_shape, dtype)
y = rng(update_shape, dtype)
check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
update_shape=update_shape, max_idx=max_idx, multiplier=multiplier)
for arg_shape, idxs, update_shape, dnums, max_idx, multiplier in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)), 4, 1),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,)), 4, 2),
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)), 9, 1),
]
],
dtype=grad_float_dtypes,
)
def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
max_idx, multiplier):
# Scatters with conflicting indices are not deterministic on GPU, so we
# use indices that do not collide.
rng_idx = jtu.rand_unique_int(self.rng(), high=max_idx)
rng = jtu.rand_default(self.rng())
# The multiplier ensures we don't pick overlapping windows if the update
# window is not of size 1.
idxs = rng_idx(idxs.shape, idxs.dtype) * multiplier
scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
x = rng(arg_shape, dtype)
y = rng(update_shape, dtype)
check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
def testScatterGradSymbolicZeroUpdate(self):
# https://github.com/jax-ml/jax/issues/1901
def f(x):
n = x.shape[0]
y = np.arange(n, dtype=x.dtype)
return jax.device_put(x).at[np.diag_indices(n)].set(y)
rng = jtu.rand_default(self.rng())
check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
1.)
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
update_shape=update_shape)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]
],
dtype=grad_float_dtypes,
)
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums):
rng = jtu.rand_default(self.rng())
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
idxs = rng_idx(idxs.shape, idxs.dtype)
scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
x = rng(arg_shape, dtype)
y = rng(update_shape, dtype)
check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
@jtu.sample_product(
[dict(arg_shape=arg_shape, idxs=idxs, dnums=dnums,
update_shape=update_shape)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]
],
dtype=grad_float_dtypes,
)
def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums):
rng = jtu.rand_default(self.rng())
rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
idxs = rng_idx(idxs.shape, idxs.dtype)
scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
x = rng(arg_shape, dtype)
y = rng(update_shape, dtype)
check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
def testStopGradient(self):
def f(x):
return lax.sin(x) * lax.cos(lax.stop_gradient(x))
def f2(x, y):
return lax.sin(x) * lax.cos(y)
x = 3.14
ans = jax.grad(f)(x)
expected = jax.grad(f2)(x, x)
self.assertAllClose(ans, expected)
ans = jax.grad(jax.grad(f))(x)
expected = jax.grad(jax.grad(f2))(x, x)
self.assertAllClose(ans, expected)
ans = jax.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
expected = np.array(0.0)
self.assertAllClose(ans, expected, check_dtypes=False)
with jax.enable_checks(False):
with self.assertRaises(TypeError):
lax.stop_gradient(lambda x: x)
# TODO(mattjj): make this a more systematic test
def testRemainder(self):
def gen_x(rng, size):
return rng.uniform(-9, 9, size=size)
def gen_y(rng, size):
# avoid values near zero because gradients diverge
return rng.uniform(0.1, 5, size=size) * rng.choice([-1, 1], size=size)
rng = self.rng()
x = gen_x(rng, (5, 8))
y = gen_y(rng, (1, 8))
assert not set(np.unique(x)) & set(np.unique(y))
check_grads(lax.rem, (x, y), 2, ["fwd", "rev"])
rng = self.rng()
x = gen_x(rng, (1, 8))
y = gen_y(rng, (5, 8))
assert not set(np.unique(x)) & set(np.unique(y))
check_grads(lax.rem, (x, y), 2, ["fwd", "rev"])
def testHigherOrderGradientOfReciprocal(self):
# Regression test for https://github.com/jax-ml/jax/issues/3136
def inv(x):
# N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
return 1 / x
grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
self.assertAllClose(np.float32(0.0439453125), grad_fn(np.float32(4.)))
def test_linear_transpose_real(self):
f = lambda x: x.real
transpose = jax.linear_transpose(f, 1.j)
actual, = transpose(1.)
expected = 1.
self.assertEqual(actual, expected)
def test_linear_transpose_imag(self):
f = lambda x: x.imag
transpose = jax.linear_transpose(f, 1.j)
actual, = transpose(1.)
expected = -1.j
self.assertEqual(actual, expected)
def test_scatter_apply_jvp(self):
def f(x):
return x.at[1].apply(jax.numpy.sin)
x = jax.numpy.array([1.0, 2.0])
with self.assertRaises(NotImplementedError):
jax.jacfwd(f)(x)
def test_scatter_apply_vjp(self):
def f(x):
return x.at[1].apply(jax.numpy.sin)
x = jax.numpy.array([1.0, 2.0])
with self.assertRaises(NotImplementedError):
jax.jacrev(f)(x)
def testPowShapeMismatch(self):
# Regression test for https://github.com/jax-ml/jax/issues/17294
x = lax.iota('float32', 4)
y = 2
actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error
expected = jax.numpy.diag(y * x ** (y - 1))
self.assertArraysEqual(actual, expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())