2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2019 The JAX Authors.
|
2019-09-21 01:04:26 -04:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
"""Tests for nn module."""
|
|
|
|
|
2019-10-03 12:01:21 -07:00
|
|
|
import collections
|
2020-10-02 09:48:07 -04:00
|
|
|
from functools import partial
|
2019-10-03 12:01:21 -07:00
|
|
|
import itertools
|
2023-04-19 18:11:35 -07:00
|
|
|
import unittest
|
2019-10-03 12:01:21 -07:00
|
|
|
|
2019-09-21 01:04:26 -04:00
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2020-10-02 09:48:07 -04:00
|
|
|
import scipy.stats
|
2019-09-21 01:04:26 -04:00
|
|
|
|
2023-10-11 08:45:30 -07:00
|
|
|
from jax._src import config
|
2023-02-14 23:00:40 -08:00
|
|
|
from jax._src import core
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2023-04-19 18:11:35 -07:00
|
|
|
from jax._src import ad_checkpoint
|
2024-07-08 06:15:20 -07:00
|
|
|
from jax._src.interpreters import mlir
|
|
|
|
from jax._src.lib import cuda_versions
|
2019-09-21 01:04:26 -04:00
|
|
|
from jax.test_util import check_grads
|
|
|
|
from jax import nn
|
|
|
|
from jax import random
|
2019-10-21 11:48:58 +00:00
|
|
|
import jax
|
2020-05-05 14:59:16 -04:00
|
|
|
import jax.numpy as jnp
|
2019-09-21 01:04:26 -04:00
|
|
|
|
2023-10-12 13:15:22 +01:00
|
|
|
config.parse_flags_with_absl()
|
2019-09-21 01:04:26 -04:00
|
|
|
|
2024-08-26 17:32:38 +00:00
|
|
|
def _is_required_cudnn_version_satisfied(min_cudnn_version):
|
2024-07-08 06:15:20 -07:00
|
|
|
return (
|
|
|
|
jtu.is_cuda_compute_capability_at_least("8.0") and
|
|
|
|
cuda_versions is not None and
|
2024-08-26 17:32:38 +00:00
|
|
|
cuda_versions.cudnn_get_version() >= min_cudnn_version
|
2024-07-08 06:15:20 -07:00
|
|
|
)
|
2020-05-01 18:00:38 +01:00
|
|
|
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
def _check_cudnn_backend(fn, *args, **kwargs):
|
|
|
|
lowered = jax.jit(fn).lower(*args, **kwargs)
|
|
|
|
hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
|
|
|
return '__cudnn$fmha' in hlo
|
2024-07-08 06:15:20 -07:00
|
|
|
|
2024-09-18 21:23:16 +00:00
|
|
|
_cudnn_dbias_error = 'cuDNN only supports bias gradient'
|
|
|
|
|
2024-07-08 06:15:20 -07:00
|
|
|
@jtu.with_config(jax_legacy_prng_key="allow",
|
|
|
|
jax_numpy_dtype_promotion="standard")
|
2019-10-03 12:01:21 -07:00
|
|
|
class NNFunctionsTest(jtu.JaxTestCase):
|
2024-07-08 06:15:20 -07:00
|
|
|
@parameterized.product(
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
dtype=[jnp.bfloat16, jnp.float16],
|
2024-07-15 22:07:08 +00:00
|
|
|
group_num=[1, 2, 4],
|
2024-08-01 19:39:34 +00:00
|
|
|
use_vmap=[False, True],
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
impl=['cudnn', 'xla'],
|
2024-07-08 06:15:20 -07:00
|
|
|
)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
|
2024-08-26 17:32:38 +00:00
|
|
|
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
|
2024-07-08 06:15:20 -07:00
|
|
|
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
|
|
|
if impl == 'cudnn' and dtype == jnp.float32:
|
|
|
|
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
|
|
|
|
2024-07-15 22:07:08 +00:00
|
|
|
B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
keys = random.split(random.PRNGKey(0), 5)
|
2024-07-08 06:15:20 -07:00
|
|
|
Q = random.normal(keys[0], (B, T, N, H), dtype)
|
2024-07-15 22:07:08 +00:00
|
|
|
K = random.normal(keys[1], (B, S, N // G, H), dtype)
|
|
|
|
V = random.normal(keys[2], (B, S, N // G, H), dtype)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
grad = random.normal(keys[3], (B, T, N, H), dtype)
|
|
|
|
bias, mask = None, None
|
2024-07-08 06:15:20 -07:00
|
|
|
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
sdpa = nn.dot_product_attention
|
|
|
|
sdpa_ref = partial(sdpa, implementation=None)
|
|
|
|
sdpa_ans = partial(sdpa, implementation=impl)
|
|
|
|
if use_vmap:
|
|
|
|
sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0)
|
2024-07-08 06:15:20 -07:00
|
|
|
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
# For testing purposes, we call the non-GQA version without vmap in the
|
|
|
|
# reference code
|
|
|
|
K_ref = jnp.repeat(K, G, axis=2)
|
|
|
|
V_ref = jnp.repeat(V, G, axis=2)
|
|
|
|
out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask)
|
|
|
|
out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask)
|
|
|
|
|
|
|
|
dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3]
|
|
|
|
dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3]
|
|
|
|
dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3)
|
|
|
|
dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3)
|
2024-07-08 06:15:20 -07:00
|
|
|
|
|
|
|
if impl == 'cudnn':
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask))
|
|
|
|
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))
|
2024-07-15 22:07:08 +00:00
|
|
|
|
2024-07-08 06:15:20 -07:00
|
|
|
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01)
|
2024-10-21 17:00:04 +00:00
|
|
|
self.assertAllClose(dK_ref, dK_ans, rtol=.01, atol=.01)
|
|
|
|
self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01)
|
2024-07-08 06:15:20 -07:00
|
|
|
|
|
|
|
@parameterized.product(
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'),
|
2024-08-26 17:32:38 +00:00
|
|
|
('custom', 'padding'), ('bias', 'causal'),
|
|
|
|
('causal', 'sliding_window')],
|
2024-07-08 06:15:20 -07:00
|
|
|
)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
def testDotProductAttentionMask(self, mask_mode):
|
|
|
|
if isinstance(mask_mode, str):
|
|
|
|
mask_mode = (mask_mode,)
|
2024-08-26 17:32:38 +00:00
|
|
|
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
|
|
|
|
if not _is_required_cudnn_version_satisfied(min_cudnn_version):
|
|
|
|
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
2024-07-08 06:15:20 -07:00
|
|
|
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
dtype = jnp.bfloat16
|
|
|
|
B, S, T, N, H = 2, 128, 128, 4, 32
|
|
|
|
keys = random.split(random.PRNGKey(0), 4)
|
2024-07-08 06:15:20 -07:00
|
|
|
Q = random.normal(keys[0], (B, T, N, H), dtype)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
K = random.normal(keys[1], (B, S, N, H), dtype)
|
|
|
|
V = random.normal(keys[2], (B, S, N, H), dtype)
|
2024-07-08 06:15:20 -07:00
|
|
|
grad = random.normal(keys[3], (B, T, N, H), dtype)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
bias, mask = None, None
|
|
|
|
q_seqlen, kv_seqlen = None, None
|
2024-08-26 17:32:38 +00:00
|
|
|
window_size = None
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
|
|
|
|
is_causal = 'causal' in mask_mode
|
|
|
|
if 'padding' in mask_mode:
|
2024-07-18 17:03:49 +00:00
|
|
|
q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32)
|
|
|
|
kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
if 'custom' in mask_mode:
|
|
|
|
# Use a generated causal mask as the custom mask.
|
|
|
|
custom_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
|
|
|
|
mask = custom_mask[None, None, :, :]
|
|
|
|
if 'bias' in mask_mode:
|
|
|
|
bias = random.normal(keys[4], (1, N, T, S), dtype)
|
2024-08-26 17:32:38 +00:00
|
|
|
if 'sliding_window' in mask_mode:
|
|
|
|
window_size = (3, 2) if is_causal else (3, 0)
|
2024-07-08 06:15:20 -07:00
|
|
|
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
sdpa = nn.dot_product_attention
|
2024-07-08 06:15:20 -07:00
|
|
|
sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn')
|
|
|
|
|
|
|
|
args = (Q, K, V, bias, mask)
|
|
|
|
kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen}
|
|
|
|
|
|
|
|
# Convert the kargs to positional args for the jax.vjp.
|
2024-07-18 17:03:49 +00:00
|
|
|
fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref(
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
|
2024-08-26 17:32:38 +00:00
|
|
|
local_window_size=window_size,
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
)
|
|
|
|
fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans(
|
|
|
|
q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs,
|
2024-08-26 17:32:38 +00:00
|
|
|
local_window_size=window_size,
|
2024-07-18 17:03:49 +00:00
|
|
|
)
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen)
|
|
|
|
out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen)
|
2024-07-18 17:03:49 +00:00
|
|
|
dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4]
|
|
|
|
dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4]
|
2024-07-08 06:15:20 -07:00
|
|
|
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
# Check if cudnn backend is called.
|
|
|
|
self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs))
|
|
|
|
self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad))
|
2024-07-08 06:15:20 -07:00
|
|
|
|
PR #23223: [NVIDIA] Reduce number of tests for `jax.nn.dot_product_attention`
Imported from GitHub PR https://github.com/google/jax/pull/23223
While adding the new mask mode, `sliding_window`, I noticed that the number of tests has become quite large. Currently, every time we introduce a new option, it requires testing all possible combinations with existing options, which makes the number of tests increase exponentially. For example, we already have 864 parameterized tests for this API. This PR aims to address this issue by reducing the number of tests through grouping.
For the new tests, we categorize them as follows:
1. **Non-mask tests:** These verify the basic functionality of the API, including data types, `vmap`, groups, etc.
2. **Mask tests:** These cover different masking scenarios, such as causal, padding, or other commonly used combinations.
Additionally, we will no longer maintain separate tests for inference and training.
Copybara import of the project:
--
dd2ca197431429ae59f9af506d481cbb40ae14e5 by kaixih <kaixih@nvidia.com>:
Reduce attn tests
Merging this change closes #23223
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/23223 from kaixih:reduce_attn_tests dd2ca197431429ae59f9af506d481cbb40ae14e5
PiperOrigin-RevId: 669364738
2024-08-30 10:11:19 -07:00
|
|
|
self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01)
|
2024-10-21 17:00:04 +00:00
|
|
|
self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02)
|
2024-07-15 22:07:08 +00:00
|
|
|
self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02)
|
2024-10-21 17:00:04 +00:00
|
|
|
self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01)
|
|
|
|
self.assertAllClose(dbias_ref, dbias_ans, rtol=.02, atol=.02)
|
2024-07-08 06:15:20 -07:00
|
|
|
|
2024-09-18 21:23:16 +00:00
|
|
|
@parameterized.product(
|
|
|
|
batch_size=[1, 16],
|
|
|
|
use_vmap=[False, True],
|
|
|
|
)
|
|
|
|
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
|
|
|
|
if not _is_required_cudnn_version_satisfied(8904):
|
|
|
|
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
|
|
|
|
|
|
|
dtype = jnp.bfloat16
|
|
|
|
B, S, N, H = batch_size, 128, 4, 32
|
|
|
|
keys = random.split(random.PRNGKey(0), 2)
|
|
|
|
x = random.normal(keys[0], (B, S, N, H), dtype)
|
|
|
|
bias = random.normal(keys[1], (B, N, S, S), dtype=dtype)
|
|
|
|
mask = jnp.ones((1, 1, S), dtype=jnp.bool_)
|
|
|
|
|
|
|
|
def attention(x, bias, mask, impl):
|
|
|
|
return jax.nn.dot_product_attention(
|
|
|
|
query=x,
|
|
|
|
key=x,
|
|
|
|
value=x,
|
|
|
|
bias=bias,
|
|
|
|
mask=mask,
|
|
|
|
is_causal=False,
|
|
|
|
implementation=impl,
|
|
|
|
)
|
|
|
|
attn_ref = partial(attention, impl=None)
|
|
|
|
attn_ans = partial(attention, impl='cudnn')
|
|
|
|
if use_vmap:
|
|
|
|
attn_batched_ref = jax.vmap(attn_ref, in_axes=(0, 0, None))
|
|
|
|
attn_batched_ans = jax.vmap(attn_ans, in_axes=(0, 0, None))
|
|
|
|
else:
|
|
|
|
attn_batched_ref = attn_ref
|
|
|
|
attn_batched_ans = attn_ans
|
|
|
|
|
|
|
|
fwd_ref = jax.jit(attn_batched_ref)
|
|
|
|
fwd_ans = jax.jit(attn_batched_ans)
|
|
|
|
y_ref = fwd_ref(x, bias, mask)
|
|
|
|
y_ans = fwd_ans(x, bias, mask)
|
|
|
|
self.assertAllClose(y_ref, y_ans)
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def bwd_ref(x, bias, mask):
|
|
|
|
_, f_vjp = jax.vjp(attn_ref, x, bias, mask)
|
|
|
|
return f_vjp(x)
|
|
|
|
@jax.jit
|
|
|
|
def bwd_ans(x, bias, mask):
|
|
|
|
_, f_vjp = jax.vjp(attn_ans, x, bias, mask)
|
|
|
|
return f_vjp(x)
|
|
|
|
|
|
|
|
if batch_size != 1:
|
|
|
|
with self.assertRaisesRegex(ValueError, _cudnn_dbias_error):
|
|
|
|
_, dbias_ans, _ = bwd_ans(x, bias, mask)
|
|
|
|
else:
|
|
|
|
_, dbias_ref, _ = bwd_ref(x, bias, mask)
|
|
|
|
_, dbias_ans, _ = bwd_ans(x, bias, mask)
|
2024-10-21 17:00:04 +00:00
|
|
|
self.assertAllClose(dbias_ans, dbias_ref, rtol=.02, atol=.02)
|
2024-09-18 21:23:16 +00:00
|
|
|
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-09-21 01:04:26 -04:00
|
|
|
def testSoftplusGrad(self):
|
2020-03-03 16:27:53 -08:00
|
|
|
check_grads(nn.softplus, (1e-8,), order=4,
|
2023-09-27 12:10:06 -07:00
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
2019-10-03 12:01:21 -07:00
|
|
|
|
2020-04-13 09:44:13 -07:00
|
|
|
def testSoftplusGradZero(self):
|
|
|
|
check_grads(nn.softplus, (0.,), order=1,
|
2023-09-27 12:10:06 -07:00
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
2020-04-13 09:44:13 -07:00
|
|
|
|
|
|
|
def testSoftplusGradInf(self):
|
|
|
|
self.assertAllClose(
|
2020-06-01 17:19:23 -04:00
|
|
|
1., jax.grad(nn.softplus)(float('inf')))
|
2020-04-13 09:44:13 -07:00
|
|
|
|
|
|
|
def testSoftplusGradNegInf(self):
|
|
|
|
check_grads(nn.softplus, (-float('inf'),), order=1,
|
2023-09-27 12:10:06 -07:00
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
2020-04-13 09:44:13 -07:00
|
|
|
|
|
|
|
def testSoftplusGradNan(self):
|
|
|
|
check_grads(nn.softplus, (float('nan'),), order=1,
|
2023-09-27 12:10:06 -07:00
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
2020-04-13 09:44:13 -07:00
|
|
|
|
2020-12-08 13:03:30 -08:00
|
|
|
@parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
|
2020-04-13 09:44:13 -07:00
|
|
|
def testSoftplusZero(self, dtype):
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertEqual(jnp.log(dtype(2)), nn.softplus(dtype(0)))
|
2020-04-13 09:44:13 -07:00
|
|
|
|
2024-03-15 04:39:48 -07:00
|
|
|
def testSparseplusGradZero(self):
|
|
|
|
check_grads(nn.sparse_plus, (-2.,), order=1,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
def testSparseplusGrad(self):
|
|
|
|
check_grads(nn.sparse_plus, (0.,), order=1,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
2024-04-09 03:09:20 -07:00
|
|
|
def testSparseplusAndSparseSigmoid(self):
|
|
|
|
self.assertAllClose(
|
|
|
|
jax.grad(nn.sparse_plus)(0.), nn.sparse_sigmoid(0.),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(
|
|
|
|
jax.grad(nn.sparse_plus)(2.), nn.sparse_sigmoid(2.),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(
|
|
|
|
jax.grad(nn.sparse_plus)(-2.), nn.sparse_sigmoid(-2.),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
2023-11-14 23:52:41 -05:00
|
|
|
def testSquareplusGrad(self):
|
|
|
|
check_grads(nn.squareplus, (1e-8,), order=4,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
def testSquareplusGradZero(self):
|
|
|
|
check_grads(nn.squareplus, (0.,), order=1,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
def testSquareplusGradNegInf(self):
|
|
|
|
check_grads(nn.squareplus, (-float('inf'),), order=1,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
def testSquareplusGradNan(self):
|
|
|
|
check_grads(nn.squareplus, (float('nan'),), order=1,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
@parameterized.parameters([float] + jtu.dtypes.floating)
|
|
|
|
def testSquareplusZero(self, dtype):
|
|
|
|
self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4)))
|
|
|
|
|
2024-04-03 16:37:07 -04:00
|
|
|
def testMishGrad(self):
|
|
|
|
check_grads(nn.mish, (1e-8,), order=4,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
def testMishGradZero(self):
|
|
|
|
check_grads(nn.mish, (0.,), order=1,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
def testMishGradNegInf(self):
|
|
|
|
check_grads(nn.mish, (-float('inf'),), order=1,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
def testMishGradNan(self):
|
|
|
|
check_grads(nn.mish, (float('nan'),), order=1,
|
|
|
|
rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None)
|
|
|
|
|
|
|
|
@parameterized.parameters([float] + jtu.dtypes.floating)
|
|
|
|
def testMishZero(self, dtype):
|
|
|
|
self.assertEqual(dtype(0), nn.mish(dtype(0)))
|
|
|
|
|
2020-03-03 16:27:53 -08:00
|
|
|
def testReluGrad(self):
|
2023-09-27 12:10:06 -07:00
|
|
|
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
|
2020-03-03 16:27:53 -08:00
|
|
|
check_grads(nn.relu, (1.,), order=3, rtol=rtol)
|
|
|
|
check_grads(nn.relu, (-1.,), order=3, rtol=rtol)
|
|
|
|
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
|
2020-03-03 16:27:53 -08:00
|
|
|
|
2023-03-07 17:30:17 -08:00
|
|
|
def testRelu6Grad(self):
|
2023-09-27 12:10:06 -07:00
|
|
|
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
|
2023-03-07 17:30:17 -08:00
|
|
|
check_grads(nn.relu6, (1.,), order=3, rtol=rtol)
|
|
|
|
check_grads(nn.relu6, (-1.,), order=3, rtol=rtol)
|
|
|
|
self.assertAllClose(jax.grad(nn.relu6)(0.), 0., check_dtypes=False)
|
|
|
|
self.assertAllClose(jax.grad(nn.relu6)(6.), 0., check_dtypes=False)
|
|
|
|
|
2019-09-21 01:04:26 -04:00
|
|
|
def testSoftplusValue(self):
|
|
|
|
val = nn.softplus(89.)
|
2019-09-27 12:11:18 -04:00
|
|
|
self.assertAllClose(val, 89., check_dtypes=False)
|
2019-10-03 12:01:21 -07:00
|
|
|
|
2024-03-15 04:39:48 -07:00
|
|
|
def testSparseplusValue(self):
|
|
|
|
val = nn.sparse_plus(89.)
|
|
|
|
self.assertAllClose(val, 89., check_dtypes=False)
|
|
|
|
|
2024-04-09 03:09:20 -07:00
|
|
|
def testSparsesigmoidValue(self):
|
|
|
|
self.assertAllClose(nn.sparse_sigmoid(-2.), 0., check_dtypes=False)
|
|
|
|
self.assertAllClose(nn.sparse_sigmoid(2.), 1., check_dtypes=False)
|
|
|
|
self.assertAllClose(nn.sparse_sigmoid(0.), .5, check_dtypes=False)
|
|
|
|
|
2023-11-14 23:52:41 -05:00
|
|
|
def testSquareplusValue(self):
|
|
|
|
val = nn.squareplus(1e3)
|
|
|
|
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
|
|
|
|
2024-04-03 16:37:07 -04:00
|
|
|
def testMishValue(self):
|
|
|
|
val = nn.mish(1e3)
|
|
|
|
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
|
|
|
|
2020-02-05 17:35:46 +01:00
|
|
|
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
2019-10-21 11:48:58 +00:00
|
|
|
def testEluGrad(self):
|
2020-03-03 16:27:53 -08:00
|
|
|
check_grads(nn.elu, (1e4,), order=4, eps=1.)
|
2019-10-21 11:48:58 +00:00
|
|
|
|
|
|
|
def testEluValue(self):
|
|
|
|
val = nn.elu(1e4)
|
|
|
|
self.assertAllClose(val, 1e4, check_dtypes=False)
|
2020-06-02 17:37:20 -07:00
|
|
|
|
2020-05-02 19:33:10 -07:00
|
|
|
def testGluValue(self):
|
2022-03-17 11:38:13 -07:00
|
|
|
val = nn.glu(jnp.array([1.0, 0.0]), axis=0)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(val, jnp.array([0.5]))
|
2019-10-21 11:48:58 +00:00
|
|
|
|
2022-08-11 15:46:34 +01:00
|
|
|
@parameterized.parameters(False, True)
|
|
|
|
def testGeluIntType(self, approximate):
|
|
|
|
val_float = nn.gelu(jnp.array(-1.0), approximate=approximate)
|
|
|
|
val_int = nn.gelu(jnp.array(-1), approximate=approximate)
|
2022-08-11 15:41:07 +01:00
|
|
|
self.assertAllClose(val_float, val_int)
|
|
|
|
|
2020-10-02 09:48:07 -04:00
|
|
|
@parameterized.parameters(False, True)
|
|
|
|
def testGelu(self, approximate):
|
|
|
|
def gelu_reference(x):
|
|
|
|
return x * scipy.stats.norm.cdf(x)
|
2024-09-20 13:05:14 -07:00
|
|
|
args_maker = lambda: [jnp.linspace(-12, 5, 10000, dtype=jnp.float32)]
|
|
|
|
rtol = 2e-5
|
|
|
|
atol = 1e-3 if approximate else 0
|
2020-10-02 09:48:07 -04:00
|
|
|
self._CheckAgainstNumpy(
|
2024-09-20 13:05:14 -07:00
|
|
|
gelu_reference,
|
|
|
|
partial(nn.gelu, approximate=approximate),
|
|
|
|
args_maker,
|
|
|
|
check_dtypes=False,
|
|
|
|
tol=0,
|
|
|
|
rtol=rtol,
|
|
|
|
atol=atol,
|
|
|
|
)
|
2020-10-02 09:48:07 -04:00
|
|
|
|
2020-02-19 06:04:20 +00:00
|
|
|
@parameterized.parameters(*itertools.product(
|
2020-05-05 14:59:16 -04:00
|
|
|
(jnp.float32, jnp.bfloat16, jnp.float16),
|
2020-10-02 09:48:07 -04:00
|
|
|
(partial(nn.gelu, approximate=False),
|
|
|
|
partial(nn.gelu, approximate=True),
|
2024-04-03 16:37:07 -04:00
|
|
|
nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish)))
|
2020-02-19 06:04:20 +00:00
|
|
|
def testDtypeMatchesInput(self, dtype, fn):
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.zeros((), dtype=dtype)
|
2020-02-19 06:04:20 +00:00
|
|
|
out = fn(x)
|
|
|
|
self.assertEqual(out.dtype, dtype)
|
|
|
|
|
2021-03-10 11:34:42 -05:00
|
|
|
def testEluMemory(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# see https://github.com/jax-ml/jax/pull/1640
|
2021-03-19 13:49:38 -07:00
|
|
|
with jax.enable_checks(False): # With checks we materialize the array
|
2021-03-10 11:34:42 -05:00
|
|
|
jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom
|
|
|
|
|
|
|
|
def testHardTanhMemory(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# see https://github.com/jax-ml/jax/pull/1640
|
2021-03-19 13:49:38 -07:00
|
|
|
with jax.enable_checks(False): # With checks we materialize the array
|
2021-03-10 11:34:42 -05:00
|
|
|
jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
|
|
|
|
|
2024-04-10 10:23:21 -07:00
|
|
|
@parameterized.parameters([nn.softmax, nn.log_softmax])
|
|
|
|
def testSoftmaxEmptyArray(self, fn):
|
|
|
|
x = jnp.array([], dtype=float)
|
|
|
|
self.assertArraysEqual(fn(x), x)
|
|
|
|
|
|
|
|
@parameterized.parameters([nn.softmax, nn.log_softmax])
|
|
|
|
def testSoftmaxEmptyMask(self, fn):
|
|
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
|
|
m = jnp.zeros_like(x, dtype=bool)
|
|
|
|
expected = jnp.full_like(x, 0.0 if fn is nn.softmax else -jnp.inf)
|
|
|
|
self.assertArraysEqual(fn(x, where=m), expected)
|
|
|
|
|
2021-12-21 13:59:30 +01:00
|
|
|
@parameterized.parameters([nn.softmax, nn.log_softmax])
|
|
|
|
def testSoftmaxWhereMask(self, fn):
|
|
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
|
|
m = jnp.array([True, False, True, True])
|
|
|
|
|
2024-04-10 10:23:21 -07:00
|
|
|
out = fn(x, where=m)
|
2023-06-01 10:18:05 -07:00
|
|
|
self.assertAllClose(out[m], fn(x[m]))
|
2021-12-21 13:59:30 +01:00
|
|
|
|
2023-06-01 10:18:05 -07:00
|
|
|
probs = out if fn is nn.softmax else jnp.exp(out)
|
|
|
|
self.assertAllClose(probs.sum(), 1.0)
|
2021-12-21 13:59:30 +01:00
|
|
|
|
2023-05-23 11:56:50 -07:00
|
|
|
# TODO(mattjj): include log_softmax in these extra tests if/when we add a
|
|
|
|
# custom_jvp rule for it (since otherwise it doesn't pass the numerical
|
|
|
|
# checks below).
|
2023-10-11 08:45:30 -07:00
|
|
|
if fn is nn.softmax and config.softmax_custom_jvp.value:
|
2023-04-19 18:11:35 -07:00
|
|
|
g_fun = lambda x: jnp.take(fn(x, where=m, initial=-jnp.inf),
|
|
|
|
jnp.array([0, 2, 3]))
|
|
|
|
jtu.check_grads(g_fun, (x,), order=2)
|
|
|
|
|
2024-01-25 10:04:44 -08:00
|
|
|
@parameterized.parameters([nn.softmax, nn.log_softmax])
|
|
|
|
def testSoftmaxWhereGrad(self, fn):
|
2024-09-20 07:51:48 -07:00
|
|
|
# regression test for https://github.com/jax-ml/jax/issues/19490
|
2024-01-25 10:04:44 -08:00
|
|
|
x = jnp.array([36., 10000.])
|
|
|
|
mask = x < 1000
|
|
|
|
|
2024-04-10 10:23:21 -07:00
|
|
|
f = lambda x, mask: fn(x, where=mask)[0]
|
2024-01-25 10:04:44 -08:00
|
|
|
|
|
|
|
self.assertAllClose(jax.grad(f)(x, mask), jnp.zeros_like(x))
|
|
|
|
|
2023-04-19 18:11:35 -07:00
|
|
|
def testSoftmaxGrad(self):
|
|
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
2023-10-24 10:44:52 -07:00
|
|
|
jtu.check_grads(nn.softmax, (x,), order=2, atol=5e-3)
|
2023-04-19 18:11:35 -07:00
|
|
|
|
|
|
|
def testSoftmaxGradResiduals(self):
|
2023-10-11 08:45:30 -07:00
|
|
|
if not config.softmax_custom_jvp.value:
|
2023-04-19 18:11:35 -07:00
|
|
|
raise unittest.SkipTest("only applies when upgrade flag enabled")
|
|
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
|
|
res = ad_checkpoint.saved_residuals(nn.softmax, x)
|
|
|
|
self.assertLen(res, 1)
|
|
|
|
|
|
|
|
def testSoftmaxGradFlag(self):
|
|
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
|
|
|
|
|
|
with jax.softmax_custom_jvp(False):
|
|
|
|
res = ad_checkpoint.saved_residuals(nn.softmax, x)
|
|
|
|
self.assertLen(res, 3)
|
|
|
|
self.assertEqual(sum(a.size for a, _ in res), 6)
|
|
|
|
|
|
|
|
with jax.softmax_custom_jvp(True):
|
|
|
|
res = ad_checkpoint.saved_residuals(nn.softmax, x)
|
|
|
|
self.assertLen(res, 1)
|
|
|
|
self.assertEqual(sum(a.size for a, _ in res), 4)
|
|
|
|
|
2022-03-23 20:39:39 +00:00
|
|
|
def testStandardizeWhereMask(self):
|
2021-12-21 13:59:30 +01:00
|
|
|
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
|
|
|
m = jnp.array([True, False, True, True])
|
|
|
|
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
|
|
|
|
2022-03-23 20:39:39 +00:00
|
|
|
out_masked = jnp.take(nn.standardize(x, where=m), jnp.array([0, 2, 3]))
|
|
|
|
out_filtered = nn.standardize(x_filtered)
|
2021-12-21 13:59:30 +01:00
|
|
|
|
|
|
|
self.assertAllClose(out_masked, out_filtered)
|
|
|
|
|
2020-02-15 18:32:00 +00:00
|
|
|
def testOneHot(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
actual = nn.one_hot(jnp.array([0, 1, 2]), 3)
|
|
|
|
expected = jnp.array([[1., 0., 0.],
|
2022-11-30 15:17:51 -08:00
|
|
|
[0., 1., 0.],
|
|
|
|
[0., 0., 1.]])
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2020-02-15 18:32:00 +00:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
actual = nn.one_hot(jnp.array([1, 2, 0]), 3)
|
|
|
|
expected = jnp.array([[0., 1., 0.],
|
2022-11-30 15:17:51 -08:00
|
|
|
[0., 0., 1.],
|
|
|
|
[1., 0., 0.]])
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2020-02-15 18:32:00 +00:00
|
|
|
|
|
|
|
def testOneHotOutOfBound(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
actual = nn.one_hot(jnp.array([-1, 3]), 3)
|
|
|
|
expected = jnp.array([[0., 0., 0.],
|
2022-11-30 15:17:51 -08:00
|
|
|
[0., 0., 0.]])
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2020-02-15 18:32:00 +00:00
|
|
|
|
|
|
|
def testOneHotNonArrayInput(self):
|
|
|
|
actual = nn.one_hot([0, 1, 2], 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = jnp.array([[1., 0., 0.],
|
2022-11-30 15:17:51 -08:00
|
|
|
[0., 1., 0.],
|
|
|
|
[0., 0., 1.]])
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2020-02-15 18:32:00 +00:00
|
|
|
|
|
|
|
def testOneHotCustomDtype(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
actual = nn.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
|
|
|
|
expected = jnp.array([[True, False, False],
|
2022-11-30 15:17:51 -08:00
|
|
|
[False, True, False],
|
|
|
|
[False, False, True]])
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(actual, expected)
|
2020-02-15 18:32:00 +00:00
|
|
|
|
2020-07-03 20:54:25 -07:00
|
|
|
def testOneHotConcretizationError(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/3654
|
2020-09-15 08:06:46 -07:00
|
|
|
msg = r"in jax.nn.one_hot argument `num_classes`"
|
2020-07-03 20:54:25 -07:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
|
|
jax.jit(nn.one_hot)(3, 5)
|
|
|
|
|
2021-02-02 14:06:42 +00:00
|
|
|
def testOneHotAxis(self):
|
|
|
|
expected = jnp.array([[0., 1., 0.],
|
2022-11-30 15:17:51 -08:00
|
|
|
[0., 0., 1.],
|
|
|
|
[1., 0., 0.]]).T
|
2021-02-02 14:06:42 +00:00
|
|
|
|
|
|
|
actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
|
2022-11-30 15:17:51 -08:00
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2021-02-02 14:06:42 +00:00
|
|
|
|
|
|
|
actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
|
2022-11-30 15:17:51 -08:00
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
2020-07-03 20:54:25 -07:00
|
|
|
|
2021-04-29 08:25:26 -07:00
|
|
|
def testTanhExists(self):
|
|
|
|
nn.tanh # doesn't crash
|
|
|
|
|
2021-12-14 12:57:52 -08:00
|
|
|
def testCustomJVPLeak(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/8171
|
2021-12-14 12:57:52 -08:00
|
|
|
@jax.jit
|
|
|
|
def fwd():
|
|
|
|
a = jnp.array(1.)
|
|
|
|
|
|
|
|
def f(hx, _):
|
|
|
|
hx = jax.nn.sigmoid(hx + a)
|
|
|
|
return hx, None
|
|
|
|
|
|
|
|
hx = jnp.array(0.)
|
|
|
|
jax.lax.scan(f, hx, None, length=2)
|
|
|
|
|
|
|
|
with jax.checking_leaks():
|
|
|
|
fwd() # doesn't crash
|
|
|
|
|
2022-09-23 11:24:13 -07:00
|
|
|
def testCustomJVPLeak2(self):
|
2024-09-20 07:51:48 -07:00
|
|
|
# https://github.com/jax-ml/jax/issues/8171
|
2022-09-23 11:24:13 -07:00
|
|
|
# The above test uses jax.nn.sigmoid, as in the original #8171, but that
|
|
|
|
# function no longer actually has a custom_jvp! So we inline the old def.
|
|
|
|
|
|
|
|
@jax.custom_jvp
|
|
|
|
def sigmoid(x):
|
|
|
|
one = jnp.float32(1)
|
|
|
|
return jax.lax.div(one, jax.lax.add(one, jax.lax.exp(jax.lax.neg(x))))
|
|
|
|
sigmoid.defjvps(lambda g, ans, x: g * ans * (jnp.float32(1) - ans))
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def fwd():
|
|
|
|
a = jnp.array(1., 'float32')
|
|
|
|
|
|
|
|
def f(hx, _):
|
2022-09-23 12:42:15 -07:00
|
|
|
hx = sigmoid(hx + a)
|
2022-09-23 11:24:13 -07:00
|
|
|
return hx, None
|
|
|
|
|
|
|
|
hx = jnp.array(0., 'float32')
|
|
|
|
jax.lax.scan(f, hx, None, length=2)
|
|
|
|
|
|
|
|
with jax.checking_leaks():
|
|
|
|
fwd() # doesn't crash
|
|
|
|
|
|
|
|
|
2019-10-03 12:01:21 -07:00
|
|
|
InitializerRecord = collections.namedtuple(
|
|
|
|
"InitializerRecord",
|
2021-07-31 19:26:53 +02:00
|
|
|
["name", "initializer", "shapes", "dtypes"])
|
2019-10-03 12:01:21 -07:00
|
|
|
|
|
|
|
ALL_SHAPES = [(2,), (2, 2), (2, 3), (3, 2), (2, 3, 4), (4, 3, 2), (2, 3, 4, 5)]
|
|
|
|
|
2021-07-31 19:26:53 +02:00
|
|
|
def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4):
|
2019-10-03 12:01:21 -07:00
|
|
|
shapes = [shape for shape in ALL_SHAPES
|
|
|
|
if min_dims <= len(shape) <= max_dims]
|
2021-07-31 19:26:53 +02:00
|
|
|
return InitializerRecord(name, initializer, shapes, dtypes)
|
2019-10-03 12:01:21 -07:00
|
|
|
|
|
|
|
INITIALIZER_RECS = [
|
2021-07-31 19:26:53 +02:00
|
|
|
initializer_record("uniform", nn.initializers.uniform, jtu.dtypes.floating, 1),
|
|
|
|
initializer_record("normal", nn.initializers.normal, jtu.dtypes.inexact, 1),
|
|
|
|
initializer_record("he_normal", nn.initializers.he_normal, jtu.dtypes.inexact),
|
|
|
|
initializer_record("he_uniform", nn.initializers.he_uniform, jtu.dtypes.inexact),
|
|
|
|
initializer_record("glorot_normal", nn.initializers.glorot_normal, jtu.dtypes.inexact),
|
|
|
|
initializer_record("glorot_uniform", nn.initializers.glorot_uniform, jtu.dtypes.inexact),
|
|
|
|
initializer_record("lecun_normal", nn.initializers.lecun_normal, jtu.dtypes.inexact),
|
|
|
|
initializer_record("lecun_uniform", nn.initializers.lecun_uniform, jtu.dtypes.inexact),
|
|
|
|
initializer_record("orthogonal", nn.initializers.orthogonal, jtu.dtypes.floating, 2, 2),
|
2023-09-07 16:23:03 -07:00
|
|
|
initializer_record("truncated_normal", nn.initializers.truncated_normal, jtu.dtypes.floating, 1),
|
2021-07-31 19:26:53 +02:00
|
|
|
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
|
2019-10-03 12:01:21 -07:00
|
|
|
]
|
|
|
|
|
2022-02-14 09:22:05 -08:00
|
|
|
|
2023-08-25 14:11:19 -07:00
|
|
|
@jtu.with_config(jax_legacy_prng_key="allow")
|
2019-10-03 12:01:21 -07:00
|
|
|
class NNInitializersTest(jtu.JaxTestCase):
|
2022-10-10 11:34:43 -07:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
|
|
|
[dict(initializer=rec.initializer())],
|
|
|
|
shape=rec.shapes,
|
|
|
|
dtype=rec.dtypes
|
|
|
|
)
|
|
|
|
for rec in INITIALIZER_RECS
|
|
|
|
))
|
2019-10-29 11:04:55 -04:00
|
|
|
def testInitializer(self, initializer, shape, dtype):
|
|
|
|
rng = random.PRNGKey(0)
|
2019-10-21 11:48:58 +00:00
|
|
|
val = initializer(rng, shape, dtype)
|
2021-07-31 19:26:53 +02:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertEqual(shape, jnp.shape(val))
|
|
|
|
self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))
|
2020-02-04 18:38:38 +02:00
|
|
|
|
2022-10-10 11:34:43 -07:00
|
|
|
@parameterized.parameters(itertools.chain.from_iterable(
|
|
|
|
jtu.sample_product_testcases(
|
|
|
|
[dict(initializer_provider=rec.initializer)],
|
|
|
|
shape=rec.shapes,
|
|
|
|
dtype=rec.dtypes
|
|
|
|
)
|
|
|
|
for rec in INITIALIZER_RECS
|
|
|
|
))
|
2020-02-04 18:38:38 +02:00
|
|
|
def testInitializerProvider(self, initializer_provider, shape, dtype):
|
|
|
|
rng = random.PRNGKey(0)
|
|
|
|
initializer = initializer_provider(dtype=dtype)
|
|
|
|
val = initializer(rng, shape)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertEqual(shape, jnp.shape(val))
|
|
|
|
self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))
|
2020-02-04 18:38:38 +02:00
|
|
|
|
2021-10-20 09:25:02 -07:00
|
|
|
def testVarianceScalingMultiAxis(self):
|
|
|
|
rng = random.PRNGKey(0)
|
|
|
|
shape = (2, 3, 4, 5)
|
|
|
|
initializer = nn.initializers.variance_scaling(
|
2021-10-27 22:05:08 -07:00
|
|
|
scale=1.0, mode='fan_avg', distribution='truncated_normal',
|
2021-10-20 09:25:02 -07:00
|
|
|
in_axis=(0, 1), out_axis=(-2, -1))
|
|
|
|
val = initializer(rng, shape)
|
2021-10-27 22:06:06 -07:00
|
|
|
|
2021-10-20 09:25:02 -07:00
|
|
|
self.assertEqual(shape, jnp.shape(val))
|
|
|
|
|
2022-02-04 17:01:36 -08:00
|
|
|
def testVarianceScalingBatchAxis(self):
|
|
|
|
rng = random.PRNGKey(0)
|
|
|
|
shape = (2, 3, 4, 5)
|
|
|
|
initializer = nn.initializers.variance_scaling(
|
|
|
|
scale=1.0, mode='fan_avg', distribution='truncated_normal',
|
|
|
|
in_axis=0, out_axis=(2, 3), batch_axis=1)
|
|
|
|
val = initializer(rng, shape)
|
|
|
|
|
|
|
|
self.assertEqual(shape, jnp.shape(val))
|
|
|
|
|
2023-01-12 16:20:48 -08:00
|
|
|
def testVarianceScalingError(self):
|
|
|
|
rng = random.PRNGKey(0)
|
|
|
|
shape = (5,)
|
|
|
|
initializer = nn.initializers.variance_scaling(
|
|
|
|
scale=1.0, mode='fan_avg', distribution='truncated_normal')
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
"Can't compute input and output sizes of a 1"
|
|
|
|
"-dimensional weights tensor. Must be at least 2D."
|
|
|
|
):
|
|
|
|
initializer(rng, shape)
|
|
|
|
|
2024-03-01 14:23:47 -08:00
|
|
|
def testAccidentalUpcasting(self):
|
|
|
|
rng = random.PRNGKey(0)
|
|
|
|
shape = (4, 4)
|
|
|
|
scalar_param = jnp.array(1.0, dtype=jnp.float32)
|
|
|
|
for init_fn in (nn.initializers.uniform(scalar_param, jnp.bfloat16),
|
|
|
|
nn.initializers.normal(scalar_param, jnp.bfloat16),
|
|
|
|
nn.initializers.truncated_normal(scalar_param, jnp.bfloat16),
|
|
|
|
):
|
|
|
|
sub_rng, rng = random.split(rng)
|
|
|
|
val = init_fn(sub_rng, shape)
|
|
|
|
self.assertEqual(val.dtype, jnp.bfloat16)
|
2019-10-21 11:48:58 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|