Merge pull request #6108 from NeilGirdhar:annotate

PiperOrigin-RevId: 365063131
This commit is contained in:
jax authors 2021-03-25 10:03:31 -07:00
commit b8812b2a5d
2 changed files with 145 additions and 53 deletions

View File

@ -15,11 +15,27 @@
# Helpers for indexed updates.
import sys
from typing import Any, Optional, Sequence, Tuple, Union
import numpy as np
from jax import lax
from jax._src.numpy import lax_numpy as jnp
from jax._src import util
Array = Any
if sys.version_info >= (3, 10):
from typing import EllipsisType
SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType]
else:
SingleIndex = Union[None, int, slice, Sequence[int], Array]
Index = Union[SingleIndex, Tuple[SingleIndex, ...]]
Scalar = Union[complex, float, int, np.number]
Numeric = Union[Array, Scalar]
def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
unique_indices, normalize_indices=True):
"""Helper for indexed updates.
@ -102,7 +118,11 @@ class _Indexable(object):
index = _Indexable()
def index_add(x, idx, y, indices_are_sorted=False, unique_indices=False):
def index_add(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] += y`.
Returns the value of `x` that would result from the
@ -145,7 +165,11 @@ def index_add(x, idx, y, indices_are_sorted=False, unique_indices=False):
x, idx, y, lax.scatter_add, indices_are_sorted, unique_indices)
def index_mul(x, idx, y, indices_are_sorted=False, unique_indices=False):
def index_mul(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] *= y`.
Returns the value of `x` that would result from the
@ -188,7 +212,11 @@ def index_mul(x, idx, y, indices_are_sorted=False, unique_indices=False):
indices_are_sorted, unique_indices)
def index_min(x, idx, y, indices_are_sorted=False, unique_indices=False):
def index_min(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`.
Returns the value of `x` that would result from the
@ -228,7 +256,11 @@ def index_min(x, idx, y, indices_are_sorted=False, unique_indices=False):
return _scatter_update(
x, idx, y, lax.scatter_min, indices_are_sorted, unique_indices)
def index_max(x, idx, y, indices_are_sorted=False, unique_indices=False):
def index_max(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = maximum(x[idx], y)`.
Returns the value of `x` that would result from the
@ -268,7 +300,11 @@ def index_max(x, idx, y, indices_are_sorted=False, unique_indices=False):
return _scatter_update(
x, idx, y, lax.scatter_max, indices_are_sorted, unique_indices)
def index_update(x, idx, y, indices_are_sorted=False, unique_indices=False):
def index_update(x: Array,
idx: Index,
y: Numeric,
indices_are_sorted: bool = False,
unique_indices: bool = False) -> Array:
"""Pure equivalent of :code:`x[idx] = y`.
Returns the value of `x` that would result from the
@ -309,12 +345,13 @@ def index_update(x, idx, y, indices_are_sorted=False, unique_indices=False):
return _scatter_update(
x, idx, y, lax.scatter, indices_are_sorted, unique_indices)
def segment_sum(data,
segment_ids,
num_segments=None,
indices_are_sorted=False,
unique_indices=False,
bucket_size=None): # TODO(zhangqiaorjc): use non-None default.
def segment_sum(data: Array,
segment_ids: Array,
num_segments: Optional[int] = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
# TODO(zhangqiaorjc): use non-None default for bucket_size.
bucket_size: Optional[int] = None) -> Array:
"""Computes the sum within segments of an array.
Similar to TensorFlow's segment_sum:

View File

@ -14,7 +14,7 @@
from functools import partial
from typing import Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union
import warnings
import numpy as np
@ -36,6 +36,15 @@ from jax.interpreters import xla
from jax._src.util import prod
Array = Any
RealArray = Array
IntegerArray = Array
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeInt = Any
DTypeLikeFloat = Any
_UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}
@ -265,7 +274,7 @@ def _split(key, num) -> jnp.ndarray:
return lax.reshape(threefry_2x32(key, counts), (num, 2))
def fold_in(key, data):
def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
"""Folds in data to a PRNG key to form a new PRNG key.
Args:
@ -349,9 +358,9 @@ def _check_shape(name, shape: Union[Sequence[int], NamedShape], *param_shapes):
def uniform(key: jnp.ndarray,
shape: Union[Sequence[int], NamedShape] = (),
dtype: np.dtype = dtypes.float_,
minval: Union[float, jnp.ndarray] = 0.,
maxval: Union[float, jnp.ndarray] = 1.) -> jnp.ndarray:
dtype: DTypeLikeFloat = dtypes.float_,
minval: RealArray = 0.,
maxval: RealArray = 1.) -> jnp.ndarray:
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
@ -407,9 +416,9 @@ def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray:
def randint(key: jnp.ndarray,
shape: Sequence[int],
minval: Union[int, jnp.ndarray],
maxval: Union[int, jnp.ndarray],
dtype: np.dtype = dtypes.int_):
minval: IntegerArray,
maxval: IntegerArray,
dtype: DTypeLikeInt = dtypes.int_):
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
Args:
@ -488,7 +497,7 @@ def shuffle(key: jnp.ndarray, x: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
return _shuffle(key, x, axis) # type: ignore
def permutation(key, x):
def permutation(key: jnp.ndarray, x: IntegerArray) -> jnp.ndarray:
"""
Permute elements of an array along its first axis or return a permuted range.
@ -506,11 +515,12 @@ def permutation(key, x):
# scalar case, must be a concrete integer
if not np.issubdtype(lax.dtype(x), np.integer):
raise TypeError("x must be an integer or at least 1-dimensional")
x = int(x)
x = int(x) # type: ignore[assignment]
return _shuffle(key, jnp.arange(x), 0)
elif np.ndim(x) == 1:
return _shuffle(key, x, 0)
else:
assert isinstance(x, jnp.ndarray)
ind = _shuffle(key, jnp.arange(x.shape[0]), 0) # type: ignore[attribute-error]
return x[ind]
@ -543,7 +553,11 @@ def _shuffle(key, x, axis) -> jnp.ndarray:
return x
def choice(key, a, shape=(), replace=True, p=None):
def choice(key: jnp.ndarray,
a: IntegerArray,
shape: Sequence[int] = (),
replace: bool = True,
p=None) -> jnp.ndarray:
"""Generates a random sample from a given 1-D array.
Args:
@ -572,7 +586,7 @@ def choice(key, a, shape=(), replace=True, p=None):
a = int(a)
else:
a = _asarray(a)
n_inputs = a if np.ndim(a) == 0 else len(a)
n_inputs = int(a) if np.ndim(a) == 0 else len(a) # type: ignore[arg-type]
n_draws = prod(shape)
if n_draws == 0:
return jnp.zeros(shape, dtype=lax.dtype(a))
@ -584,7 +598,7 @@ def choice(key, a, shape=(), replace=True, p=None):
if p is None:
if replace:
ind = randint(key, shape, 0, n_inputs)
result = ind if np.ndim(a) == 0 else a[ind]
result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index]
else:
result = permutation(key, a)[:n_draws]
else:
@ -594,18 +608,18 @@ def choice(key, a, shape=(), replace=True, p=None):
p_cuml = jnp.cumsum(p)
r = p_cuml[-1] * (1 - uniform(key, shape))
ind = jnp.searchsorted(p_cuml, r)
result = ind if np.ndim(a) == 0 else a[ind]
result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index]
else:
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
g = -gumbel(key, (n_inputs,)) - jnp.log(p)
ind = jnp.argsort(g)[:n_draws]
result = ind if np.ndim(a) == 0 else a[ind]
result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index]
return result.reshape(shape)
def normal(key: jnp.ndarray,
shape: Union[Sequence[int], NamedShape] = (),
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample standard normal random values with given shape and float dtype.
Args:
@ -648,10 +662,10 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray:
def multivariate_normal(key: jnp.ndarray,
mean: jnp.ndarray,
cov: jnp.ndarray,
mean: RealArray,
cov: RealArray,
shape: Optional[Sequence[int]] = None,
dtype: np.dtype = dtypes.float_,
dtype: DTypeLikeFloat = dtypes.float_,
method: str = 'cholesky') -> jnp.ndarray:
"""Sample multivariate normal random values with given mean and covariance.
@ -716,10 +730,10 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> jnp.ndarray:
def truncated_normal(key: jnp.ndarray,
lower: Union[float, jnp.ndarray],
upper: Union[float, jnp.ndarray],
lower: RealArray,
upper: RealArray,
shape: Optional[Union[Sequence[int], NamedShape]] = None,
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample truncated standard normal random values with given shape and dtype.
Args:
@ -773,7 +787,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray:
def bernoulli(key: jnp.ndarray,
p: jnp.ndarray = np.float32(0.5), # type: ignore[assignment]
p: RealArray = np.float32(0.5),
shape: Optional[Union[Sequence[int], NamedShape]] = None) -> jnp.ndarray:
"""Sample Bernoulli random values with given shape and mean.
@ -810,10 +824,10 @@ def _bernoulli(key, p, shape) -> jnp.ndarray:
def beta(key: jnp.ndarray,
a: Union[float, jnp.ndarray],
b: Union[float, jnp.ndarray],
a: RealArray,
b: RealArray,
shape: Optional[Sequence[int]] = None,
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Beta random values with given shape and float dtype.
Args:
@ -856,7 +870,9 @@ def _beta(key, a, b, shape, dtype):
return gamma_a / (gamma_a + gamma_b)
def cauchy(key, shape=(), dtype=dtypes.float_):
def cauchy(key: jnp.ndarray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Cauchy random values with given shape and float dtype.
Args:
@ -884,7 +900,10 @@ def _cauchy(key, shape, dtype):
return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5))))
def dirichlet(key, alpha, shape=None, dtype=dtypes.float_):
def dirichlet(key: jnp.ndarray,
alpha: RealArray,
shape: Optional[Sequence[int]] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Dirichlet random values with given shape and float dtype.
Args:
@ -928,7 +947,9 @@ def _dirichlet(key, alpha, shape, dtype):
return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)
def exponential(key, shape=(), dtype=dtypes.float_):
def exponential(key: jnp.ndarray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Exponential random values with given shape and float dtype.
Args:
@ -1056,7 +1077,10 @@ xla.backend_specific_translations['cpu'][random_gamma_p] = xla.lower_fun(
multiple_results=False)
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
def gamma(key, a, shape=None, dtype=dtypes.float_):
def gamma(key: jnp.ndarray,
a: RealArray,
shape: Optional[Sequence[int]] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Gamma random values with given shape and float dtype.
Args:
@ -1181,7 +1205,10 @@ def _poisson(key, lam, shape, dtype):
return lax.select(lam == 0, jnp.zeros_like(result), result)
def poisson(key, lam, shape=(), dtype=dtypes.int_):
def poisson(key: jnp.ndarray,
lam: RealArray,
shape: Sequence[int] = (),
dtype: DTypeLikeInt = dtypes.int_) -> jnp.ndarray:
"""Sample Poisson random values with given shape and integer dtype.
Args:
@ -1203,7 +1230,9 @@ def poisson(key, lam, shape=(), dtype=dtypes.int_):
return _poisson(key, lam, shape, dtype)
def gumbel(key, shape=(), dtype=dtypes.float_):
def gumbel(key: jnp.ndarray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Gumbel random values with given shape and float dtype.
Args:
@ -1230,7 +1259,10 @@ def _gumbel(key, shape, dtype):
uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
def categorical(key, logits, axis=-1, shape=None):
def categorical(key: jnp.ndarray,
logits: jnp.ndarray,
axis: int = -1,
shape: Optional[Sequence[int]] = None) -> jnp.ndarray:
"""Sample random values from categorical distributions.
Args:
@ -1254,13 +1286,16 @@ def categorical(key, logits, axis=-1, shape=None):
if shape is None:
shape = batch_shape
else:
shape = tuple(shape)
_check_shape("categorical", shape, batch_shape)
sample_shape = shape[:len(shape)-len(batch_shape)]
return jnp.argmax(gumbel(key, sample_shape + logits.shape, logits.dtype) + logits, axis=axis)
def laplace(key, shape=(), dtype=dtypes.float_):
def laplace(key: jnp.ndarray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Laplace random values with given shape and float dtype.
Args:
@ -1288,7 +1323,9 @@ def _laplace(key, shape, dtype):
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
def logistic(key, shape=(), dtype=dtypes.float_):
def logistic(key: jnp.ndarray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample logistic random values with given shape and float dtype.
Args:
@ -1315,12 +1352,15 @@ def _logistic(key, shape, dtype):
return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))
def pareto(key, b, shape=None, dtype=dtypes.float_):
def pareto(key: jnp.ndarray,
b: RealArray,
shape: Optional[Sequence[int]] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Pareto random values with given shape and float dtype.
Args:
key: a PRNGKey used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
b: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``b``. The default (None)
@ -1352,7 +1392,10 @@ def _pareto(key, b, shape, dtype):
return lax.exp(e / b)
def t(key, df, shape=(), dtype=dtypes.float_):
def t(key: jnp.ndarray,
df: RealArray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample Student's t random values with given shape and float dtype.
Args:
@ -1392,7 +1435,9 @@ def _t(key, df, shape, dtype):
return n * jnp.sqrt(half_df / g)
def rademacher(key, shape, dtype=dtypes.int_):
def rademacher(key: jnp.ndarray,
shape: Sequence[int],
dtype: DTypeLikeInt = dtypes.int_) -> jnp.ndarray:
"""Sample from a Rademacher distribution.
Args:
@ -1416,7 +1461,9 @@ def _rademacher(key, shape, dtype):
return (2 * bernoulli_samples - 1).astype(dtype)
def maxwell(key, shape=(), dtype=dtypes.float_):
def maxwell(key: jnp.ndarray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample from a one sided Maxwell distribution.
The scipy counterpart is `scipy.stats.maxwell`.
@ -1447,7 +1494,11 @@ def _maxwell(key, shape, dtype):
return jnp.linalg.norm(norm_rvs, axis=-1)
def double_sided_maxwell(key, loc, scale, shape=(), dtype=dtypes.float_):
def double_sided_maxwell(key: jnp.ndarray,
loc: RealArray,
scale: RealArray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample from a double sided Maxwell distribution.
Samples using:
@ -1488,7 +1539,11 @@ def _double_sided_maxwell(key, loc, scale, shape, dtype):
return random_sign * maxwell_rvs * scale + loc
def weibull_min(key, scale, concentration, shape=(), dtype=dtypes.float_):
def weibull_min(key: jnp.ndarray,
scale: RealArray,
concentration: RealArray,
shape: Sequence[int] = (),
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample from a Weibull distribution.
The scipy counterpart is `scipy.stats.weibull_min`.