Merge pull request #17760 from superbobry:array-any

PiperOrigin-RevId: 570400629
This commit is contained in:
jax authors 2023-10-03 08:50:07 -07:00
commit c3e73c67aa
12 changed files with 138 additions and 117 deletions

View File

@ -666,6 +666,7 @@ pytype_strict_library(
":core",
":effects",
":pretty_printer",
":typing",
":util",
],
)

View File

@ -33,12 +33,12 @@ from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.typing import Array
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)
Array = Any
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -116,7 +116,7 @@ class RaggedAxis:
# For each axis, we store its index and the corresponding segment lengths.
# For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
ragged_axes: tuple[tuple[int, Array], ...]
ragged_axes: tuple[tuple[int, Any], ...]
@property
def size(self):
@ -148,8 +148,10 @@ def _sorted_ragged_axis(stacked_axis, ragged_axes):
return RaggedAxis(stacked_axis, tuple(sorted(ragged_axes, key=lambda p: p[0])))
def make_batch_axis(
ndim: int, stacked_axis: int, ragged_axes: list[tuple[int, Array]]
) -> int | RaggedAxis:
ndim: int,
stacked_axis: int,
ragged_axes: list[tuple[int, Array | core.Var]],
) -> int | RaggedAxis:
if ragged_axes:
canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
return _sorted_ragged_axis(canonicalize_axis(stacked_axis, ndim), canonical)

View File

