mirror of
https://github.com/ROCm/jax.git
synced 2025-04-22 22:06:05 +00:00
1420 lines
45 KiB
Python
1420 lines
45 KiB
Python
# Copyright 2019 The JAX Authors.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# https://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""Shared neural network activations and other functions."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from collections.abc import Sequence
|
||
from functools import partial
|
||
import operator
|
||
import math
|
||
import numpy as np
|
||
from typing import Any, List, Literal
|
||
import warnings
|
||
|
||
import jax
|
||
import jax.numpy as jnp
|
||
from jax import custom_jvp
|
||
from jax import lax
|
||
from jax._src import config
|
||
from jax._src import core
|
||
from jax._src import deprecations
|
||
from jax._src import dtypes
|
||
from jax._src import util
|
||
from jax._src.core import AxisName
|
||
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
|
||
from jax._src.cudnn.fused_attention_stablehlo import (
|
||
dot_product_attention as cudnn_dot_product_attention, MaskType)
|
||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||
scaled_matmul_wrapper as cudnn_scaled_matmul,
|
||
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
|
||
BlockScaleConfig)
|
||
from jax._src.interpreters import batching
|
||
from jax._src.interpreters import mlir
|
||
from jax._src.numpy import util as numpy_util
|
||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||
from jax._src.ops.special import logsumexp as _logsumexp
|
||
|
||
|
||
# activations
|
||
@jax.jit
|
||
def identity(x: ArrayLike) -> Array:
|
||
r"""Identity activation function.
|
||
|
||
Returns the argument unmodified.
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
The argument `x` unmodified.
|
||
|
||
Examples:
|
||
>>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
|
||
Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)
|
||
|
||
"""
|
||
numpy_util.check_arraylike("identity", x)
|
||
return jnp.asarray(x)
|
||
|
||
@custom_jvp
|
||
@jax.jit
|
||
def relu(x: ArrayLike) -> Array:
|
||
r"""Rectified linear unit activation function.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{relu}(x) = \max(x, 0)
|
||
|
||
except under differentiation, we take:
|
||
|
||
.. math::
|
||
\nabla \mathrm{relu}(0) = 0
|
||
|
||
For more information see
|
||
`Numerical influence of ReLU’(0) on backpropagation
|
||
<https://dl.acm.org/doi/10.5555/3540261.3540297>`_.
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
Examples:
|
||
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
|
||
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
|
||
|
||
See also:
|
||
:func:`relu6`
|
||
|
||
"""
|
||
return jnp.maximum(x, 0)
|
||
# For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297
|
||
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
|
||
|
||
@jax.jit
|
||
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array:
|
||
r"""Squareplus activation function.
|
||
|
||
Computes the element-wise function
|
||
|
||
.. math::
|
||
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
|
||
|
||
as described in https://arxiv.org/abs/2112.11687.
|
||
|
||
Args:
|
||
x : input array
|
||
b : smoothness parameter
|
||
"""
|
||
numpy_util.check_arraylike("squareplus", x)
|
||
numpy_util.check_arraylike("squareplus", b)
|
||
x = jnp.asarray(x)
|
||
b = jnp.asarray(b)
|
||
y = x + jnp.sqrt(jnp.square(x) + b)
|
||
return y / 2
|
||
|
||
@jax.jit
|
||
def softplus(x: ArrayLike) -> Array:
|
||
r"""Softplus activation function.
|
||
|
||
Computes the element-wise function
|
||
|
||
.. math::
|
||
\mathrm{softplus}(x) = \log(1 + e^x)
|
||
|
||
Args:
|
||
x : input array
|
||
"""
|
||
return jnp.logaddexp(x, 0)
|
||
|
||
@jax.jit
|
||
def sparse_plus(x: ArrayLike) -> Array:
|
||
r"""Sparse plus function.
|
||
|
||
Computes the function:
|
||
|
||
.. math::
|
||
|
||
\mathrm{sparse\_plus}(x) = \begin{cases}
|
||
0, & x \leq -1\\
|
||
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
|
||
x, & 1 \leq x
|
||
\end{cases}
|
||
|
||
This is the twin function of the softplus activation ensuring a zero output
|
||
for inputs less than -1 and a linear output for inputs greater than 1,
|
||
while remaining smooth, convex, monotonic by an adequate definition between
|
||
-1 and 1.
|
||
|
||
Args:
|
||
x: input (float)
|
||
"""
|
||
numpy_util.check_arraylike("sparse_plus", x)
|
||
x = jnp.asarray(x)
|
||
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
|
||
|
||
@jax.jit
|
||
def soft_sign(x: ArrayLike) -> Array:
|
||
r"""Soft-sign activation function.
|
||
|
||
Computes the element-wise function
|
||
|
||
.. math::
|
||
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
|
||
|
||
Args:
|
||
x : input array
|
||
"""
|
||
numpy_util.check_arraylike("soft_sign", x)
|
||
x_arr = jnp.asarray(x)
|
||
return x_arr / (jnp.abs(x_arr) + 1)
|
||
|
||
@partial(jax.jit, inline=True)
|
||
def sigmoid(x: ArrayLike) -> Array:
|
||
r"""Sigmoid activation function.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`log_sigmoid`
|
||
|
||
"""
|
||
return lax.logistic(x)
|
||
|
||
@jax.jit
|
||
def sparse_sigmoid(x: ArrayLike) -> Array:
|
||
r"""Sparse sigmoid activation function.
|
||
|
||
Computes the function:
|
||
|
||
.. math::
|
||
|
||
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
|
||
0, & x \leq -1\\
|
||
\frac{1}{2}(x+1), & -1 < x < 1 \\
|
||
1, & 1 \leq x
|
||
\end{cases}
|
||
|
||
This is the twin function of the ``sigmoid`` activation ensuring a zero output
|
||
for inputs less than -1, a 1 output for inputs greater than 1, and a linear
|
||
output for inputs between -1 and 1. It is the derivative of ``sparse_plus``.
|
||
|
||
For more information, see `Learning with Fenchel-Young Losses (section 6.2)
|
||
<https://arxiv.org/abs/1901.02324>`_.
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`sigmoid`
|
||
"""
|
||
return 0.5 * jnp.clip(x + 1.0, 0.0, 2.0)
|
||
|
||
@jax.jit
|
||
def silu(x: ArrayLike) -> Array:
|
||
r"""SiLU (aka swish) activation function.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
|
||
|
||
:func:`swish` and :func:`silu` are both aliases for the same function.
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`sigmoid`
|
||
"""
|
||
numpy_util.check_arraylike("silu", x)
|
||
x_arr = jnp.asarray(x)
|
||
return x_arr * sigmoid(x_arr)
|
||
|
||
swish = silu
|
||
|
||
@jax.jit
|
||
def mish(x: ArrayLike) -> Array:
|
||
r"""Mish activation function.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{mish}(x) = x \cdot \mathrm{tanh}(\mathrm{softplus}(x))
|
||
|
||
For more information, see
|
||
`Mish: A Self Regularized Non-Monotonic Activation Function
|
||
<https://arxiv.org/abs/1908.08681>`_.
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
"""
|
||
numpy_util.check_arraylike("mish", x)
|
||
x_arr = jnp.asarray(x)
|
||
return x_arr * jnp.tanh(softplus(x_arr))
|
||
|
||
@jax.jit
|
||
def log_sigmoid(x: ArrayLike) -> Array:
|
||
r"""Log-sigmoid activation function.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`sigmoid`
|
||
"""
|
||
numpy_util.check_arraylike("log_sigmoid", x)
|
||
x_arr = jnp.asarray(x)
|
||
return -softplus(-x_arr)
|
||
|
||
@jax.jit
|
||
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
|
||
r"""Exponential linear unit activation function.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{elu}(x) = \begin{cases}
|
||
x, & x > 0\\
|
||
\alpha \left(\exp(x) - 1\right), & x \le 0
|
||
\end{cases}
|
||
|
||
Args:
|
||
x : input array
|
||
alpha : scalar or array of alpha values (default: 1.0)
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`selu`
|
||
"""
|
||
numpy_util.check_arraylike("elu", x)
|
||
x_arr = jnp.asarray(x)
|
||
return jnp.where(x_arr > 0,
|
||
x_arr,
|
||
alpha * jnp.expm1(jnp.where(x_arr > 0, 0., x_arr)))
|
||
|
||
@jax.jit
|
||
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array:
|
||
r"""Leaky rectified linear unit activation function.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{leaky\_relu}(x) = \begin{cases}
|
||
x, & x \ge 0\\
|
||
\alpha x, & x < 0
|
||
\end{cases}
|
||
|
||
where :math:`\alpha` = :code:`negative_slope`.
|
||
|
||
Args:
|
||
x : input array
|
||
negative_slope : array or scalar specifying the negative slope (default: 0.01)
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`relu`
|
||
"""
|
||
numpy_util.check_arraylike("leaky_relu", x)
|
||
x_arr = jnp.asarray(x)
|
||
return jnp.where(x_arr >= 0, x_arr, negative_slope * x_arr)
|
||
|
||
@jax.jit
|
||
def hard_tanh(x: ArrayLike) -> Array:
|
||
r"""Hard :math:`\mathrm{tanh}` activation function.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{hard\_tanh}(x) = \begin{cases}
|
||
-1, & x < -1\\
|
||
x, & -1 \le x \le 1\\
|
||
1, & 1 < x
|
||
\end{cases}
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
"""
|
||
numpy_util.check_arraylike("hard_tanh", x)
|
||
x_arr = jnp.asarray(x)
|
||
return jnp.where(x_arr > 1, 1, jnp.where(x_arr < -1, -1, x_arr))
|
||
|
||
@jax.jit
|
||
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
|
||
r"""Continuously-differentiable exponential linear unit activation.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{celu}(x) = \begin{cases}
|
||
x, & x > 0\\
|
||
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
|
||
\end{cases}
|
||
|
||
For more information, see
|
||
`Continuously Differentiable Exponential Linear Units
|
||
<https://arxiv.org/abs/1704.07483>`_.
|
||
|
||
Args:
|
||
x : input array
|
||
alpha : array or scalar (default: 1.0)
|
||
|
||
Returns:
|
||
An array.
|
||
"""
|
||
return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)
|
||
|
||
@jax.jit
|
||
def selu(x: ArrayLike) -> Array:
|
||
r"""Scaled exponential linear unit activation.
|
||
|
||
Computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{selu}(x) = \lambda \begin{cases}
|
||
x, & x > 0\\
|
||
\alpha e^x - \alpha, & x \le 0
|
||
\end{cases}
|
||
|
||
where :math:`\lambda = 1.0507009873554804934193349852946` and
|
||
:math:`\alpha = 1.6732632423543772848170429916717`.
|
||
|
||
For more information, see
|
||
`Self-Normalizing Neural Networks
|
||
<https://arxiv.org/abs/1706.02515>`_.
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`elu`
|
||
"""
|
||
alpha = 1.6732632423543772848170429916717
|
||
scale = 1.0507009873554804934193349852946
|
||
return scale * elu(x, alpha)
|
||
|
||
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
|
||
# @partial(jax.jit, static_argnames=("approximate",))
|
||
def gelu(x: ArrayLike, approximate: bool = True) -> Array:
|
||
r"""Gaussian error linear unit activation function.
|
||
|
||
If ``approximate=False``, computes the element-wise function:
|
||
|
||
.. math::
|
||
\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left(
|
||
\frac{-x}{\sqrt{2}} \right) \right)
|
||
|
||
If ``approximate=True``, uses the approximate formulation of GELU:
|
||
|
||
.. math::
|
||
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
||
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
||
|
||
For more information, see `Gaussian Error Linear Units (GELUs)
|
||
<https://arxiv.org/abs/1606.08415>`_, section 2.
|
||
|
||
Args:
|
||
x: input array
|
||
approximate: whether to use the approximate or exact formulation.
|
||
"""
|
||
[x_arr] = numpy_util.promote_args_inexact("gelu", x)
|
||
|
||
if approximate:
|
||
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x_arr.dtype)
|
||
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3))))
|
||
return x_arr * cdf
|
||
else:
|
||
sqrt_half = np.sqrt(0.5).astype(x_arr.dtype)
|
||
return jnp.array(
|
||
0.5 * x_arr * (lax.erfc(-x_arr * sqrt_half)), dtype=x_arr.dtype
|
||
)
|
||
|
||
@partial(jax.jit, static_argnames=("axis",))
|
||
def glu(x: ArrayLike, axis: int = -1) -> Array:
|
||
r"""Gated linear unit activation function.
|
||
|
||
Computes the function:
|
||
|
||
.. math::
|
||
\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
|
||
\mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
|
||
\right)
|
||
|
||
where the array is split into two along ``axis``. The size of the ``axis``
|
||
dimension must be divisible by two.
|
||
|
||
Args:
|
||
x : input array
|
||
axis: the axis along which the split should be computed (default: -1)
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`sigmoid`
|
||
"""
|
||
numpy_util.check_arraylike("glu", x)
|
||
x_arr = jnp.asarray(x)
|
||
size = x_arr.shape[axis]
|
||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||
x1, x2 = jnp.split(x_arr, 2, axis)
|
||
return x1 * sigmoid(x2)
|
||
|
||
# other functions
|
||
|
||
logsumexp = _logsumexp
|
||
|
||
|
||
@partial(jax.jit, static_argnames=("axis",))
|
||
def log_softmax(x: ArrayLike,
|
||
axis: int | tuple[int, ...] | None = -1,
|
||
where: ArrayLike | None = None) -> Array:
|
||
r"""Log-Softmax function.
|
||
|
||
Computes the logarithm of the :code:`softmax` function, which rescales
|
||
elements to the range :math:`[-\infty, 0)`.
|
||
|
||
.. math ::
|
||
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||
\right)
|
||
|
||
Args:
|
||
x : input array
|
||
axis: the axis or axes along which the :code:`log_softmax` should be
|
||
computed. Either an integer or a tuple of integers.
|
||
where: Elements to include in the :code:`log_softmax`.
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
Note:
|
||
If any input values are ``+inf``, the result will be all ``NaN``: this reflects the
|
||
fact that ``inf / inf`` is not well-defined in the context of floating-point math.
|
||
|
||
See also:
|
||
:func:`softmax`
|
||
"""
|
||
numpy_util.check_arraylike("log_softmax", x)
|
||
x_arr = jnp.asarray(x)
|
||
x_max = jnp.max(x_arr, axis, where=where, initial=-jnp.inf, keepdims=True)
|
||
x_safe = x_arr if where is None else jnp.where(where, x_arr, -jnp.inf)
|
||
shifted = x_safe - lax.stop_gradient(x_max)
|
||
shifted_logsumexp = jnp.log(
|
||
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
|
||
result = shifted - shifted_logsumexp
|
||
if where is not None:
|
||
return jnp.where(where, result, -jnp.inf)
|
||
return result
|
||
|
||
|
||
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
|
||
# @partial(jax.jit, static_argnames=("axis",))
|
||
def softmax(x: ArrayLike,
|
||
axis: int | tuple[int, ...] | None = -1,
|
||
where: ArrayLike | None = None) -> Array:
|
||
r"""Softmax function.
|
||
|
||
Computes the function which rescales elements to the range :math:`[0, 1]`
|
||
such that the elements along :code:`axis` sum to :math:`1`.
|
||
|
||
.. math ::
|
||
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||
|
||
Args:
|
||
x : input array
|
||
axis: the axis or axes along which the softmax should be computed. The
|
||
softmax output summed across these dimensions should sum to :math:`1`.
|
||
Either an integer or a tuple of integers.
|
||
where: Elements to include in the :code:`softmax`.
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
Note:
|
||
If any input values are ``+inf``, the result will be all ``NaN``: this reflects the
|
||
fact that ``inf / inf`` is not well-defined in the context of floating-point math.
|
||
|
||
See also:
|
||
:func:`log_softmax`
|
||
"""
|
||
if config.softmax_custom_jvp.value:
|
||
# mypy is confused by the `functools.partial` application in the definition
|
||
# of `_softmax` and incorrectly concludes that `_softmax` returns
|
||
# `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`.
|
||
return _softmax(x, axis, where)
|
||
else:
|
||
return _softmax_deprecated(x, axis, where)
|
||
|
||
# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
|
||
@partial(jax.custom_jvp, nondiff_argnums=(1,))
|
||
def _softmax(
|
||
x: ArrayLike,
|
||
axis: int | tuple[int, ...] | None = -1,
|
||
where: ArrayLike | None = None,
|
||
initial: ArrayLike | None = -jnp.inf) -> Array:
|
||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||
x_safe = x if where is None else jnp.where(where, x, initial)
|
||
unnormalized = jnp.exp(x_safe - x_max)
|
||
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||
if where is not None:
|
||
result = jnp.where(where, result, 0)
|
||
return result
|
||
|
||
@_softmax.defjvp
|
||
def _softmax_jvp(axis, primals, tangents):
|
||
(x, where, initial), (x_dot, _, _) = primals, tangents
|
||
y = _softmax(x, axis, where, initial)
|
||
return y, y * (x_dot - (y * x_dot).sum(axis, where=where, keepdims=True))
|
||
|
||
def _softmax_deprecated(
|
||
x: ArrayLike,
|
||
axis: int | tuple[int, ...] | None = -1,
|
||
where: ArrayLike | None = None,
|
||
initial: ArrayLike | None = -jnp.inf) -> Array:
|
||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||
x_safe = x if where is None else jnp.where(where, x, initial)
|
||
unnormalized = jnp.exp(x_safe - lax.stop_gradient(x_max))
|
||
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||
if where is not None:
|
||
result = jnp.where(where, result, 0)
|
||
return result
|
||
|
||
|
||
@partial(jax.jit, static_argnames=("axis",))
|
||
def standardize(x: ArrayLike,
|
||
axis: int | tuple[int, ...] | None = -1,
|
||
mean: ArrayLike | None = None,
|
||
variance: ArrayLike | None = None,
|
||
epsilon: ArrayLike = 1e-5,
|
||
where: ArrayLike | None = None) -> Array:
|
||
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
|
||
numpy_util.check_arraylike("standardize", x)
|
||
numpy_util.check_arraylike_or_none("standardize", mean, variance, where)
|
||
if mean is None:
|
||
mean = jnp.mean(x, axis, keepdims=True, where=where)
|
||
if variance is None:
|
||
# this definition is traditionally seen as less accurate than jnp.var's
|
||
# mean((x - mean(x))**2) but may be faster and even, given typical
|
||
# activation distributions and low-precision arithmetic, more accurate
|
||
# when used in neural network normalization layers
|
||
variance = jnp.mean(
|
||
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
|
||
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)
|
||
|
||
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
|
||
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
|
||
def _one_hot(x: Array, num_classes: int, *,
|
||
dtype: Any, axis: int | AxisName) -> Array:
|
||
num_classes = core.concrete_dim_or_error(
|
||
num_classes,
|
||
"The error arose in jax.nn.one_hot argument `num_classes`.")
|
||
dtype = dtypes.canonicalize_dtype(dtype)
|
||
try:
|
||
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
|
||
except TypeError:
|
||
axis_size = lax.axis_size(axis)
|
||
if num_classes != axis_size:
|
||
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
|
||
f"but {num_classes} != {axis_size}") from None
|
||
axis_idx = lax.axis_index(axis)
|
||
return jnp.asarray(_dot_product_attention_xla == axis_idx, dtype=dtype)
|
||
axis = operator.index(axis) # type: ignore[arg-type]
|
||
lhs = lax.expand_dims(x, (axis,))
|
||
rhs_shape = [1] * x.ndim
|
||
rhs_shape.insert(output_pos_axis, num_classes)
|
||
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
|
||
rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
|
||
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
|
||
out_sharding=rhs_sharding)
|
||
return (lhs == rhs).astype(dtype)
|
||
|
||
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
|
||
def one_hot(x: Any, num_classes: int, *,
|
||
dtype: Any = jnp.float_, axis: int | AxisName = -1) -> Array:
|
||
"""One-hot encodes the given indices.
|
||
|
||
Each index in the input ``x`` is encoded as a vector of zeros of length
|
||
``num_classes`` with the element at ``index`` set to one::
|
||
|
||
>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
|
||
Array([[1., 0., 0.],
|
||
[0., 1., 0.],
|
||
[0., 0., 1.]], dtype=float32)
|
||
|
||
Indices outside the range [0, num_classes) will be encoded as zeros::
|
||
|
||
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
|
||
Array([[0., 0., 0.],
|
||
[0., 0., 0.]], dtype=float32)
|
||
|
||
Args:
|
||
x: A tensor of indices.
|
||
num_classes: Number of classes in the one-hot dimension.
|
||
dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`).
|
||
axis: the axis or axes along which the function should be
|
||
computed.
|
||
"""
|
||
num_classes = core.concrete_dim_or_error(
|
||
num_classes,
|
||
"The error arose in jax.nn.one_hot argument `num_classes`.")
|
||
x_arr = jnp.asarray(x)
|
||
if not jnp.isdtype(x_arr.dtype, "integral"):
|
||
# Deprecated 2024-12-18
|
||
deprecations.warn(
|
||
'jax-nn-one-hot-float-input',
|
||
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
|
||
stacklevel=1)
|
||
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)
|
||
|
||
|
||
@jax.custom_jvp
|
||
@jax.jit
|
||
def relu6(x: ArrayLike) -> Array:
|
||
r"""Rectified Linear Unit 6 activation function.
|
||
|
||
Computes the element-wise function
|
||
|
||
.. math::
|
||
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
|
||
|
||
except under differentiation, we take:
|
||
|
||
.. math::
|
||
\nabla \mathrm{relu}(0) = 0
|
||
|
||
and
|
||
|
||
.. math::
|
||
\nabla \mathrm{relu}(6) = 0
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`relu`
|
||
"""
|
||
return jnp.minimum(jnp.maximum(x, 0), 6.)
|
||
relu6.defjvps(lambda g, ans, x:
|
||
lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0)))
|
||
|
||
@jax.jit
|
||
def hard_sigmoid(x: ArrayLike) -> Array:
|
||
r"""Hard Sigmoid activation function.
|
||
|
||
Computes the element-wise function
|
||
|
||
.. math::
|
||
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`relu6`
|
||
"""
|
||
return relu6(x + 3.) / 6.
|
||
|
||
@jax.jit
|
||
def hard_silu(x: ArrayLike) -> Array:
|
||
r"""Hard SiLU (swish) activation function
|
||
|
||
Computes the element-wise function
|
||
|
||
.. math::
|
||
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
|
||
|
||
Both :func:`hard_silu` and :func:`hard_swish` are aliases for the same
|
||
function.
|
||
|
||
Args:
|
||
x : input array
|
||
|
||
Returns:
|
||
An array.
|
||
|
||
See also:
|
||
:func:`hard_sigmoid`
|
||
"""
|
||
numpy_util.check_arraylike("hard_silu", x)
|
||
x_arr = jnp.asarray(x)
|
||
return x_arr * hard_sigmoid(x_arr)
|
||
|
||
hard_swish = hard_silu
|
||
|
||
def _get_large_negative(dtype):
|
||
dtype_max = jnp.finfo(dtype).max
|
||
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
|
||
|
||
def _get_causal_mask(T, S):
|
||
mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_))
|
||
return mask[None, None, :, :]
|
||
|
||
def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]):
|
||
query_pos = jnp.array(range(T))
|
||
key_pos = jnp.array(range(S))
|
||
left_window, right_window = local_window_size
|
||
left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window
|
||
right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window
|
||
return jnp.logical_and(right_mask, left_mask)[None, None, :, :]
|
||
|
||
def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen):
|
||
q_mask = True
|
||
kv_mask = True
|
||
if q_seqlen is not None:
|
||
q_indices = jnp.arange(0, T)[None, :, None]
|
||
q_mask = q_indices < q_seqlen[:, None, None]
|
||
if kv_seqlen is not None:
|
||
kv_indices = jnp.arange(0, S)[None, None, :]
|
||
kv_mask = kv_indices < kv_seqlen[:, None, None]
|
||
mask = jnp.logical_and(q_mask, kv_mask)
|
||
return mask[:, None, :, :]
|
||
|
||
def _get_padding_mask_encoded(T, q_seqlen):
|
||
q_indices = jnp.arange(0, T)[None, :]
|
||
mask = q_indices < q_seqlen[:, None]
|
||
return mask[:, :, None, None]
|
||
|
||
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
|
||
local_window_size):
|
||
if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
|
||
return logits
|
||
|
||
combined_mask = jnp.ones_like(logits, dtype=jnp.bool_)
|
||
if mask is not None:
|
||
assert mask.dtype == jnp.bool_
|
||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||
|
||
T, S = logits.shape[2], logits.shape[3]
|
||
|
||
if is_causal:
|
||
mask = _get_causal_mask(T, S)
|
||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||
|
||
if local_window_size is not None:
|
||
mask = _get_window_mask(T, S, local_window_size)
|
||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||
|
||
if q_seqlen is not None or kv_seqlen is not None:
|
||
mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen)
|
||
combined_mask = jnp.logical_and(combined_mask, mask)
|
||
|
||
large_negative_number = _get_large_negative(logits.dtype)
|
||
padded_logits = jnp.where(combined_mask, logits, large_negative_number)
|
||
return padded_logits
|
||
|
||
def _dot_product_attention_core(query, key, value, bias, mask, is_causal,
|
||
scale, q_seqlen, kv_seqlen, local_window_size):
|
||
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
|
||
|
||
# If the query and logits dtypes are different, then the default precision
|
||
# can use inconsistent types in the backwards pass
|
||
# (see https://github.com/jax-ml/jax/issues/24047).
|
||
if query.dtype == jnp.bfloat16:
|
||
precision = jax.lax.DotAlgorithmPreset.BF16_BF16_F32
|
||
elif query.dtype == jnp.float16:
|
||
precision = jax.lax.DotAlgorithmPreset.F16_F16_F32
|
||
# TODO(sbodenstein): Implement this fix for all dtypes.
|
||
else:
|
||
precision = None
|
||
|
||
# Explicit precision will fail on platforms that don't support it. For example,
|
||
# some GPUs do not support BF16_BF16_F32, and TPU does not support F16_F16_F32.
|
||
# Use the default precision as a fallback in these cases.
|
||
try:
|
||
logits = jnp.einsum(
|
||
"BTNH,BSNH->BNTS",
|
||
query,
|
||
key,
|
||
precision=precision,
|
||
preferred_element_type=logits_dtype,
|
||
)
|
||
except: # pylint: disable=bare-except
|
||
logits = jnp.einsum(
|
||
"BTNH,BSNH->BNTS",
|
||
query,
|
||
key,
|
||
precision=None,
|
||
preferred_element_type=logits_dtype,
|
||
)
|
||
|
||
logits *= jnp.array(scale, dtype=logits.dtype)
|
||
|
||
if bias is not None:
|
||
logits = (logits + bias).astype(logits.dtype)
|
||
|
||
padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
|
||
local_window_size)
|
||
|
||
# Softmax and it is always carried out in fp32.
|
||
padded_logits = padded_logits.astype(jnp.float32)
|
||
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype)
|
||
|
||
encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
|
||
if q_seqlen is not None and kv_seqlen is not None:
|
||
mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen)
|
||
encoded *= mask.astype(encoded.dtype)
|
||
return encoded
|
||
|
||
def _dot_product_attention_xla(
|
||
query: Array,
|
||
key: Array,
|
||
value: Array,
|
||
bias: Array | None,
|
||
mask: Array | None,
|
||
is_causal: bool,
|
||
scale: float,
|
||
q_seqlen: Array | None,
|
||
kv_seqlen: Array | None,
|
||
local_window_size: tuple[int, int] | None):
|
||
|
||
B, T, N, H = query.shape
|
||
_, S, K, _ = key.shape
|
||
G = N // K
|
||
|
||
query = jnp.reshape(query, (B, T, K, G, H))
|
||
def _reshape_to_grouped(t):
|
||
if t is not None:
|
||
tB, tN, tT, tS = t.shape
|
||
if tN == 1:
|
||
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
|
||
else:
|
||
assert tN == N
|
||
t = jnp.reshape(t, (tB, K, G, tT, tS))
|
||
return t
|
||
bias = _reshape_to_grouped(bias)
|
||
mask = _reshape_to_grouped(mask)
|
||
vmapped_fn = jax.vmap(
|
||
_dot_product_attention_core,
|
||
in_axes=(3, None, None, 2, 2, None, None, None, None, None),
|
||
out_axes=3,
|
||
)
|
||
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale,
|
||
q_seqlen, kv_seqlen, local_window_size)
|
||
encoded = jnp.reshape(encoded, (B, T, N, H))
|
||
return encoded
|
||
|
||
def bias_fwd_rule(a, query_head_num):
|
||
return bias_fwd_p.bind(a, query_head_num), a
|
||
def bias_bwd_rule(query_head_num, res, g):
|
||
a = res
|
||
if a.shape[0] > 1 or a.shape[-3] != query_head_num:
|
||
raise ValueError("cuDNN only supports bias gradient when the batch size is "
|
||
f"1 and the head number matches the query, but got "
|
||
f"B={a.shape[0]}, N={a.shape[-3]}.")
|
||
return (bias_bwd_p.bind(g, a, query_head_num),)
|
||
|
||
# This function uses two custom primitives, `bias_fwd` and `bias_bwd`, to work
|
||
# around a cuDNN issue where bias gradients are only supported when the batch
|
||
# size is 1 and the number of heads matches the query.
|
||
# TODO(kaixih@nvidia): Remove this workaround once cuDNN resolves the issue.
|
||
@partial(jax.custom_vjp, nondiff_argnums=(1,))
|
||
def check_valid_bias_batch(x, query_head_num):
|
||
output, _ = bias_fwd_rule(x, query_head_num)
|
||
return output
|
||
check_valid_bias_batch.defvjp(bias_fwd_rule, bias_bwd_rule)
|
||
|
||
bias_fwd_p = core.Primitive('bias_fwd')
|
||
bias_fwd_p.multiple_results = False
|
||
bias_bwd_p = core.Primitive('bias_bwd')
|
||
bias_bwd_p.multiple_results = False
|
||
|
||
def bias_fwd_impl(a, query_head_num):
|
||
return a
|
||
def bias_bwd_impl(g, a, query_head_num):
|
||
return g
|
||
bias_fwd_p.def_impl(bias_fwd_impl)
|
||
bias_bwd_p.def_impl(bias_bwd_impl)
|
||
|
||
def bias_fwd_abstract_eval(a, query_head_num):
|
||
return core.ShapedArray(a.shape, a.dtype)
|
||
def bias_bwd_abstract_eval(g, a, query_head_num):
|
||
return core.ShapedArray(g.shape, g.dtype)
|
||
bias_fwd_p.def_abstract_eval(bias_fwd_abstract_eval)
|
||
bias_bwd_p.def_abstract_eval(bias_bwd_abstract_eval)
|
||
|
||
def bias_fwd_lowering(ctx, a, query_head_num):
|
||
return [a]
|
||
def bias_bwd_lowering(ctx, g, a, query_head_num):
|
||
return [g]
|
||
mlir.register_lowering(bias_fwd_p, bias_fwd_lowering)
|
||
mlir.register_lowering(bias_bwd_p, bias_bwd_lowering)
|
||
|
||
def bias_fwd_batch_rule(batched_args, batch_dims):
|
||
x, query_head_num = batched_args
|
||
a = batch_dims[0]
|
||
output, _ = bias_fwd_rule(x, query_head_num)
|
||
return output, a
|
||
def bias_bwd_batch_rule(batched_args, batch_dims):
|
||
g, x, query_head_num = batched_args
|
||
b = batch_dims[0]
|
||
*Bs, _, _, _ = x.shape
|
||
B = math.prod(Bs)
|
||
x = jnp.reshape(x, (B,) + x.shape[-3:])
|
||
output, = bias_bwd_rule(query_head_num, x, g)
|
||
return output, b
|
||
batching.primitive_batchers[bias_fwd_p] = bias_fwd_batch_rule
|
||
batching.primitive_batchers[bias_bwd_p] = bias_bwd_batch_rule
|
||
|
||
def dot_product_attention(
|
||
query: ArrayLike,
|
||
key: ArrayLike,
|
||
value: ArrayLike,
|
||
bias: ArrayLike | None = None,
|
||
mask: ArrayLike | None = None,
|
||
*,
|
||
scale: float | None = None,
|
||
is_causal: bool = False,
|
||
query_seq_lengths: ArrayLike | None = None,
|
||
key_value_seq_lengths: ArrayLike | None = None,
|
||
local_window_size: int | tuple[int, int] | None = None,
|
||
implementation: Literal['xla', 'cudnn'] | None = None) -> Array:
|
||
r"""Scaled dot product attention function.
|
||
|
||
Computes the attention function on Query, Key, and Value tensors:
|
||
|
||
.. math::
|
||
|
||
\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
|
||
|
||
If we define :code:`logits` as the output of :math:`QK^T` and the
|
||
:code:`probs` as the output of :math:`softmax`.
|
||
|
||
Throughout this function, we utilize the following uppercase letters to
|
||
represent the shape of array::
|
||
|
||
B = batch size
|
||
S = length of the key/value (source)
|
||
T = length of the query (target)
|
||
N = number of attention heads
|
||
H = dimensions of each attention head
|
||
K = number of key/value heads
|
||
G = number of groups, which equals to N // K
|
||
|
||
Args:
|
||
query: query array; shape :code:`(BTNH|TNH)`
|
||
key: key array: shape :code:`(BSKH|SKH)`. When `K` equals `N`, multi-headed
|
||
attention (MHA https://arxiv.org/abs/1706.03762) is performed. Otherwise,
|
||
grouped query attention (GQA https://arxiv.org/abs/2305.13245) is
|
||
performed if `N` is a multiple of `K`, and multi-query attention (MQA
|
||
https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case
|
||
of GQA).
|
||
value: value array, should have the same shape as the `key` array.
|
||
bias: optional, bias array to be added to logits; The shape must be 4D and
|
||
be broadcastable to :code:`(BNTS|NTS)`.
|
||
mask: optional, mask array used to filter out logits. It is a boolean mask
|
||
where `True` indicates the element should take part in attention. For an
|
||
additive mask, users should pass it to `bias`. The shape must be 4D and be
|
||
broadcastable to :code:`(BNTS|NTS)`.
|
||
scale: scale for the logits. If None, the scale will be set to 1 divided by
|
||
the square root of query's head dimension (i.e. H).
|
||
is_causal: If true, causal attention will be applied. Note, some
|
||
implementations like `xla` will generate a mask tensor and apply it to the
|
||
logits to mask out the non-causal parts of the attention matrix, but other
|
||
implementations like `cudnn` will avoid computing the non-causal regions,
|
||
providing speedups.
|
||
query_seq_lengths: `int32` array of sequence lengths for query; shape
|
||
:code:`(B)`
|
||
key_value_seq_lengths: `int32` array of sequence lengths for key and value;
|
||
shape :code:`(B)`
|
||
local_window_size: Window sizes to make self attention to attend to each
|
||
token's local window. If set, this specifies the (left_window_size,
|
||
right_window_size) for each token. E.g., if local_window_size == (3, 2)
|
||
and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend
|
||
to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as
|
||
a symmetric window (window_size, window_size).
|
||
implementation: A string to control which implementation backend to use.
|
||
Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults
|
||
to `None`, which will automatically select the best available backend.
|
||
Note, `cudnn` supports only a subset of shapes/dtypes, and an exception
|
||
will be thrown if its not supported.
|
||
|
||
Returns:
|
||
An array of the attention output with the same shape as :code:`query`.
|
||
"""
|
||
output_shape = jnp.asarray(query).shape
|
||
def _ensure_4d(t):
|
||
t = jnp.asarray(t)
|
||
dims_to_add = 4 - t.ndim
|
||
if dims_to_add > 0:
|
||
return jnp.expand_dims(t, axis=tuple(range(dims_to_add)))
|
||
return t
|
||
|
||
query_arr = _ensure_4d(query)
|
||
key_arr = _ensure_4d(key)
|
||
value_arr = _ensure_4d(value)
|
||
bias = _ensure_4d(bias) if bias is not None else None
|
||
mask = _ensure_4d(mask) if mask is not None else None
|
||
if query_seq_lengths is not None:
|
||
query_seq_lengths = jnp.asarray(query_seq_lengths)
|
||
if key_value_seq_lengths is not None:
|
||
key_value_seq_lengths = jnp.asarray(key_value_seq_lengths)
|
||
if isinstance(local_window_size, int):
|
||
local_window_size = (local_window_size, local_window_size)
|
||
|
||
def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
|
||
dtype: DType | None, name: str) -> None:
|
||
if t is None:
|
||
return
|
||
if t.ndim != len(shape):
|
||
raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}")
|
||
if dtype is not None and t.dtype != dtype:
|
||
raise ValueError(f"{name} dtype should be {dtype}, but got {t.dtype}")
|
||
for i in range(t.ndim):
|
||
if shape[i] != -1 and t.shape[i] != shape[i]:
|
||
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}")
|
||
|
||
B, S, K, H = key_arr.shape
|
||
_check_shape_and_dtype(value_arr, [B, S, K, H], key_arr.dtype, 'value')
|
||
_check_shape_and_dtype(query_arr, [B, -1, -1, H], key_arr.dtype, 'query')
|
||
_check_shape_and_dtype(mask, [-1] * 4, jnp.bool_, 'mask')
|
||
_check_shape_and_dtype(bias, [-1] * 4, None, 'bias')
|
||
_check_shape_and_dtype(query_seq_lengths, [B], jnp.int32,
|
||
'query_seq_lengths')
|
||
_check_shape_and_dtype(key_value_seq_lengths, [B], jnp.int32,
|
||
'key_value_seq_lengths')
|
||
if query_arr.shape[-2] % K != 0:
|
||
raise ValueError(f"The number of query heads must be a multiple of "
|
||
f"key/value heads, but got {query_arr.shape[-2]} vs {K}")
|
||
|
||
scale_val = (1.0 / np.sqrt(H)) if scale is None else scale
|
||
|
||
match implementation:
|
||
case 'xla':
|
||
out = _dot_product_attention_xla(
|
||
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
|
||
scale=scale_val, q_seqlen=query_seq_lengths,
|
||
kv_seqlen=key_value_seq_lengths,
|
||
local_window_size=local_window_size,
|
||
)
|
||
case 'cudnn':
|
||
if bias is not None:
|
||
bias = check_valid_bias_batch(bias, query_arr.shape[-2])
|
||
bias = jnp.asarray(bias)
|
||
use_padding = (
|
||
query_seq_lengths is not None or key_value_seq_lengths is not None
|
||
)
|
||
if use_padding:
|
||
if query_seq_lengths is None:
|
||
T = query_arr.shape[1]
|
||
query_seq_lengths = jnp.full((B,), T, dtype=jnp.int32)
|
||
if key_value_seq_lengths is None:
|
||
key_value_seq_lengths = jnp.full((B,), S, dtype=jnp.int32)
|
||
|
||
mask_type = MaskType.NO_MASK
|
||
if use_padding and is_causal:
|
||
mask_type = MaskType.PADDING_CAUSAL
|
||
elif is_causal:
|
||
mask_type = MaskType.CAUSAL
|
||
elif use_padding:
|
||
mask_type = MaskType.PADDING
|
||
# CuDNN supports only the left window with an exclusive boundary when
|
||
# causal mask is enabled.
|
||
sliding_window = None
|
||
if local_window_size is not None:
|
||
l_window, r_window = local_window_size
|
||
if r_window == 0 or mask_type == MaskType.CAUSAL:
|
||
sliding_window = l_window + 1
|
||
else:
|
||
raise ValueError(f"cuDNN doesn't support right window: {r_window} "
|
||
"when causal mask is not used.")
|
||
|
||
out = cudnn_dot_product_attention(
|
||
query_arr, key_arr, value_arr, bias, mask, query_seq_lengths,
|
||
key_value_seq_lengths, scale=scale_val, mask_type=mask_type,
|
||
sliding_window_length=sliding_window,
|
||
)
|
||
case None:
|
||
# TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select
|
||
# best backend.
|
||
out = _dot_product_attention_xla(
|
||
query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal,
|
||
scale=scale_val, q_seqlen=query_seq_lengths,
|
||
kv_seqlen=key_value_seq_lengths,
|
||
local_window_size=local_window_size,
|
||
)
|
||
case _:
|
||
raise ValueError(f"Unsupported implementation option: {implementation}")
|
||
|
||
return jnp.reshape(out, output_shape)
|
||
|
||
def scaled_matmul(
|
||
lhs: Array,
|
||
rhs: Array,
|
||
lhs_scales: Array,
|
||
rhs_scales: Array,
|
||
preferred_element_type: DTypeLike = jnp.float32,
|
||
) -> Array:
|
||
r"""Scaled matrix multiplication function.
|
||
|
||
Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
|
||
The last dim is the contracting dim, and block size is inferred.
|
||
|
||
Mathematically, this operation is equivalent to::
|
||
|
||
a_block_size = a.shape[-1] // a_scales.shape[-1]
|
||
b_block_size = b.shape[-1] // b_scales.shape[-1]
|
||
a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
|
||
b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1)
|
||
jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
|
||
|
||
Args:
|
||
lhs (Array): Operand a, shape (B, M, K).
|
||
rhs (Array): Operand b, shape (B, N, K).
|
||
lhs_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
|
||
rhs_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
|
||
preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
|
||
|
||
Returns:
|
||
Array of shape (B, M, N).
|
||
|
||
Notes:
|
||
- We currently do not support user-defined `precision` for customizing the
|
||
compute data type. It is fixed to `jnp.float32`.
|
||
- Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`.
|
||
- To use cuDNN with Nvidia Blackwell GPUs, inputs must match::
|
||
|
||
# mxfp8
|
||
a, b: jnp.float8_e4m3fn | jnp.float8_e5m2
|
||
a_scales, b_scales: jnp.float8_e8m0fnu
|
||
block_size: 32
|
||
# nvfp4
|
||
a, b: jnp.float4_e2m1fn
|
||
a_scales, b_scales: jnp.float8_e4m3fn
|
||
block_size: 16
|
||
|
||
Examples:
|
||
|
||
Basic case:
|
||
|
||
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
|
||
>>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
|
||
>>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
|
||
>>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
|
||
>>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP
|
||
Array([[[8.]]], dtype=float32)
|
||
|
||
Using fused cuDNN call on Blackwell GPUs:
|
||
|
||
>>> dtype = jnp.float8_e4m3fn
|
||
>>> a = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64), dtype=dtype)
|
||
>>> b = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64), dtype=dtype)
|
||
>>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
|
||
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
|
||
>>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP
|
||
"""
|
||
a, b, a_scales, b_scales = lhs, rhs, lhs_scales, rhs_scales
|
||
if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)):
|
||
raise ValueError(
|
||
"scaled_matmul requires all inputs to be 3-dimensional arrays"
|
||
)
|
||
|
||
B_a, M_a, K_a = a.shape
|
||
B_b, N_b, K_b = b.shape
|
||
if K_a != K_b or B_a != B_b:
|
||
raise ValueError(
|
||
"scaled_matmul requires inputs a and b to have matching batch (B) "
|
||
f"and contract (K) dimensions, but got shapes {a.shape} and "
|
||
f"{b.shape}"
|
||
)
|
||
|
||
B_as, M_as, K_as = a_scales.shape
|
||
B_bs, N_bs, K_bs = b_scales.shape
|
||
if K_as != K_bs or B_as != B_bs:
|
||
raise ValueError(
|
||
"scaled_matmul requires scales to have matching batch (B) and "
|
||
f"contract (K) dimensions, but got shapes {a_scales.shape} and "
|
||
f"{b_scales.shape}"
|
||
)
|
||
|
||
if M_as != M_a or N_bs != N_b:
|
||
raise ValueError(
|
||
"scaled_matmul requires scales to match non-contract dimensions of "
|
||
f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: "
|
||
f"{a_scales.shape}, b_scales: {b_scales.shape}"
|
||
)
|
||
|
||
preferred_element_type = dtypes.canonicalize_dtype(
|
||
np.dtype(preferred_element_type)
|
||
)
|
||
out = cudnn_scaled_matmul(
|
||
a,
|
||
b,
|
||
a_scales,
|
||
b_scales,
|
||
preferred_element_type=preferred_element_type,
|
||
)
|
||
return out
|
||
|
||
def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'],
|
||
global_scale: Array | None = None):
|
||
r"""Get quantization configs for scaled_dot_general.
|
||
|
||
Create quantization configs for the `jax.nn.scaled_dot_general`.
|
||
|
||
See Also:
|
||
- :func:`jax.nn.scaled_dot_general`: Scaled dot general function.
|
||
"""
|
||
|
||
if mode == 'nvfp4':
|
||
one = jnp.ones((1,), dtype=jnp.float32)
|
||
return BlockScaleConfig(
|
||
mode='nvfp4',
|
||
block_size=16,
|
||
data_type=jnp.float4_e2m1fn,
|
||
scale_type=jnp.float8_e4m3fn,
|
||
global_scale=one if global_scale is None else global_scale,
|
||
infer_only=False
|
||
)
|
||
elif mode == 'mxfp8':
|
||
return BlockScaleConfig(
|
||
mode='mxfp8',
|
||
block_size=32,
|
||
data_type=jnp.float8_e4m3fn,
|
||
scale_type=jnp.float8_e8m0fnu,
|
||
global_scale=None,
|
||
infer_only=False
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported mode: {mode}")
|
||
|
||
def scaled_dot_general(
|
||
lhs, rhs,
|
||
dimension_numbers,
|
||
preferred_element_type=jnp.float32,
|
||
configs: List[BlockScaleConfig] | None = None,
|
||
implementation: Literal['cudnn'] | None = None,
|
||
):
|
||
r"""Scaled dot general operation.
|
||
|
||
Performs a generalized dot product with block-scaled quantization on the
|
||
lhs and rhs inputs. This operation extends `lax.dot_general` to support
|
||
user-defined scaling configurations.
|
||
|
||
Essentially, the operation follows::
|
||
|
||
a, a_scales = quantize(lhs, configs[0])
|
||
b, b_scales = quantize(rhs, configs[1])
|
||
c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
|
||
|
||
Args:
|
||
lhs (ArrayLike): Input array.
|
||
rhs (ArrayLike): Input array.
|
||
dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying
|
||
the contraction and batch dimensions:
|
||
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
|
||
preferred_element_type (DTypeLike, optional): Output data type of the dot
|
||
product. Defaults to `jnp.float32`. Other valid types include
|
||
`jnp.bfloat16` and `jnp.float16`.
|
||
configs (list of BlockScaleConfig, optional): Scaling configurations for
|
||
lhs, rhs, and gradients. Users can obtain valid configurations via
|
||
`jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
|
||
are supported. If `None`, falls back to `lax.dot_general`.
|
||
implementation: str
|
||
(Deprecated) Backend selector, now ignored. The system chooses the backend
|
||
automatically. Scheduled for removal in future releases.
|
||
|
||
Returns:
|
||
Array: The resulting tensor, with batch dimensions first, followed by
|
||
non-contracting/non-batch dimensions of lhs, and then those of rhs.
|
||
|
||
See Also:
|
||
- :func:`jax.nn.scaled_matmul`: Scaled matmul function.
|
||
- :func:`jax.lax.dot_general`: General dot product operator.
|
||
|
||
Notes:
|
||
- Unlike `nn.scaled_matmul`, which assumes quantized low-precision
|
||
inputs with explicit scaling factors, this operator takes high-precision
|
||
inputs, applies quantization internally, and handles the backward pass.
|
||
|
||
Examples:
|
||
|
||
Creating config for mxfp8:
|
||
|
||
>>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3
|
||
|
||
Creating config for nvfp4:
|
||
|
||
>>> global_scale = jnp.array([0.5], jnp.float32)
|
||
>>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3
|
||
|
||
Using scaled_dot_general with the configs:
|
||
|
||
>>> import functools
|
||
>>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
|
||
>>> lhs = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64))
|
||
>>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64))
|
||
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) # doctest: +SKIP
|
||
"""
|
||
if implementation is not None:
|
||
warnings.warn("Backend selector, now ignored. The system chooses the "
|
||
"backend automatically.", DeprecationWarning)
|
||
|
||
if configs is None:
|
||
return lax.dot_general(lhs, rhs, dimension_numbers,
|
||
preferred_element_type=preferred_element_type)
|
||
|
||
out = cudnn_scaled_dot_general(
|
||
lhs, rhs, dimension_numbers,
|
||
preferred_element_type=preferred_element_type,
|
||
configs=configs
|
||
)
|
||
|
||
return out
|