mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
rollback of https://github.com/google/jax/pull/9596
Why? Shape annotations are inaccurate and cause pytype failures PiperOrigin-RevId: 465337386
This commit is contained in:
parent
89ce078bf1
commit
d52017aa78
@ -18,10 +18,9 @@ used in Keras and Sonnet.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Any, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
@ -30,25 +29,9 @@ from jax import core
|
||||
from jax._src.util import prod
|
||||
from jax._src import dtypes
|
||||
|
||||
KeyArray = random.KeyArray
|
||||
Array = Any
|
||||
# TODO: Import or define these to match
|
||||
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
|
||||
DTypeLikeFloat = Any
|
||||
DTypeLikeComplex = Any
|
||||
DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex
|
||||
RealNumeric = Any # Scalar jnp array or float
|
||||
DType = Any
|
||||
|
||||
class Initializer(Protocol):
|
||||
@staticmethod
|
||||
def __call__(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Array:
|
||||
...
|
||||
|
||||
def zeros(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Array:
|
||||
def zeros(key, shape, dtype: DType = jnp.float_):
|
||||
"""An initializer that returns a constant array full of zeros.
|
||||
|
||||
The ``key`` argument is ignored.
|
||||
@ -60,9 +43,7 @@ def zeros(key: KeyArray,
|
||||
"""
|
||||
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def ones(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Array:
|
||||
def ones(key, shape, dtype: DType = jnp.float_):
|
||||
"""An initializer that returns a constant array full of ones.
|
||||
|
||||
The ``key`` argument is ignored.
|
||||
@ -75,9 +56,7 @@ def ones(key: KeyArray,
|
||||
"""
|
||||
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
def constant(value: Array,
|
||||
dtype: DTypeLikeInexact = jnp.float_
|
||||
) -> Initializer:
|
||||
def constant(value, dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds an initializer that returns arrays full of a constant ``value``.
|
||||
|
||||
Args:
|
||||
@ -90,15 +69,12 @@ def constant(value: Array,
|
||||
DeviceArray([[-7., -7., -7.],
|
||||
[-7., -7., -7.]], dtype=float32)
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = dtype) -> Array:
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
return jnp.full(shape, value, dtype=dtype)
|
||||
return init
|
||||
|
||||
def uniform(scale: RealNumeric = 1e-2,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds an initializer that returns real uniformly-distributed random arrays.
|
||||
|
||||
Args:
|
||||
@ -115,15 +91,12 @@ def uniform(scale: RealNumeric = 1e-2,
|
||||
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
|
||||
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = dtype) -> Array:
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
return random.uniform(key, shape, dtype) * scale
|
||||
return init
|
||||
|
||||
def normal(stddev: RealNumeric = 1e-2,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds an initializer that returns real normally-distributed random arrays.
|
||||
|
||||
Args:
|
||||
@ -140,18 +113,13 @@ def normal(stddev: RealNumeric = 1e-2,
|
||||
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
||||
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = dtype) -> Array:
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
return random.normal(key, shape, dtype) * stddev
|
||||
return init
|
||||
|
||||
def _compute_fans(shape: core.NamedShape,
|
||||
in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Union[int, Sequence[int]] = ()
|
||||
) -> Tuple[Array, Array]:
|
||||
def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1,
|
||||
batch_axis=()):
|
||||
"""
|
||||
Compute effective input and output sizes for a linear or convolutional layer.
|
||||
|
||||
@ -175,9 +143,7 @@ def _compute_fans(shape: core.NamedShape,
|
||||
fan_out = out_size * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
def _complex_uniform(key: KeyArray,
|
||||
shape: core.NamedShape,
|
||||
dtype: DTypeLikeInexact) -> Array:
|
||||
def _complex_uniform(key, shape, dtype):
|
||||
"""
|
||||
Sample uniform random values within a disk on the complex plane,
|
||||
with zero mean and unit variance.
|
||||
@ -189,33 +155,24 @@ def _complex_uniform(key: KeyArray,
|
||||
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
|
||||
return r * jnp.exp(1j * theta)
|
||||
|
||||
def _complex_truncated_normal(key: KeyArray, upper: Array,
|
||||
shape: core.NamedShape,
|
||||
dtype: DTypeLikeInexact) -> Array:
|
||||
def _complex_truncated_normal(key, upper, shape, dtype):
|
||||
"""
|
||||
Sample random values from a centered normal distribution on the complex plane,
|
||||
whose modulus is truncated to `upper`, and the variance before the truncation
|
||||
is one.
|
||||
whose modulus is truncated to `upper`, and the variance before the truncation is one.
|
||||
"""
|
||||
key_r, key_theta = random.split(key)
|
||||
real_dtype = np.array(0, dtype).real.dtype
|
||||
dtype = dtypes._to_complex_dtype(real_dtype)
|
||||
t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype)))
|
||||
* random.uniform(key_r, shape, real_dtype).astype(dtype))
|
||||
t = (1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, real_dtype).astype(dtype)
|
||||
r = jnp.sqrt(-jnp.log(1 - t))
|
||||
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
|
||||
return r * jnp.exp(1j * theta)
|
||||
|
||||
def variance_scaling(
|
||||
scale: RealNumeric,
|
||||
mode: Union[Literal["fan_in"], Literal["fan_out"], Literal["fan_avg"]],
|
||||
distribution: Union[Literal["truncated_normal"], Literal["normal"],
|
||||
Literal["uniform"]],
|
||||
in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DTypeLikeInexact = jnp.float_
|
||||
) -> Initializer:
|
||||
def variance_scaling(scale, mode: str, distribution: str,
|
||||
in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
r"""
|
||||
Initializer that adapts its scale to the shape of the weights tensor.
|
||||
|
||||
@ -257,12 +214,10 @@ def variance_scaling(
|
||||
dtype: the dtype of the weights.
|
||||
"""
|
||||
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = dtype) -> Array:
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
named_shape = core.as_named_shape(shape)
|
||||
fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis)
|
||||
shape = core.as_named_shape(shape)
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
|
||||
if mode == "fan_in": denominator = fan_in
|
||||
elif mode == "fan_out": denominator = fan_out
|
||||
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
|
||||
@ -275,18 +230,18 @@ def variance_scaling(
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
|
||||
return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev
|
||||
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
|
||||
else:
|
||||
# constant is stddev of complex standard normal truncated to 2
|
||||
stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)
|
||||
return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev
|
||||
return _complex_truncated_normal(key, 2, shape, dtype) * stddev
|
||||
elif distribution == "normal":
|
||||
return random.normal(key, named_shape, dtype) * jnp.sqrt(variance)
|
||||
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance)
|
||||
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
|
||||
else:
|
||||
return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance)
|
||||
return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}")
|
||||
|
||||
@ -295,7 +250,7 @@ def variance_scaling(
|
||||
def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a Glorot uniform initializer (aka Xavier uniform initializer).
|
||||
|
||||
A `Glorot uniform initializer`_ is a specialization of
|
||||
@ -333,7 +288,7 @@ xavier_uniform = glorot_uniform
|
||||
def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a Glorot normal initializer (aka Xavier normal initializer).
|
||||
|
||||
A `Glorot normal initializer`_ is a specialization of
|
||||
@ -370,7 +325,7 @@ xavier_normal = glorot_normal
|
||||
def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a Lecun uniform initializer.
|
||||
|
||||
A `Lecun uniform initializer`_ is a specialization of
|
||||
@ -405,7 +360,7 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a Lecun normal initializer.
|
||||
|
||||
A `Lecun normal initializer`_ is a specialization of
|
||||
@ -441,7 +396,7 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a He uniform initializer (aka Kaiming uniform initializer).
|
||||
|
||||
A `He uniform initializer`_ is a specialization of
|
||||
@ -479,7 +434,7 @@ kaiming_uniform = he_uniform
|
||||
def he_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
out_axis: Union[int, Sequence[int]] = -1,
|
||||
batch_axis: Sequence[int] = (),
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
dtype: DType = jnp.float_) -> Callable:
|
||||
"""Builds a He normal initializer (aka Kaiming normal initializer).
|
||||
|
||||
A `He normal initializer`_ is a specialization of
|
||||
@ -514,9 +469,7 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
kaiming_normal = he_normal
|
||||
|
||||
|
||||
def orthogonal(scale: RealNumeric = 1.0,
|
||||
column_axis: int = -1,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
|
||||
"""
|
||||
Builds an initializer that returns uniformly distributed orthogonal matrices.
|
||||
|
||||
@ -539,9 +492,7 @@ def orthogonal(scale: RealNumeric = 1.0,
|
||||
DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
|
||||
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = dtype) -> Array:
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if len(shape) < 2:
|
||||
raise ValueError("orthogonal initializer requires at least a 2D shape")
|
||||
@ -558,10 +509,7 @@ def orthogonal(scale: RealNumeric = 1.0,
|
||||
return init
|
||||
|
||||
|
||||
def delta_orthogonal(
|
||||
scale: RealNumeric = 1.0,
|
||||
column_axis: int = -1,
|
||||
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
|
||||
def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
|
||||
"""
|
||||
Builds an initializer for delta orthogonal kernels.
|
||||
|
||||
@ -594,9 +542,7 @@ def delta_orthogonal(
|
||||
|
||||
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = dtype) -> Array:
|
||||
def init(key, shape, dtype=dtype):
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if len(shape) not in [3, 4, 5]:
|
||||
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D "
|
||||
|
@ -19,7 +19,6 @@ used in Keras and Sonnet.
|
||||
|
||||
from jax._src.nn.initializers import (
|
||||
constant as constant,
|
||||
Initializer as Initializer,
|
||||
delta_orthogonal as delta_orthogonal,
|
||||
glorot_normal as glorot_normal,
|
||||
glorot_uniform as glorot_uniform,
|
||||
|
Loading…
x
Reference in New Issue
Block a user