@ -70,7 +70,6 @@ Todos::
"""
from functools import partial
from typing import Any
import numpy as np
@ -88,9 +87,7 @@ from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import hlo
Array = Any
from jax._src.typing import Array
def approx_max_k(operand: Array,

View File

@ -39,6 +39,7 @@ from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as state_primitives
from jax._src.state import utils as state_utils
from jax._src.state import types as state_types
from jax._src.typing import Array
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list, split_dict)
from jax._src.lax.control_flow import loops
@ -53,7 +54,6 @@ zip, unsafe_zip = safe_zip, zip
S = TypeVar('S')
T = TypeVar('T')
class Ref(Generic[T]): pass
Array = Any
ref_set = state_primitives.ref_set
ref_get = state_primitives.ref_get

View File

@ -52,6 +52,7 @@ from jax._src import state
from jax._src.state import discharge as state_discharge
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.typing import Array
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
unzip2, weakref_lru_cache, merge_lists)
import numpy as np
@ -64,7 +65,6 @@ _map = safe_map
zip = safe_zip
T = TypeVar('T')
Array = Any
BooleanNumeric = Any # A bool, or a Boolean array.
### Helper functions

View File

@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
from collections.abc import Sequence
from functools import partial
import operator
from typing import Any, NamedTuple, Optional, Union
from typing import NamedTuple, Optional, Union
import numpy as np
@ -28,14 +27,9 @@ from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lib.mlir.dialects import hlo
from jax._src.typing import Array, DTypeLike
_max = builtins.max
Array = Any
DType = Any
Shape = core.Shape
class ConvDimensionNumbers(NamedTuple):
"""Describes batch, spatial, and feature dimensions of a convolution.
@ -62,7 +56,7 @@ def conv_general_dilated(
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""General n-dimensional convolution operator, with optional dilation.
Wraps XLA's `Conv
@ -174,7 +168,7 @@ def conv_general_dilated(
def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: str, precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper around `conv_general_dilated`.
Args:
@ -204,7 +198,7 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
lhs_dilation: Optional[Sequence[int]],
rhs_dilation: Optional[Sequence[int]],
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper around `conv_general_dilated`.
Args:
@ -256,7 +250,7 @@ def _conv_transpose_padding(k, s, padding):
else:
pad_a = int(np.ceil(pad_len / 2))
elif padding == 'VALID':
pad_len = k + s - 2 + _max(k - s, 0)
pad_len = k + s - 2 + max(k - s, 0)
pad_a = k - 1
else:
raise ValueError('Padding mode must be `SAME` or `VALID`.')
@ -277,7 +271,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
transpose_kernel: bool = False,
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper for calculating the N-d convolution "transpose".
This function directly calculates a fractionally strided conv rather than
@ -343,7 +337,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
if transpose_kernel:
# flip spatial dims and swap input / output channel axes
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
rhs = rhs.swapaxes(dn.rhs_spec[0], dn.rhs_spec[1])
return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn,
precision=precision,
preferred_element_type=preferred_element_type)

View File

@ -14,7 +14,7 @@
from collections.abc import Sequence
from functools import partial
from typing import Any, Callable, Optional, Union
from typing import Callable, Optional, Union
import warnings
import numpy as np
@ -36,12 +36,11 @@ from jax._src.lax import slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.typing import Array
map = util.safe_map
zip = util.safe_zip
Array = Any
def reduce_window(operand, init_value, computation: Callable,
window_dimensions: core.Shape, window_strides: Sequence[int],

View File

@ -28,15 +28,16 @@ from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
from jax._src.numpy import util as numpy_util
from jax._src.typing import Array, ArrayLike
from jax._src.ops.special import logsumexp as _logsumexp
Array = Any
# activations
@custom_jvp
@jax.jit
def relu(x: Array) -> Array:
def relu(x: ArrayLike) -> Array:
r"""Rectified linear unit activation function.
Computes the element-wise function:
@ -72,7 +73,7 @@ def relu(x: Array) -> Array:
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
@jax.jit
def softplus(x: Array) -> Array:
def softplus(x: ArrayLike) -> Array:
r"""Softplus activation function.
Computes the element-wise function
@ -86,7 +87,7 @@ def softplus(x: Array) -> Array:
return jnp.logaddexp(x, 0)
@jax.jit
def soft_sign(x: Array) -> Array:
def soft_sign(x: ArrayLike) -> Array:
r"""Soft-sign activation function.
Computes the element-wise function
@ -97,10 +98,12 @@ def soft_sign(x: Array) -> Array:
Args:
x : input array
"""
return x / (jnp.abs(x) + 1)
numpy_util.check_arraylike("soft_sign", x)
x_arr = jnp.asarray(x)
return x_arr / (jnp.abs(x_arr) + 1)
@jax.jit
def sigmoid(x: Array) -> Array:
def sigmoid(x: ArrayLike) -> Array:
r"""Sigmoid activation function.
Computes the element-wise function:
@ -121,7 +124,7 @@ def sigmoid(x: Array) -> Array:
return lax.logistic(x)
@jax.jit
def silu(x: Array) -> Array:
def silu(x: ArrayLike) -> Array:
r"""SiLU (a.k.a. swish) activation function.
Computes the element-wise function:
@ -140,12 +143,14 @@ def silu(x: Array) -> Array:
See also:
:func:`sigmoid`
"""
return x * sigmoid(x)
numpy_util.check_arraylike("silu", x)
x_arr = jnp.asarray(x)
return x_arr * sigmoid(x_arr)
swish = silu
@jax.jit
def log_sigmoid(x: Array) -> Array:
def log_sigmoid(x: ArrayLike) -> Array:
r"""Log-sigmoid activation function.
Computes the element-wise function:
@ -162,10 +167,12 @@ def log_sigmoid(x: Array) -> Array:
See also:
:func:`sigmoid`
"""
return -softplus(-x)
numpy_util.check_arraylike("log_sigmoid", x)
x_arr = jnp.asarray(x)
return -softplus(-x_arr)
@jax.jit
def elu(x: Array, alpha: Array = 1.0) -> Array:
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
r"""Exponential linear unit activation function.
Computes the element-wise function:
@ -186,11 +193,14 @@ def elu(x: Array, alpha: Array = 1.0) -> Array:
See also:
:func:`selu`
"""
safe_x = jnp.where(x > 0, 0., x)
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))
numpy_util.check_arraylike("elu", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr > 0,
x_arr,
alpha * jnp.expm1(jnp.where(x_arr > 0, 0., x_arr)))
@jax.jit
def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array:
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array:
r"""Leaky rectified linear unit activation function.
Computes the element-wise function:
@ -213,10 +223,12 @@ def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array:
See also:
:func:`relu`
"""
return jnp.where(x >= 0, x, negative_slope * x)
numpy_util.check_arraylike("leaky_relu", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr >= 0, x_arr, negative_slope * x_arr)
@jax.jit
def hard_tanh(x: Array) -> Array:
def hard_tanh(x: ArrayLike) -> Array:
r"""Hard :math:`\mathrm{tanh}` activation function.
Computes the element-wise function:
@ -234,10 +246,12 @@ def hard_tanh(x: Array) -> Array:
Returns:
An array.
"""
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))
numpy_util.check_arraylike("hard_tanh", x)
x_arr = jnp.asarray(x)
return jnp.where(x_arr > 1, 1, jnp.where(x_arr < -1, -1, x_arr))
@jax.jit
def celu(x: Array, alpha: Array = 1.0) -> Array:
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
r"""Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
@ -262,7 +276,7 @@ def celu(x: Array, alpha: Array = 1.0) -> Array:
return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)
@jax.jit
def selu(x: Array) -> Array:
def selu(x: ArrayLike) -> Array:
r"""Scaled exponential linear unit activation.
Computes the element-wise function:
@ -295,7 +309,7 @@ def selu(x: Array) -> Array:
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
# @partial(jax.jit, static_argnames=("approximate",))
def gelu(x: Array, approximate: bool = True) -> Array:
def gelu(x: ArrayLike, approximate: bool = True) -> Array:
r"""Gaussian error linear unit activation function.
If ``approximate=False``, computes the element-wise function:
@ -317,20 +331,18 @@ def gelu(x: Array, approximate: bool = True) -> Array:
x : input array
approximate: whether to use the approximate or exact formulation.
"""
# Promote to nearest float-like dtype.
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
[x_arr] = numpy_util.promote_args_inexact("gelu", x)
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
return x * cdf
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x_arr.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3))))
return x_arr * cdf
else:
sqrt_2 = np.sqrt(2).astype(x.dtype)
return jnp.array(x * (lax.erf(x / sqrt_2) + 1) / 2, dtype=x.dtype)
sqrt_2 = np.sqrt(2).astype(x_arr.dtype)
return jnp.array(x_arr * (lax.erf(x_arr / sqrt_2) + 1) / 2, dtype=x_arr.dtype)
@partial(jax.jit, static_argnames=("axis",))
def glu(x: Array, axis: int = -1) -> Array:
def glu(x: ArrayLike, axis: int = -1) -> Array:
r"""Gated linear unit activation function.
Computes the function:
@ -353,9 +365,11 @@ def glu(x: Array, axis: int = -1) -> Array:
See also:
:func:`sigmoid`
"""
size = x.shape[axis]
numpy_util.check_arraylike("glu", x)
x_arr = jnp.asarray(x)
size = x_arr.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = jnp.split(x, 2, axis)
x1, x2 = jnp.split(x_arr, 2, axis)
return x1 * sigmoid(x2)
# other functions
@ -364,10 +378,10 @@ logsumexp = _logsumexp
@partial(jax.jit, static_argnames=("axis",))
def log_softmax(x: Array,
def log_softmax(x: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
where: Optional[ArrayLike] = None,
initial: Optional[ArrayLike] = None) -> Array:
r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales
@ -391,8 +405,10 @@ def log_softmax(x: Array,
See also:
:func:`softmax`
"""
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
shifted = x - lax.stop_gradient(x_max)
numpy_util.check_arraylike("log_softmax", x)
x_arr = jnp.asarray(x)
x_max = jnp.max(x_arr, axis, where=where, initial=initial, keepdims=True)
shifted = x_arr - lax.stop_gradient(x_max)
shifted_logsumexp = jnp.log(
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
result = shifted - shifted_logsumexp
@ -403,10 +419,10 @@ def log_softmax(x: Array,
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
#@partial(jax.jit, static_argnames=("axis",))
def softmax(x: Array,
def softmax(x: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
where: Optional[ArrayLike] = None,
initial: Optional[ArrayLike] = None) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
@ -431,17 +447,20 @@ def softmax(x: Array,
:func:`log_softmax`
"""
if jax.config.jax_softmax_custom_jvp:
return _softmax(x, axis, where, initial)
# mypy is confused by the `functools.partial` application in the definition
# of `_softmax` and incorrectly concludes that `_softmax` returns
# `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`.
return _softmax(x, axis, where, initial) # type: ignore[return-value]
else:
return _softmax_deprecated(x, axis, where, initial)
# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
@partial(jax.custom_jvp, nondiff_argnums=(1,))
def _softmax(
x,
x: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
where: Optional[ArrayLike] = None,
initial: Optional[ArrayLike] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - x_max)
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
@ -455,7 +474,11 @@ def _softmax_jvp(axis, primals, tangents):
y = _softmax(x, axis, where, initial)
return y, y * (x_dot - (y * x_dot).sum(axis, where=where, keepdims=True))
def _softmax_deprecated(x, axis, where, initial):
def _softmax_deprecated(
x: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[ArrayLike] = None,
initial: Optional[ArrayLike] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
@ -465,13 +488,15 @@ def _softmax_deprecated(x, axis, where, initial):
@partial(jax.jit, static_argnames=("axis",))
def standardize(x: Array,
def standardize(x: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
mean: Optional[ArrayLike] = None,
variance: Optional[ArrayLike] = None,
epsilon: ArrayLike = 1e-5,
where: Optional[ArrayLike] = None) -> Array:
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
numpy_util.check_arraylike("standardize", x)
numpy_util.check_arraylike_or_none("standardize", mean, variance, where)
if mean is None:
mean = jnp.mean(x, axis, keepdims=True, where=where)
if variance is None:
@ -481,43 +506,45 @@ def standardize(x: Array,
# when used in neural network normalization layers
variance = jnp.mean(
jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean)
return (x - mean) * lax.rsqrt(variance + epsilon)
return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon)
def normalize(x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
def normalize(x: ArrayLike,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[ArrayLike] = None,
variance: Optional[ArrayLike] = None,
epsilon: ArrayLike = 1e-5,
where: Optional[ArrayLike] = None) -> Array:
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
return standardize(x, axis, mean, variance, epsilon, where)
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
def _one_hot(x: Array, num_classes: int, *,
def _one_hot(x: Any, num_classes: int, *,
dtype: Any, axis: Union[int, AxisName]) -> Array:
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
dtype = dtypes.canonicalize_dtype(dtype)
x = jnp.asarray(x)
x_arr = jnp.asarray(x)
try:
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
output_pos_axis = util.canonicalize_axis(axis, x_arr.ndim + 1)
except TypeError:
axis_size = lax.psum(1, axis)
if num_classes != axis_size:
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(x == axis_idx, dtype=dtype)
return jnp.asarray(x_arr == axis_idx, dtype=dtype)
axis = operator.index(axis) # type: ignore[arg-type]
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
lhs = lax.expand_dims(x_arr, (axis,))
rhs_shape = [1] * x_arr.ndim
rhs_shape.insert(output_pos_axis, num_classes)
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
rhs = lax.broadcasted_iota(x_arr.dtype, rhs_shape, output_pos_axis)
return jnp.asarray(lhs == rhs, dtype=dtype)
def one_hot(x: Array, num_classes: int, *,
# TODO(slebedev): Change the type of `x` to `ArrayLike`.
def one_hot(x: Any, num_classes: int, *,
dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array:
"""One-hot encodes the given indices.
@ -550,7 +577,7 @@ def one_hot(x: Array, num_classes: int, *,
@jax.custom_jvp
@jax.jit
def relu6(x: Array) -> Array:
def relu6(x: ArrayLike) -> Array:
r"""Rectified Linear Unit 6 activation function.
Computes the element-wise function
@ -582,7 +609,7 @@ relu6.defjvps(lambda g, ans, x:
lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0)))
@jax.jit
def hard_sigmoid(x: Array) -> Array:
def hard_sigmoid(x: ArrayLike) -> Array:
r"""Hard Sigmoid activation function.
Computes the element-wise function
@ -602,7 +629,7 @@ def hard_sigmoid(x: Array) -> Array:
return relu6(x + 3.) / 6.
@jax.jit
def hard_silu(x: Array) -> Array:
def hard_silu(x: ArrayLike) -> Array:
r"""Hard SiLU (swish) activation function
Computes the element-wise function
@ -622,6 +649,8 @@ def hard_silu(x: Array) -> Array:
See also:
:func:`hard_sigmoid`
"""
return x * hard_sigmoid(x)
numpy_util.check_arraylike("hard_silu", x)
x_arr = jnp.asarray(x)
return x_arr * hard_sigmoid(x_arr)
hard_swish = hard_silu

View File

@ -23,18 +23,17 @@ from typing import Any, Literal, Protocol, Union
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax import random
from jax._src import core
from jax._src import dtypes
from jax._src.typing import Array, ArrayLike
from jax._src.util import set_module
export = set_module('jax.nn.initializers')
KeyArray = jax.Array
Array = Any
KeyArray = Array
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeFloat = Any
@ -48,7 +47,7 @@ class Initializer(Protocol):
def __call__(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
...
raise NotImplementedError
@export
def zeros(key: KeyArray,
@ -82,7 +81,7 @@ def ones(key: KeyArray,
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
@export
def constant(value: Array,
def constant(value: ArrayLike,
dtype: DTypeLikeInexact = jnp.float_
) -> Initializer:
"""Builds an initializer that returns arrays full of a constant ``value``.
@ -240,7 +239,7 @@ def _complex_uniform(key: KeyArray,
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)
def _complex_truncated_normal(key: KeyArray, upper: Array,
def _complex_truncated_normal(key: KeyArray, upper: ArrayLike,
shape: Union[Sequence[int], core.NamedShape],
dtype: DTypeLikeInexact) -> Array:
"""

View File

@ -16,7 +16,7 @@
from collections.abc import Sequence
import sys
from typing import Any, Callable, Optional, Union
from typing import Callable, Optional, Union
import warnings
import numpy as np
@ -31,9 +31,9 @@ from jax._src.lax import lax as lax_internal
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions
from jax._src.numpy.util import check_arraylike, promote_dtypes
from jax._src.typing import Array, ArrayLike
Array = Any
if sys.version_info >= (3, 10):
from types import EllipsisType
SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType]
@ -154,8 +154,8 @@ def _get_identity(op, dtype):
def _segment_update(name: str,
data: Array,
segment_ids: Array,
data: ArrayLike,
segment_ids: ArrayLike,
scatter_op: Callable,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
@ -195,8 +195,8 @@ def _segment_update(name: str,
return reducer(out, axis=0).astype(dtype)
def segment_sum(data: Array,
segment_ids: Array,
def segment_sum(data: ArrayLike,
segment_ids: ArrayLike,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
@ -250,8 +250,8 @@ def segment_sum(data: Array,
indices_are_sorted, unique_indices, bucket_size, reductions.sum, mode=mode)
def segment_prod(data: Array,
segment_ids: Array,
def segment_prod(data: ArrayLike,
segment_ids: ArrayLike,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
@ -306,8 +306,8 @@ def segment_prod(data: Array,
indices_are_sorted, unique_indices, bucket_size, reductions.prod, mode=mode)
def segment_max(data: Array,
segment_ids: Array,
def segment_max(data: ArrayLike,
segment_ids: ArrayLike,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
@ -361,8 +361,8 @@ def segment_max(data: Array,
indices_are_sorted, unique_indices, bucket_size, reductions.max, mode=mode)
def segment_min(data: Array,
segment_ids: Array,
def segment_min(data: ArrayLike,
segment_ids: ArrayLike,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,

View File

@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
from typing import Any, Callable, NamedTuple, Optional, Union
from typing import Callable, NamedTuple, Optional, Union
from functools import partial
import jax
import jax.numpy as jnp
from jax import lax
from jax._src.scipy.optimize.line_search import line_search
from jax._src.typing import Array
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
Array = Any
class LBFGSResults(NamedTuple):
"""Results from L-BFGS optimization

View File

@ -23,14 +23,13 @@ from jax._src import core
from jax._src import effects
from jax._src import pretty_printer as pp
from jax._src.util import safe_map, safe_zip
from jax._src.typing import Array
## JAX utilities
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Array = Any
_ref_effect_color = pp.Color.GREEN
class RefEffect(effects.JaxprInputEffect):