mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

These docstrings do not make the tests any more clear and typically just duplicate the test module name. PiperOrigin-RevId: 737611977
846 lines
30 KiB
Python
846 lines
30 KiB
Python
# Copyright 2019 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import collections
|
|
from functools import partial
|
|
import itertools
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
|
|
import scipy.stats
|
|
|
|
from jax._src import ad_checkpoint
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dtypes as _dtypes
|
|
from jax._src import test_util as jtu
|
|
from jax._src.lib import cuda_versions
|
|
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
|
quantize,
|
|
shape_normalization,
|
|
BlockScaleConfig,
|
|
)
|
|
from jax.test_util import check_grads
|
|
from jax import nn
|
|
from jax import random
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
def _is_required_cudnn_version_satisfied(min_cc, min_cudnn_version):
|
|
return (
|
|
jtu.is_cuda_compute_capability_at_least(min_cc) and
|
|
cuda_versions is not None and
|
|
cuda_versions.cudnn_get_version() >= min_cudnn_version
|
|
)
|
|
|
|
def _check_cudnn_backend(fn, *args, **kwargs):
|
|
lowered = jax.jit(fn).lower(*args, **kwargs)
|
|
hlo = lowered.as_text('stablehlo', debug_info=True)
|
|
return '__cudnn$fmha' in hlo
|
|
|
|
_cudnn_dbias_error = 'cuDNN only supports bias gradient'
|
|
|
|
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
|
|
# Explicitly cast the max values to the compute dtype to avoid unnecessary
|
|
# casting to FP32 during the subsequent math operations."
|
|
assert q_dtype in (jnp.float8_e4m3fn, )
|
|
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
|
|
scaled_x = x / jnp.broadcast_to(
|
|
jnp.asarray(scale, dtype=compute_dtype), x.shape
|
|
)
|
|
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
|
|
return clipped_x.astype(q_dtype)
|
|
|
|
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
|
|
qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale)
|
|
out = qx.astype(x.dtype) * jnp.broadcast_to(
|
|
jnp.asarray(scale, dtype=x.dtype), qx.shape
|
|
)
|
|
return out
|
|
|
|
def _generate_quantized_tensors(
|
|
batch, lhs_non_contract, contract, rhs_non_contract,
|
|
configs, dtype=jnp.float32,
|
|
):
|
|
cast_to_representable = partial(
|
|
quantize_dequantize,
|
|
scale=jnp.ones((1,)),
|
|
compute_dtype=dtype,
|
|
)
|
|
|
|
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
|
|
|
a = cast_to_representable(
|
|
jax.random.uniform(
|
|
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
|
),
|
|
configs[0].data_type,
|
|
)
|
|
b = cast_to_representable(
|
|
jax.random.uniform(
|
|
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
|
),
|
|
configs[1].data_type,
|
|
)
|
|
|
|
dn = ((2,), (0,))
|
|
a_3d = shape_normalization(a, dn)
|
|
b_3d = shape_normalization(b, dn)
|
|
a_q, a_scales = quantize(a, configs[0])
|
|
b_q, b_scales = quantize(b, configs[1])
|
|
|
|
return a, b, a_q, b_q, a_scales, b_scales
|
|
|
|
def create_mxfp8_configs_if_available():
|
|
if _dtypes.float8_e8m0fnu is None:
|
|
raise unittest.SkipTest("float8_e8m0fnu is not available.")
|
|
|
|
def _create_mxfp8_config():
|
|
return BlockScaleConfig(
|
|
mode='mxfp8',
|
|
block_size=32,
|
|
data_type=jnp.float8_e4m3fn,
|
|
scale_type=jnp.float8_e8m0fnu,
|
|
global_scale=None,
|
|
infer_only=False
|
|
)
|
|
|
|
return [_create_mxfp8_config() for _ in range(3)]
|
|
|
|
|
|
@jtu.with_config(jax_legacy_prng_key="allow",
|
|
jax_numpy_dtype_promotion="standard")
|
|
class NNFunctionsTest(jtu.JaxTestCase):
|
|
@parameterized.product(
|
|
contract=[160, 96],
|
|
lhs_non_contract=[240, 100],
|
|
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
|
impl=['cudnn',],
|
|
)
|
|
def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
|
|
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
|
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
|
# Check if float8_e8m0fnu is available
|
|
configs = create_mxfp8_configs_if_available()
|
|
batch, rhs_non_contract = 4, 256
|
|
a, b, a_q, b_q, a_scales, b_scales = _generate_quantized_tensors(
|
|
batch, lhs_non_contract, contract, rhs_non_contract,
|
|
configs, dtype=dtype,
|
|
)
|
|
out = nn.scaled_matmul(a_q, b_q, a_scales, b_scales,
|
|
preferred_element_type=dtype)
|
|
out_ref = jnp.matmul(a.astype(jnp.float32),
|
|
jnp.transpose(b, (0, 2, 1)).astype(jnp.float32))
|
|
self.assertArraysAllClose(
|
|
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
|
|
)
|
|
|
|
@parameterized.product(
|
|
is_training=[True, False],
|
|
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
|
|
impl=['cudnn',],
|
|
)
|
|
def testScaledDotGeneral(
|
|
self, is_training, output_type, impl):
|
|
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
|
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
|
|
|
configs = create_mxfp8_configs_if_available()
|
|
cast_to_representable = partial(
|
|
quantize_dequantize,
|
|
scale=jnp.ones((1,)),
|
|
compute_dtype=jnp.float32,
|
|
)
|
|
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
|
a_shape = [2, 256, 96]
|
|
b_shape = [2, 96, 160]
|
|
dimension_numbers = (([2], [1]), ([0], [0]))
|
|
a = cast_to_representable(
|
|
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
|
|
configs[0].data_type,
|
|
)
|
|
b = cast_to_representable(
|
|
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
|
|
configs[1].data_type,
|
|
)
|
|
|
|
scaled_dot_general_fn = partial(
|
|
nn.scaled_dot_general, configs=configs
|
|
)
|
|
def fwd(a, b, is_ref=False):
|
|
fn = jax.lax.dot_general if is_ref else scaled_dot_general_fn
|
|
y = fn(a, b, dimension_numbers,
|
|
preferred_element_type=output_type)
|
|
return jnp.sum(y)
|
|
|
|
if is_training:
|
|
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
|
|
|
j_train_ref = jax.jit(
|
|
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
|
)
|
|
out, (x_grad, w_grad) = j_train(a, b)
|
|
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
|
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
|
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
|
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
|
else:
|
|
j_inference = jax.jit(fwd)
|
|
j_inference_ref = jax.jit(partial(fwd, is_ref=True))
|
|
out = j_inference(a, b)
|
|
out_ref = j_inference_ref(a, b)
|
|
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
|
|
|
@parameterized.product(
|
|
dtype=[jnp.bfloat16, jnp.float16],
|
|
group_num=[1, 2, 4],
|
|
use_vmap=[False, True],
|
|
impl=['cudnn', 'xla'],
|
|
)
|
|
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
|
|
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904):
|
|
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
|
if impl == 'cudnn' and dtype == jnp.float32:
|
|
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
|
|
|
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
|
|
keys = random.split(random.PRNGKey(0), 5)
|
|
Q = random.normal(keys[0], (B, T, N, H), dtype)
|
|
K = random.normal(keys[1], (B, S, N // G, H), dtype)
|
|
V = random.normal(keys[2], (B, S, N // G, H), dtype)
|
|
grad = random.normal(keys[3], (B, T, N, H), dtype)
|
|
bias, mask = None, None
|
|
|
|
sdpa = nn.dot_product_attention
|
|
sdpa_ref = partial(sdpa, implementation=None)
|
|
sdpa_ans = partial(sdpa, implementation=impl)
|
|
if use_vmap:
|
|
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
|
|
|
|
# For testing purposes, we call the non-GQA version without vmap in the
|
|
# reference code
|
|
K_ref = jnp.repeat(K, G, axis=2)
|
|
V_ref = jnp.repeat(V, G, axis=2)
|
|
out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask)
|
|
out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask)
|
|
|
|
dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3]
|
|
dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3]
|
|
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
|
|
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)
|
|
|
|
if impl == 'cudnn':
|
|
self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask))
|
|
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))
|
|
|
|
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
|
|
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
|
|
self.assertAllClose(dK_ref, dK_ans, rtol=.01, atol=.01)
|
|
self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01)
|
|
|
|
@parameterized.product(
|
|
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
|
|
('custom', 'padding'), ('bias', 'causal'),
|
|
('causal', 'sliding_window')],
|
|
)
|
|
def testDotProductAttentionMask(self, mask_mode):
|
|
if isinstance(mask_mode, str):
|
|
mask_mode = (mask_mode,)
|
|
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
|
|
if not _is_required_cudnn_version_satisfied("8.0", min_cudnn_version):
|
|
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
|
|
|
dtype = jnp.bfloat16
|
|
B, S, T, N, H = 2, 128, 128, 4, 32
|
|
keys = random.split(random.PRNGKey(0), 4)
|
|
Q = random.normal(keys[0], (B, T, N, H), dtype)
|
|
K = random.normal(keys[1], (B, S, N, H), dtype)
|
|
V = random.normal(keys[2], (B, S, N, H), dtype)
|
|
grad = random.normal(keys[3], (B, T, N, H), dtype)
|
|
bias, mask = None, None
|
|
q_seqlen, kv_seqlen = None, None
|
|
window_size = None
|
|
|
|
is_causal = 'causal' in mask_mode
|
|
if 'padding' in mask_mode:
|
|
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
|
|
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
|
|
if 'custom' in mask_mode:
|
|
# Use a generated causal mask as the custom mask.
|
|
custom_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
|
|
mask = custom_mask[None, None, :, :]
|
|
if 'bias' in mask_mode:
|
|
bias = random.normal(keys[4], (1, N, T, S), dtype)
|
|
if 'sliding_window' in mask_mode:
|
|
window_size = (3, 2) if is_causal else (3, 0)
|
|
|
|
sdpa = nn.dot_product_attention
|
|
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
|
|
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn')
|
|
|
|
args = (Q, K, V, bias, mask)
|
|
kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen}
|
|
|
|
# Convert the kargs to positional args for the jax.vjp.
|
|
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
|
|
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
|
|
local_window_size=window_size,
|
|
)
|
|
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
|
|
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
|
|
local_window_size=window_size,
|
|
)
|
|
out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen)
|
|
out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)
|
|
dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4]
|
|
dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4]
|
|
|
|
# Check if cudnn backend is called.
|
|
self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs))
|
|
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))
|
|
|
|
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
|
|
self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02)
|
|
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
|
|
self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01)
|
|
self.assertAllClose(dbias_ref, dbias_ans, rtol=.02, atol=.02)
|
|
|
|
@parameterized.product(
|
|
batch_size=[1, 16],
|
|
use_vmap=[False, True],
|
|
)
|
|
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
|
|
if not _is_required_cudnn_version_satisfied("8.0", 8904):
|
|
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
|
|
|
dtype = jnp.bfloat16
|
|
B, S, N, H = batch_size, 128, 4, 32
|
|
keys = random.split(random.PRNGKey(0), 2)
|
|
x = random.normal(keys[0], (B, S, N, H), dtype)
|
|
bias = random.normal(keys[1], (B, N, S, S), dtype=dtype)
|
|
mask = jnp.ones((1, 1, S), dtype=jnp.bool_)
|
|
|
|
def attention(x, bias, mask, impl):
|
|
return jax.nn.dot_product_attention(
|
|
query=x,
|
|
key=x,
|
|
value=x,
|
|
bias=bias,
|
|
mask=mask,
|
|
is_causal=False,
|
|
implementation=impl,
|
|
)
|
|
attn_ref = partial(attention, impl=None)
|
|
attn_ans = partial(attention, impl='cudnn')
|
|
if use_vmap:
|
|
attn_batched_ref = jax.vmap(attn_ref, in_axes=(0, 0, None))
|
|
attn_batched_ans = jax.vmap(attn_ans, in_axes=(0, 0, None))
|
|
else:
|
|
attn_batched_ref = attn_ref
|
|
attn_batched_ans = attn_ans
|
|
|
|
fwd_ref = jax.jit(attn_batched_ref)
|
|
fwd_ans = jax.jit(attn_batched_ans)
|
|
y_ref = fwd_ref(x, bias, mask)
|
|
y_ans = fwd_ans(x, bias, mask)
|
|
self.assertAllClose(y_ref, y_ans)
|
|
|
|
@jax.jit
|
|
def bwd_ref(x, bias, mask):
|
|
_, f_vjp = jax.vjp(attn_ref, x, bias, mask)
|
|
return f_vjp(x)
|
|
@jax.jit
|
|
def bwd_ans(x, bias, mask):
|
|
_, f_vjp = jax.vjp(attn_ans, x, bias, mask)
|
|
return f_vjp(x)
|
|
|
|
if batch_size != 1:
|
|
with self.assertRaisesRegex(ValueError, _cudnn_dbias_error):
|
|
_, dbias_ans, _ = bwd_ans(x, bias, mask)
|
|
else:
|
|
_, dbias_ref, _ = bwd_ref(x, bias, mask)
|
|
_, dbias_ans, _ = bwd_ans(x, bias, mask)
|
|
self.assertAllClose(dbias_ans, dbias_ref, rtol=0.1, atol=0.1)
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def testSoftplusGrad(self):
|
|
check_grads(nn.softplus, (1e-8,), order=4,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testSoftplusGradZero(self):
|
|
check_grads(nn.softplus, (0.,), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testSoftplusGradInf(self):
|
|
self.assertAllClose(
|
|
1., jax.grad(nn.softplus)(float('inf')))
|
|
|
|
def testSoftplusGradNegInf(self):
|
|
check_grads(nn.softplus, (-float('inf'),), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testSoftplusGradNan(self):
|
|
check_grads(nn.softplus, (float('nan'),), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
@parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
|
|
def testSoftplusZero(self, dtype):
|
|
self.assertEqual(jnp.log(dtype(2)), nn.softplus(dtype(0)))
|
|
|
|
def testSparseplusGradZero(self):
|
|
check_grads(nn.sparse_plus, (-2.,), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testSparseplusGrad(self):
|
|
check_grads(nn.sparse_plus, (0.,), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testSparseplusAndSparseSigmoid(self):
|
|
self.assertAllClose(
|
|
jax.grad(nn.sparse_plus)(0.), nn.sparse_sigmoid(0.),
|
|
check_dtypes=False)
|
|
self.assertAllClose(
|
|
jax.grad(nn.sparse_plus)(2.), nn.sparse_sigmoid(2.),
|
|
check_dtypes=False)
|
|
self.assertAllClose(
|
|
jax.grad(nn.sparse_plus)(-2.), nn.sparse_sigmoid(-2.),
|
|
check_dtypes=False)
|
|
|
|
def testSquareplusGrad(self):
|
|
check_grads(nn.squareplus, (1e-8,), order=4,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testSquareplusGradZero(self):
|
|
check_grads(nn.squareplus, (0.,), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testSquareplusGradNegInf(self):
|
|
check_grads(nn.squareplus, (-float('inf'),), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testSquareplusGradNan(self):
|
|
check_grads(nn.squareplus, (float('nan'),), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
@parameterized.parameters([float] + jtu.dtypes.floating)
|
|
def testSquareplusZero(self, dtype):
|
|
self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4)))
|
|
|
|
def testMishGrad(self):
|
|
check_grads(nn.mish, (1e-8,), order=4,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testMishGradZero(self):
|
|
check_grads(nn.mish, (0.,), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testMishGradNegInf(self):
|
|
check_grads(nn.mish, (-float('inf'),), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
def testMishGradNan(self):
|
|
check_grads(nn.mish, (float('nan'),), order=1,
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
@parameterized.parameters([float] + jtu.dtypes.floating)
|
|
def testMishZero(self, dtype):
|
|
self.assertEqual(dtype(0), nn.mish(dtype(0)))
|
|
|
|
def testReluGrad(self):
|
|
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
|
|
check_grads(nn.relu, (1.,), order=3, rtol=rtol)
|
|
check_grads(nn.relu, (-1.,), order=3, rtol=rtol)
|
|
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
|
|
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
|
|
|
|
def testReluGradAtZero(self):
|
|
# https://dl.acm.org/doi/10.5555/3540261.3540297
|
|
grad = jax.grad(nn.relu)(0.)
|
|
self.assertEqual(grad, 0.)
|
|
|
|
def testRelu6Grad(self):
|
|
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
|
|
check_grads(nn.relu6, (1.,), order=3, rtol=rtol)
|
|
check_grads(nn.relu6, (-1.,), order=3, rtol=rtol)
|
|
self.assertAllClose(jax.grad(nn.relu6)(0.), 0., check_dtypes=False)
|
|
self.assertAllClose(jax.grad(nn.relu6)(6.), 0., check_dtypes=False)
|
|
|
|
def testSoftplusValue(self):
|
|
val = nn.softplus(89.)
|
|
self.assertAllClose(val, 89., check_dtypes=False)
|
|
|
|
def testSparseplusValue(self):
|
|
val = nn.sparse_plus(89.)
|
|
self.assertAllClose(val, 89., check_dtypes=False)
|
|
|
|
def testSparsesigmoidValue(self):
|
|
self.assertAllClose(nn.sparse_sigmoid(-2.), 0., check_dtypes=False)
|
|
self.assertAllClose(nn.sparse_sigmoid(2.), 1., check_dtypes=False)
|
|
self.assertAllClose(nn.sparse_sigmoid(0.), .5, check_dtypes=False)
|
|
|
|
def testSquareplusValue(self):
|
|
val = nn.squareplus(1e3)
|
|
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
|
|
|
def testMishValue(self):
|
|
val = nn.mish(1e3)
|
|
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
|
def testEluGrad(self):
|
|
check_grads(nn.elu, (1e4,), order=4, eps=1.)
|
|
|
|
def testEluValue(self):
|
|
val = nn.elu(1e4)
|
|
self.assertAllClose(val, 1e4, check_dtypes=False)
|
|
|
|
def testGluValue(self):
|
|
val = nn.glu(jnp.array([1.0, 0.0]), axis=0)
|
|
self.assertAllClose(val, jnp.array([0.5]))
|
|
|
|
@parameterized.parameters(False, True)
|
|
def testGeluIntType(self, approximate):
|
|
val_float = nn.gelu(jnp.array(-1.0), approximate=approximate)
|
|
val_int = nn.gelu(jnp.array(-1), approximate=approximate)
|
|
self.assertAllClose(val_float, val_int)
|
|
|
|
@parameterized.parameters(False, True)
|
|
def testGelu(self, approximate):
|
|
def gelu_reference(x):
|
|
return x * scipy.stats.norm.cdf(x)
|
|
args_maker = lambda: [jnp.linspace(-12, 5, 10000, dtype=jnp.float32)]
|
|
rtol = 2e-5
|
|
atol = 1e-3 if approximate else 0
|
|
self._CheckAgainstNumpy(
|
|
gelu_reference,
|
|
partial(nn.gelu, approximate=approximate),
|
|
args_maker,
|
|
check_dtypes=False,
|
|
tol=0,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
)
|
|
|
|
@parameterized.parameters(*itertools.product(
|
|
(jnp.float32, jnp.bfloat16, jnp.float16),
|
|
(partial(nn.gelu, approximate=False),
|
|
partial(nn.gelu, approximate=True),
|
|
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
|
|
def testDtypeMatchesInput(self, dtype, fn):
|
|
x = jnp.zeros((), dtype=dtype)
|
|
out = fn(x)
|
|
self.assertEqual(out.dtype, dtype)
|
|
|
|
def testEluMemory(self):
|
|
# see https://github.com/jax-ml/jax/pull/1640
|
|
with jax.enable_checks(False): # With checks we materialize the array
|
|
jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom
|
|
|
|
def testHardTanhMemory(self):
|
|
# see https://github.com/jax-ml/jax/pull/1640
|
|
with jax.enable_checks(False): # With checks we materialize the array
|
|
jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
|
|
|
|
@parameterized.parameters([nn.softmax, nn.log_softmax])
|
|
def testSoftmaxEmptyArray(self, fn):
|
|
x = jnp.array([], dtype=float)
|
|
self.assertArraysEqual(fn(x), x)
|
|
|
|
@parameterized.parameters([nn.softmax, nn.log_softmax])
|
|
def testSoftmaxEmptyMask(self, fn):
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
m = jnp.zeros_like(x, dtype=bool)
|
|
expected = jnp.full_like(x, 0.0 if fn is nn.softmax else -jnp.inf)
|
|
self.assertArraysEqual(fn(x, where=m), expected)
|
|
|
|
@parameterized.parameters([nn.softmax, nn.log_softmax])
|
|
def testSoftmaxWhereMask(self, fn):
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
m = jnp.array([True, False, True, True])
|
|
|
|
out = fn(x, where=m)
|
|
self.assertAllClose(out[m], fn(x[m]))
|
|
|
|
probs = out if fn is nn.softmax else jnp.exp(out)
|
|
self.assertAllClose(probs.sum(), 1.0)
|
|
|
|
# TODO(mattjj): include log_softmax in these extra tests if/when we add a
|
|
# custom_jvp rule for it (since otherwise it doesn't pass the numerical
|
|
# checks below).
|
|
if fn is nn.softmax and config.softmax_custom_jvp.value:
|
|
g_fun = lambda x: jnp.take(fn(x, where=m, initial=-jnp.inf),
|
|
jnp.array([0, 2, 3]))
|
|
jtu.check_grads(g_fun, (x,), order=2)
|
|
|
|
@parameterized.parameters([nn.softmax, nn.log_softmax])
|
|
def testSoftmaxWhereGrad(self, fn):
|
|
# regression test for https://github.com/jax-ml/jax/issues/19490
|
|
x = jnp.array([36., 10000.])
|
|
mask = x < 1000
|
|
|
|
f = lambda x, mask: fn(x, where=mask)[0]
|
|
|
|
self.assertAllClose(jax.grad(f)(x, mask), jnp.zeros_like(x))
|
|
|
|
def testSoftmaxGrad(self):
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
jtu.check_grads(nn.softmax, (x,), order=2, atol=5e-3)
|
|
|
|
def testSoftmaxGradResiduals(self):
|
|
if not config.softmax_custom_jvp.value:
|
|
raise unittest.SkipTest("only applies when upgrade flag enabled")
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
res = ad_checkpoint.saved_residuals(nn.softmax, x)
|
|
self.assertLen(res, 1)
|
|
|
|
def testSoftmaxGradFlag(self):
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
|
|
with jax.softmax_custom_jvp(False):
|
|
res = ad_checkpoint.saved_residuals(nn.softmax, x)
|
|
self.assertLen(res, 3)
|
|
self.assertEqual(sum(a.size for a, _ in res), 6)
|
|
|
|
with jax.softmax_custom_jvp(True):
|
|
res = ad_checkpoint.saved_residuals(nn.softmax, x)
|
|
self.assertLen(res, 1)
|
|
self.assertEqual(sum(a.size for a, _ in res), 4)
|
|
|
|
def testStandardizeWhereMask(self):
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
m = jnp.array([True, False, True, True])
|
|
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
|
|
|
out_masked = jnp.take(nn.standardize(x, where=m), jnp.array([0, 2, 3]))
|
|
out_filtered = nn.standardize(x_filtered)
|
|
|
|
self.assertAllClose(out_masked, out_filtered)
|
|
|
|
def testOneHot(self):
|
|
actual = nn.one_hot(jnp.array([0, 1, 2]), 3)
|
|
expected = jnp.array([[1., 0., 0.],
|
|
[0., 1., 0.],
|
|
[0., 0., 1.]])
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
|
|
|
actual = nn.one_hot(jnp.array([1, 2, 0]), 3)
|
|
expected = jnp.array([[0., 1., 0.],
|
|
[0., 0., 1.],
|
|
[1., 0., 0.]])
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
|
|
|
def testOneHotOutOfBound(self):
|
|
actual = nn.one_hot(jnp.array([-1, 3]), 3)
|
|
expected = jnp.array([[0., 0., 0.],
|
|
[0., 0., 0.]])
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
|
|
|
def testOneHotNonArrayInput(self):
|
|
actual = nn.one_hot([0, 1, 2], 3)
|
|
expected = jnp.array([[1., 0., 0.],
|
|
[0., 1., 0.],
|
|
[0., 0., 1.]])
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
|
|
|
def testOneHotCustomDtype(self):
|
|
actual = nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
|
|
expected = jnp.array([[True, False, False],
|
|
[False, True, False],
|
|
[False, False, True]])
|
|
self.assertAllClose(actual, expected)
|
|
|
|
def testOneHotConcretizationError(self):
|
|
# https://github.com/jax-ml/jax/issues/3654
|
|
msg = r"in jax.nn.one_hot argument `num_classes`"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
jax.jit(nn.one_hot)(3, 5)
|
|
|
|
def testOneHotAxis(self):
|
|
expected = jnp.array([[0., 1., 0.],
|
|
[0., 0., 1.],
|
|
[1., 0., 0.]]).T
|
|
|
|
actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
|
|
|
actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
|
|
|
def testOneHotNonInteger(self):
|
|
with self.assertDeprecationWarnsOrRaises("jax-nn-one-hot-float-input",
|
|
"jax.nn.one_hot input should be integer-typed"):
|
|
nn.one_hot(jnp.array([1.0]), 3)
|
|
|
|
def testTanhExists(self):
|
|
nn.tanh # doesn't crash
|
|
|
|
def testCustomJVPLeak(self):
|
|
# https://github.com/jax-ml/jax/issues/8171
|
|
@jax.jit
|
|
def fwd():
|
|
a = jnp.array(1.)
|
|
|
|
def f(hx, _):
|
|
hx = jax.nn.sigmoid(hx + a)
|
|
return hx, None
|
|
|
|
hx = jnp.array(0.)
|
|
jax.lax.scan(f, hx, None, length=2)
|
|
|
|
with jax.checking_leaks():
|
|
fwd() # doesn't crash
|
|
|
|
def testCustomJVPLeak2(self):
|
|
# https://github.com/jax-ml/jax/issues/8171
|
|
# The above test uses jax.nn.sigmoid, as in the original #8171, but that
|
|
# function no longer actually has a custom_jvp! So we inline the old def.
|
|
|
|
@jax.custom_jvp
|
|
def sigmoid(x):
|
|
one = jnp.float32(1)
|
|
return jax.lax.div(one, jax.lax.add(one, jax.lax.exp(jax.lax.neg(x))))
|
|
sigmoid.defjvps(lambda g, ans, x: g * ans * (jnp.float32(1) - ans))
|
|
|
|
@jax.jit
|
|
def fwd():
|
|
a = jnp.array(1., 'float32')
|
|
|
|
def f(hx, _):
|
|
hx = sigmoid(hx + a)
|
|
return hx, None
|
|
|
|
hx = jnp.array(0., 'float32')
|
|
jax.lax.scan(f, hx, None, length=2)
|
|
|
|
with jax.checking_leaks():
|
|
fwd() # doesn't crash
|
|
|
|
|
|
InitializerRecord = collections.namedtuple(
|
|
"InitializerRecord",
|
|
["name", "initializer", "shapes", "dtypes"])
|
|
|
|
ALL_SHAPES = [(2,), (2, 2), (2, 3), (3, 2), (2, 3, 4), (4, 3, 2), (2, 3, 4, 5)]
|
|
|
|
def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4):
|
|
shapes = [shape for shape in ALL_SHAPES
|
|
if min_dims <= len(shape) <= max_dims]
|
|
return InitializerRecord(name, initializer, shapes, dtypes)
|
|
|
|
INITIALIZER_RECS = [
|
|
initializer_record("uniform", nn.initializers.uniform, jtu.dtypes.floating, 1),
|
|
initializer_record("normal", nn.initializers.normal, jtu.dtypes.inexact, 1),
|
|
initializer_record("he_normal", nn.initializers.he_normal, jtu.dtypes.inexact),
|
|
initializer_record("he_uniform", nn.initializers.he_uniform, jtu.dtypes.inexact),
|
|
initializer_record("glorot_normal", nn.initializers.glorot_normal, jtu.dtypes.inexact),
|
|
initializer_record("glorot_uniform", nn.initializers.glorot_uniform, jtu.dtypes.inexact),
|
|
initializer_record("lecun_normal", nn.initializers.lecun_normal, jtu.dtypes.inexact),
|
|
initializer_record("lecun_uniform", nn.initializers.lecun_uniform, jtu.dtypes.inexact),
|
|
initializer_record("orthogonal", nn.initializers.orthogonal, jtu.dtypes.floating, 2, 2),
|
|
initializer_record("truncated_normal", nn.initializers.truncated_normal, jtu.dtypes.floating, 1),
|
|
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4),
|
|
initializer_record(
|
|
"variance_scaling_fan_geo_avg",
|
|
partial(nn.initializers.variance_scaling, 1, "fan_geo_avg", "normal"),
|
|
jtu.dtypes.floating,
|
|
),
|
|
]
|
|
|
|
|
|
@jtu.with_config(jax_legacy_prng_key="allow")
|
|
class NNInitializersTest(jtu.JaxTestCase):
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(initializer=rec.initializer())],
|
|
shape=rec.shapes,
|
|
dtype=rec.dtypes
|
|
)
|
|
for rec in INITIALIZER_RECS
|
|
))
|
|
def testInitializer(self, initializer, shape, dtype):
|
|
rng = random.PRNGKey(0)
|
|
val = initializer(rng, shape, dtype)
|
|
|
|
self.assertEqual(shape, jnp.shape(val))
|
|
self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))
|
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
jtu.sample_product_testcases(
|
|
[dict(initializer_provider=rec.initializer)],
|
|
shape=rec.shapes,
|
|
dtype=rec.dtypes
|
|
)
|
|
for rec in INITIALIZER_RECS
|
|
))
|
|
def testInitializerProvider(self, initializer_provider, shape, dtype):
|
|
rng = random.PRNGKey(0)
|
|
initializer = initializer_provider(dtype=dtype)
|
|
val = initializer(rng, shape)
|
|
|
|
self.assertEqual(shape, jnp.shape(val))
|
|
self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))
|
|
|
|
def testVarianceScalingMultiAxis(self):
|
|
rng = random.PRNGKey(0)
|
|
shape = (2, 3, 4, 5)
|
|
initializer = nn.initializers.variance_scaling(
|
|
scale=1.0, mode='fan_avg', distribution='truncated_normal',
|
|
in_axis=(0, 1), out_axis=(-2, -1))
|
|
val = initializer(rng, shape)
|
|
|
|
self.assertEqual(shape, jnp.shape(val))
|
|
|
|
def testVarianceScalingBatchAxis(self):
|
|
rng = random.PRNGKey(0)
|
|
shape = (2, 3, 4, 5)
|
|
initializer = nn.initializers.variance_scaling(
|
|
scale=1.0, mode='fan_avg', distribution='truncated_normal',
|
|
in_axis=0, out_axis=(2, 3), batch_axis=1)
|
|
val = initializer(rng, shape)
|
|
|
|
self.assertEqual(shape, jnp.shape(val))
|
|
|
|
def testVarianceScalingError(self):
|
|
rng = random.PRNGKey(0)
|
|
shape = (5,)
|
|
initializer = nn.initializers.variance_scaling(
|
|
scale=1.0, mode='fan_avg', distribution='truncated_normal')
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Can't compute input and output sizes of a 1"
|
|
"-dimensional weights tensor. Must be at least 2D."
|
|
):
|
|
initializer(rng, shape)
|
|
|
|
def testAccidentalUpcasting(self):
|
|
rng = random.PRNGKey(0)
|
|
shape = (4, 4)
|
|
scalar_param = jnp.array(1.0, dtype=jnp.float32)
|
|
for init_fn in (nn.initializers.uniform(scalar_param, jnp.bfloat16),
|
|
nn.initializers.normal(scalar_param, jnp.bfloat16),
|
|
nn.initializers.truncated_normal(scalar_param, jnp.bfloat16),
|
|
):
|
|
sub_rng, rng = random.split(rng)
|
|
val = init_fn(sub_rng, shape)
|
|
self.assertEqual(val.dtype, jnp.bfloat16)
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|