mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #17760 from superbobry:array-any
PiperOrigin-RevId: 570400629
This commit is contained in:
commit
c3e73c67aa
@ -666,6 +666,7 @@ pytype_strict_library(
|
||||
":core",
|
||||
":effects",
|
||||
":pretty_printer",
|
||||
":typing",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
@ -33,12 +33,12 @@ from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize, weakref_lru_cache)
|
||||
|
||||
|
||||
Array = Any
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
@ -116,7 +116,7 @@ class RaggedAxis:
|
||||
# For each axis, we store its index and the corresponding segment lengths.
|
||||
# For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i]
|
||||
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
|
||||
ragged_axes: tuple[tuple[int, Array], ...]
|
||||
ragged_axes: tuple[tuple[int, Any], ...]
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
@ -148,8 +148,10 @@ def _sorted_ragged_axis(stacked_axis, ragged_axes):
|
||||
return RaggedAxis(stacked_axis, tuple(sorted(ragged_axes, key=lambda p: p[0])))
|
||||
|
||||
def make_batch_axis(
|
||||
ndim: int, stacked_axis: int, ragged_axes: list[tuple[int, Array]]
|
||||
) -> int | RaggedAxis:
|
||||
ndim: int,
|
||||
stacked_axis: int,
|
||||
ragged_axes: list[tuple[int, Array | core.Var]],
|
||||
) -> int | RaggedAxis:
|
||||
if ragged_axes:
|
||||
canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
|
||||
return _sorted_ragged_axis(canonicalize_axis(stacked_axis, ndim), canonical)
|
||||
|
@ -70,7 +70,6 @@ Todos::
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -88,9 +87,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
|
||||
Array = Any
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
def approx_max_k(operand: Array,
|
||||
|
@ -39,6 +39,7 @@ from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.state import utils as state_utils
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
|
||||
split_list, split_dict)
|
||||
from jax._src.lax.control_flow import loops
|
||||
@ -53,7 +54,6 @@ zip, unsafe_zip = safe_zip, zip
|
||||
S = TypeVar('S')
|
||||
T = TypeVar('T')
|
||||
class Ref(Generic[T]): pass
|
||||
Array = Any
|
||||
|
||||
ref_set = state_primitives.ref_set
|
||||
ref_get = state_primitives.ref_get
|
||||
|
@ -52,6 +52,7 @@ from jax._src import state
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
|
||||
unzip2, weakref_lru_cache, merge_lists)
|
||||
import numpy as np
|
||||
@ -64,7 +65,6 @@ _map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
T = TypeVar('T')
|
||||
Array = Any
|
||||
BooleanNumeric = Any # A bool, or a Boolean array.
|
||||
|
||||
### Helper functions
|
||||
|
@ -12,11 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
from typing import NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -28,14 +27,9 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.typing import Array, DTypeLike
|
||||
|
||||
|
||||
_max = builtins.max
|
||||
|
||||
Array = Any
|
||||
DType = Any
|
||||
Shape = core.Shape
|
||||
|
||||
class ConvDimensionNumbers(NamedTuple):
|
||||
"""Describes batch, spatial, and feature dimensions of a convolution.
|
||||
|
||||
@ -62,7 +56,7 @@ def conv_general_dilated(
|
||||
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
|
||||
feature_group_count: int = 1, batch_group_count: int = 1,
|
||||
precision: lax.PrecisionLike = None,
|
||||
preferred_element_type: Optional[DType] = None) -> Array:
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
"""General n-dimensional convolution operator, with optional dilation.
|
||||
|
||||
Wraps XLA's `Conv
|
||||
@ -174,7 +168,7 @@ def conv_general_dilated(
|
||||
|
||||
def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
|
||||
padding: str, precision: lax.PrecisionLike = None,
|
||||
preferred_element_type: Optional[DType] = None) -> Array:
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
"""Convenience wrapper around `conv_general_dilated`.
|
||||
|
||||
Args:
|
||||
@ -204,7 +198,7 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
|
||||
lhs_dilation: Optional[Sequence[int]],
|
||||
rhs_dilation: Optional[Sequence[int]],
|
||||
precision: lax.PrecisionLike = None,
|
||||
preferred_element_type: Optional[DType] = None) -> Array:
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
"""Convenience wrapper around `conv_general_dilated`.
|
||||
|
||||
Args:
|
||||
@ -256,7 +250,7 @@ def _conv_transpose_padding(k, s, padding):
|
||||
else:
|
||||
pad_a = int(np.ceil(pad_len / 2))
|
||||
elif padding == 'VALID':
|
||||
pad_len = k + s - 2 + _max(k - s, 0)
|
||||
pad_len = k + s - 2 + max(k - s, 0)
|
||||
pad_a = k - 1
|
||||
else:
|
||||
raise ValueError('Padding mode must be `SAME` or `VALID`.')
|
||||
@ -277,7 +271,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
|
||||
transpose_kernel: bool = False,
|
||||
precision: lax.PrecisionLike = None,
|
||||
preferred_element_type: Optional[DType] = None) -> Array:
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
"""Convenience wrapper for calculating the N-d convolution "transpose".
|
||||
|
||||
This function directly calculates a fractionally strided conv rather than
|
||||
@ -343,7 +337,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
if transpose_kernel:
|
||||
# flip spatial dims and swap input / output channel axes
|
||||
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
|
||||
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
|
||||
rhs = rhs.swapaxes(dn.rhs_spec[0], dn.rhs_spec[1])
|
||||
return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -36,12 +36,11 @@ from jax._src.lax import slicing
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
from jax._src.typing import Array
|
||||
|
||||
map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
||||
Array = Any
|
||||
|
||||
|
||||
def reduce_window(operand, init_value, computation: Callable,
|
||||
window_dimensions: core.Shape, window_strides: Sequence[int],
|
||||
|
@ -28,15 +28,16 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import AxisName
|
||||
from jax._src.numpy import util as numpy_util
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.ops.special import logsumexp as _logsumexp
|
||||
|
||||
Array = Any
|
||||
|
||||
# activations
|
||||
|
||||
@custom_jvp
|
||||
@jax.jit
|
||||
def relu(x: Array) -> Array:
|
||||
def relu(x: ArrayLike) -> Array:
|
||||
r"""Rectified linear unit activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -72,7 +73,7 @@ def relu(x: Array) -> Array:
|
||||
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
|
||||
|
||||
@jax.jit
|
||||
def softplus(x: Array) -> Array:
|
||||
def softplus(x: ArrayLike) -> Array:
|
||||
r"""Softplus activation function.
|
||||
|
||||
Computes the element-wise function
|
||||
@ -86,7 +87,7 @@ def softplus(x: Array) -> Array:
|
||||
return jnp.logaddexp(x, 0)
|
||||
|
||||
@jax.jit
|
||||
def soft_sign(x: Array) -> Array:
|
||||
def soft_sign(x: ArrayLike) -> Array:
|
||||
r"""Soft-sign activation function.
|
||||
|
||||
Computes the element-wise function
|
||||
@ -97,10 +98,12 @@ def soft_sign(x: Array) -> Array:
|
||||
Args:
|
||||
x : input array
|
||||
"""
|
||||
return x / (jnp.abs(x) + 1)
|
||||
numpy_util.check_arraylike("soft_sign", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
return x_arr / (jnp.abs(x_arr) + 1)
|
||||
|
||||
@jax.jit
|
||||
def sigmoid(x: Array) -> Array:
|
||||
def sigmoid(x: ArrayLike) -> Array:
|
||||
r"""Sigmoid activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -121,7 +124,7 @@ def sigmoid(x: Array) -> Array:
|
||||
return lax.logistic(x)
|
||||
|
||||
@jax.jit
|
||||
def silu(x: Array) -> Array:
|
||||
def silu(x: ArrayLike) -> Array:
|
||||
r"""SiLU (a.k.a. swish) activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -140,12 +143,14 @@ def silu(x: Array) -> Array:
|
||||
See also:
|
||||
:func:`sigmoid`
|
||||
"""
|
||||
return x * sigmoid(x)
|
||||
numpy_util.check_arraylike("silu", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
return x_arr * sigmoid(x_arr)
|
||||
|
||||
swish = silu
|
||||
|
||||
@jax.jit
|
||||
def log_sigmoid(x: Array) -> Array:
|
||||
def log_sigmoid(x: ArrayLike) -> Array:
|
||||
r"""Log-sigmoid activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -162,10 +167,12 @@ def log_sigmoid(x: Array) -> Array:
|
||||
See also:
|
||||
:func:`sigmoid`
|
||||
"""
|
||||
return -softplus(-x)
|
||||
numpy_util.check_arraylike("log_sigmoid", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
return -softplus(-x_arr)
|
||||
|
||||
@jax.jit
|
||||
def elu(x: Array, alpha: Array = 1.0) -> Array:
|
||||
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
|
||||
r"""Exponential linear unit activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -186,11 +193,14 @@ def elu(x: Array, alpha: Array = 1.0) -> Array:
|
||||
See also:
|
||||
:func:`selu`
|
||||
"""
|
||||
safe_x = jnp.where(x > 0, 0., x)
|
||||
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))
|
||||
numpy_util.check_arraylike("elu", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
return jnp.where(x_arr > 0,
|
||||
x_arr,
|
||||
alpha * jnp.expm1(jnp.where(x_arr > 0, 0., x_arr)))
|
||||
|
||||
@jax.jit
|
||||
def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array:
|
||||
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array:
|
||||
r"""Leaky rectified linear unit activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -213,10 +223,12 @@ def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array:
|
||||
See also:
|
||||
:func:`relu`
|
||||
"""
|
||||
return jnp.where(x >= 0, x, negative_slope * x)
|
||||
numpy_util.check_arraylike("leaky_relu", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
return jnp.where(x_arr >= 0, x_arr, negative_slope * x_arr)
|
||||
|
||||
@jax.jit
|
||||
def hard_tanh(x: Array) -> Array:
|
||||
def hard_tanh(x: ArrayLike) -> Array:
|
||||
r"""Hard :math:`\mathrm{tanh}` activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -234,10 +246,12 @@ def hard_tanh(x: Array) -> Array:
|
||||
Returns:
|
||||
An array.
|
||||
"""
|
||||
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))
|
||||
numpy_util.check_arraylike("hard_tanh", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
return jnp.where(x_arr > 1, 1, jnp.where(x_arr < -1, -1, x_arr))
|
||||
|
||||
@jax.jit
|
||||
def celu(x: Array, alpha: Array = 1.0) -> Array:
|
||||
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
|
||||
r"""Continuously-differentiable exponential linear unit activation.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -262,7 +276,7 @@ def celu(x: Array, alpha: Array = 1.0) -> Array:
|
||||
return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)
|
||||
|
||||
@jax.jit
|
||||
def selu(x: Array) -> Array:
|
||||
def selu(x: ArrayLike) -> Array:
|
||||
r"""Scaled exponential linear unit activation.
|
||||
|
||||
Computes the element-wise function:
|
||||
@ -295,7 +309,7 @@ def selu(x: Array) -> Array:
|
||||
|
||||
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
|
||||
# @partial(jax.jit, static_argnames=("approximate",))
|
||||
def gelu(x: Array, approximate: bool = True) -> Array:
|
||||
def gelu(x: ArrayLike, approximate: bool = True) -> Array:
|
||||
r"""Gaussian error linear unit activation function.
|
||||
|
||||
If ``approximate=False``, computes the element-wise function:
|
||||
@ -317,20 +331,18 @@ def gelu(x: Array, approximate: bool = True) -> Array:
|
||||
x : input array
|
||||
approximate: whether to use the approximate or exact formulation.
|
||||
"""
|
||||
|
||||
# Promote to nearest float-like dtype.
|
||||
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
|
||||
[x_arr] = numpy_util.promote_args_inexact("gelu", x)
|
||||
|
||||
if approximate:
|
||||
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
|
||||
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
|
||||
return x * cdf
|
||||
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x_arr.dtype)
|
||||
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3))))
|
||||
return x_arr * cdf
|
||||
else:
|
||||
sqrt_2 = np.sqrt(2).astype(x.dtype)
|
||||
return jnp.array(x * (lax.erf(x / sqrt_2) + 1) / 2, dtype=x.dtype)
|
||||
sqrt_2 = np.sqrt(2).astype(x_arr.dtype)
|
||||
return jnp.array(x_arr * (lax.erf(x_arr / sqrt_2) + 1) / 2, dtype=x_arr.dtype)
|
||||
|
||||
@partial(jax.jit, static_argnames=("axis",))
|
||||
def glu(x: Array, axis: int = -1) -> Array:
|
||||
def glu(x: ArrayLike, axis: int = -1) -> Array:
|
||||
r"""Gated linear unit activation function.
|
||||
|
||||
Computes the function:
|
||||
@ -353,9 +365,11 @@ def glu(x: Array, axis: int = -1) -> Array:
|
||||
See also:
|
||||
:func:`sigmoid`
|
||||
"""
|
||||
size = x.shape[axis]
|
||||
numpy_util.check_arraylike("glu", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
size = x_arr.shape[axis]
|
||||
assert size % 2 == 0, "axis size must be divisible by 2"
|
||||
x1, x2 = jnp.split(x, 2, axis)
|
||||
x1, x2 = jnp.split(x_arr, 2, axis)
|
||||
return x1 * sigmoid(x2)
|
||||
|
||||
# other functions
|
||||
@ -364,10 +378,10 @@ logsumexp = _logsumexp
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnames=("axis",))
|
||||
def log_softmax(x: Array,
|
||||
def log_softmax(x: ArrayLike,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
where: Optional[Array] = None,
|
||||
initial: Optional[Array] = None) -> Array:
|
||||
where: Optional[ArrayLike] = None,
|
||||
initial: Optional[ArrayLike] = None) -> Array:
|
||||
r"""Log-Softmax function.
|
||||
|
||||
Computes the logarithm of the :code:`softmax` function, which rescales
|
||||
@ -391,8 +405,10 @@ def log_softmax(x: Array,
|
||||
See also:
|
||||
:func:`softmax`
|
||||
"""
|
||||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||||
shifted = x - lax.stop_gradient(x_max)
|
||||
numpy_util.check_arraylike("log_softmax", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
x_max = jnp.max(x_arr, axis, where=where, initial=initial, keepdims=True)
|
||||
shifted = x_arr - lax.stop_gradient(x_max)
|
||||
shifted_logsumexp = jnp.log(
|
||||
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
|
||||
result = shifted - shifted_logsumexp
|
||||
@ -403,10 +419,10 @@ def log_softmax(x: Array,
|
||||
|
||||
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
|
||||
#@partial(jax.jit, static_argnames=("axis",))
|
||||
def softmax(x: Array,
|
||||
def softmax(x: ArrayLike,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
where: Optional[Array] = None,
|
||||
initial: Optional[Array] = None) -> Array:
|
||||
where: Optional[ArrayLike] = None,
|
||||
initial: Optional[ArrayLike] = None) -> Array:
|
||||
r"""Softmax function.
|
||||
|
||||
Computes the function which rescales elements to the range :math:`[0, 1]`
|
||||
@ -431,17 +447,20 @@ def softmax(x: Array,
|
||||
:func:`log_softmax`
|
||||
"""
|
||||
if jax.config.jax_softmax_custom_jvp:
|
||||
return _softmax(x, axis, where, initial)
|
||||
# mypy is confused by the `functools.partial` application in the definition
|
||||
# of `_softmax` and incorrectly concludes that `_softmax` returns
|
||||
# `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`.
|
||||
return _softmax(x, axis, where, initial) # type: ignore[return-value]
|
||||
else:
|
||||
return _softmax_deprecated(x, axis, where, initial)
|
||||
|
||||
# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
|
||||
@partial(jax.custom_jvp, nondiff_argnums=(1,))
|
||||
def _softmax(
|
||||
x,
|
||||
x: ArrayLike,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
where: Optional[Array] = None,
|
||||
initial: Optional[Array] = None) -> Array:
|
||||
where: Optional[ArrayLike] = None,
|
||||
initial: Optional[ArrayLike] = None) -> Array:
|
||||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||||
unnormalized = jnp.exp(x - x_max)
|
||||
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||||
@ -455,7 +474,11 @@ def _softmax_jvp(axis, primals, tangents):
|
||||
y = _softmax(x, axis, where, initial)
|
||||
return y, y * (x_dot - (y * x_dot).sum(axis, where=where, keepdims=True))
|
||||
|
||||
def _softmax_deprecated(x, axis, where, initial):
|
||||
def _softmax_deprecated(
|
||||
x: ArrayLike,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
where: Optional[ArrayLike] = None,
|
||||
initial: Optional[ArrayLike] = None) -> Array:
|
||||
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
|
||||
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
|
||||
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
|
||||
@ -465,13 +488,15 @@ def _softmax_deprecated(x, axis, where, initial):
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnames=("axis",))
|
||||
def standardize(x: Array,
|
||||
def standardize(x: ArrayLike,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
mean: Optional[Array] = None,
|
||||
variance: Optional[Array] = None,
|
||||
epsilon: Array = 1e-5,
|
||||
where: Optional[Array] = None) -> Array:
|
||||
mean: Optional[ArrayLike] = None,
|
||||
variance: Optional[ArrayLike] = None,
|
||||
epsilon: ArrayLike = 1e-5,
|
||||
where: Optional[ArrayLike] = None) -> Array:
|
||||
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
|
||||
numpy_util.check_arraylike("standardize", x)
|
||||
numpy_util.check_arraylike_or_none("standardize", mean, variance, where)
|
||||
if mean is None:
|
||||
mean = jnp.mean(x, axis, keepdims=True, where=where)
|
||||
if variance is None:
|
||||
@ -481,43 +506,45 @@ def standardize(x: Array,
|
||||
# when used in neural network normalization layers
|
||||
variance = jnp.mean(
|
||||
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
|
||||
return (x - mean) * lax.rsqrt(variance + epsilon)
|
||||
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)
|
||||
|
||||
def normalize(x: Array,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
mean: Optional[Array] = None,
|
||||
variance: Optional[Array] = None,
|
||||
epsilon: Array = 1e-5,
|
||||
where: Optional[Array] = None) -> Array:
|
||||
def normalize(x: ArrayLike,
|
||||
axis: Optional[Union[int, tuple[int, ...]]] = -1,
|
||||
mean: Optional[ArrayLike] = None,
|
||||
variance: Optional[ArrayLike] = None,
|
||||
epsilon: ArrayLike = 1e-5,
|
||||
where: Optional[ArrayLike] = None) -> Array:
|
||||
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
|
||||
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
|
||||
return standardize(x, axis, mean, variance, epsilon, where)
|
||||
|
||||
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
|
||||
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
|
||||
def _one_hot(x: Array, num_classes: int, *,
|
||||
def _one_hot(x: Any, num_classes: int, *,
|
||||
dtype: Any, axis: Union[int, AxisName]) -> Array:
|
||||
num_classes = core.concrete_dim_or_error(
|
||||
num_classes,
|
||||
"The error arose in jax.nn.one_hot argument `num_classes`.")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
x = jnp.asarray(x)
|
||||
x_arr = jnp.asarray(x)
|
||||
try:
|
||||
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
|
||||
output_pos_axis = util.canonicalize_axis(axis, x_arr.ndim + 1)
|
||||
except TypeError:
|
||||
axis_size = lax.psum(1, axis)
|
||||
if num_classes != axis_size:
|
||||
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
|
||||
f"but {num_classes} != {axis_size}") from None
|
||||
axis_idx = lax.axis_index(axis)
|
||||
return jnp.asarray(x == axis_idx, dtype=dtype)
|
||||
return jnp.asarray(x_arr == axis_idx, dtype=dtype)
|
||||
axis = operator.index(axis) # type: ignore[arg-type]
|
||||
lhs = lax.expand_dims(x, (axis,))
|
||||
rhs_shape = [1] * x.ndim
|
||||
lhs = lax.expand_dims(x_arr, (axis,))
|
||||
rhs_shape = [1] * x_arr.ndim
|
||||
rhs_shape.insert(output_pos_axis, num_classes)
|
||||
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
|
||||
rhs = lax.broadcasted_iota(x_arr.dtype, rhs_shape, output_pos_axis)
|
||||
return jnp.asarray(lhs == rhs, dtype=dtype)
|
||||
|
||||
def one_hot(x: Array, num_classes: int, *,
|
||||
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
|
||||
def one_hot(x: Any, num_classes: int, *,
|
||||
dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array:
|
||||
"""One-hot encodes the given indices.
|
||||
|
||||
@ -550,7 +577,7 @@ def one_hot(x: Array, num_classes: int, *,
|
||||
|
||||
@jax.custom_jvp
|
||||
@jax.jit
|
||||
def relu6(x: Array) -> Array:
|
||||
def relu6(x: ArrayLike) -> Array:
|
||||
r"""Rectified Linear Unit 6 activation function.
|
||||
|
||||
Computes the element-wise function
|
||||
@ -582,7 +609,7 @@ relu6.defjvps(lambda g, ans, x:
|
||||
lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0)))
|
||||
|
||||
@jax.jit
|
||||
def hard_sigmoid(x: Array) -> Array:
|
||||
def hard_sigmoid(x: ArrayLike) -> Array:
|
||||
r"""Hard Sigmoid activation function.
|
||||
|
||||
Computes the element-wise function
|
||||
@ -602,7 +629,7 @@ def hard_sigmoid(x: Array) -> Array:
|
||||
return relu6(x + 3.) / 6.
|
||||
|
||||
@jax.jit
|
||||
def hard_silu(x: Array) -> Array:
|
||||
def hard_silu(x: ArrayLike) -> Array:
|
||||
r"""Hard SiLU (swish) activation function
|
||||
|
||||
Computes the element-wise function
|
||||
@ -622,6 +649,8 @@ def hard_silu(x: Array) -> Array:
|
||||
See also:
|
||||
:func:`hard_sigmoid`
|
||||
"""
|
||||
return x * hard_sigmoid(x)
|
||||
numpy_util.check_arraylike("hard_silu", x)
|
||||
x_arr = jnp.asarray(x)
|
||||
return x_arr * hard_sigmoid(x_arr)
|
||||
|
||||
hard_swish = hard_silu
|
||||
|
@ -23,18 +23,17 @@ from typing import Any, Literal, Protocol, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import set_module
|
||||
|
||||
export = set_module('jax.nn.initializers')
|
||||
|
||||
KeyArray = jax.Array
|
||||
Array = Any
|
||||
KeyArray = Array
|
||||
# TODO: Import or define these to match
|
||||
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
|
||||
DTypeLikeFloat = Any
|
||||
@ -48,7 +47,7 @@ class Initializer(Protocol):
|
||||
def __call__(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Array:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@export
|
||||
def zeros(key: KeyArray,
|
||||
@ -82,7 +81,7 @@ def ones(key: KeyArray,
|
||||
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
@export
|
||||
def constant(value: Array,
|
||||
def constant(value: ArrayLike,
|
||||
dtype: DTypeLikeInexact = jnp.float_
|
||||
) -> Initializer:
|
||||
"""Builds an initializer that returns arrays full of a constant ``value``.
|
||||
@ -240,7 +239,7 @@ def _complex_uniform(key: KeyArray,
|
||||
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
|
||||
return r * jnp.exp(1j * theta)
|
||||
|
||||
def _complex_truncated_normal(key: KeyArray, upper: Array,
|
||||
def _complex_truncated_normal(key: KeyArray, upper: ArrayLike,
|
||||
shape: Union[Sequence[int], core.NamedShape],
|
||||
dtype: DTypeLikeInexact) -> Array:
|
||||
"""
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
import sys
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -31,9 +31,9 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
Array = Any
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import EllipsisType
|
||||
SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType]
|
||||
@ -154,8 +154,8 @@ def _get_identity(op, dtype):
|
||||
|
||||
|
||||
def _segment_update(name: str,
|
||||
data: Array,
|
||||
segment_ids: Array,
|
||||
data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
scatter_op: Callable,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
@ -195,8 +195,8 @@ def _segment_update(name: str,
|
||||
return reducer(out, axis=0).astype(dtype)
|
||||
|
||||
|
||||
def segment_sum(data: Array,
|
||||
segment_ids: Array,
|
||||
def segment_sum(data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
@ -250,8 +250,8 @@ def segment_sum(data: Array,
|
||||
indices_are_sorted, unique_indices, bucket_size, reductions.sum, mode=mode)
|
||||
|
||||
|
||||
def segment_prod(data: Array,
|
||||
segment_ids: Array,
|
||||
def segment_prod(data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
@ -306,8 +306,8 @@ def segment_prod(data: Array,
|
||||
indices_are_sorted, unique_indices, bucket_size, reductions.prod, mode=mode)
|
||||
|
||||
|
||||
def segment_max(data: Array,
|
||||
segment_ids: Array,
|
||||
def segment_max(data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
@ -361,8 +361,8 @@ def segment_max(data: Array,
|
||||
indices_are_sorted, unique_indices, bucket_size, reductions.max, mode=mode)
|
||||
|
||||
|
||||
def segment_min(data: Array,
|
||||
segment_ids: Array,
|
||||
def segment_min(data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
|
@ -12,18 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
|
||||
from typing import Any, Callable, NamedTuple, Optional, Union
|
||||
from typing import Callable, NamedTuple, Optional, Union
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src.scipy.optimize.line_search import line_search
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
Array = Any
|
||||
|
||||
class LBFGSResults(NamedTuple):
|
||||
"""Results from L-BFGS optimization
|
||||
|
@ -23,14 +23,13 @@ from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.typing import Array
|
||||
|
||||
## JAX utilities
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Array = Any
|
||||
|
||||
_ref_effect_color = pp.Color.GREEN
|
||||
|
||||
class RefEffect(effects.JaxprInputEffect):
|
||||
|
Loading…
x
Reference in New Issue
Block a user