mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #6108 from NeilGirdhar:annotate
PiperOrigin-RevId: 365063131
This commit is contained in:
commit
b8812b2a5d
@ -15,11 +15,27 @@
|
||||
# Helpers for indexed updates.
|
||||
|
||||
|
||||
import sys
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src import util
|
||||
|
||||
|
||||
Array = Any
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import EllipsisType
|
||||
SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType]
|
||||
else:
|
||||
SingleIndex = Union[None, int, slice, Sequence[int], Array]
|
||||
Index = Union[SingleIndex, Tuple[SingleIndex, ...]]
|
||||
Scalar = Union[complex, float, int, np.number]
|
||||
Numeric = Union[Array, Scalar]
|
||||
|
||||
|
||||
def _scatter_update(x, idx, y, scatter_op, indices_are_sorted,
|
||||
unique_indices, normalize_indices=True):
|
||||
"""Helper for indexed updates.
|
||||
@ -102,7 +118,11 @@ class _Indexable(object):
|
||||
index = _Indexable()
|
||||
|
||||
|
||||
def index_add(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
def index_add(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] += y`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
@ -145,7 +165,11 @@ def index_add(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
x, idx, y, lax.scatter_add, indices_are_sorted, unique_indices)
|
||||
|
||||
|
||||
def index_mul(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
def index_mul(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] *= y`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
@ -188,7 +212,11 @@ def index_mul(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
indices_are_sorted, unique_indices)
|
||||
|
||||
|
||||
def index_min(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
def index_min(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] = minimum(x[idx], y)`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
@ -228,7 +256,11 @@ def index_min(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
return _scatter_update(
|
||||
x, idx, y, lax.scatter_min, indices_are_sorted, unique_indices)
|
||||
|
||||
def index_max(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
def index_max(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] = maximum(x[idx], y)`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
@ -268,7 +300,11 @@ def index_max(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
return _scatter_update(
|
||||
x, idx, y, lax.scatter_max, indices_are_sorted, unique_indices)
|
||||
|
||||
def index_update(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
def index_update(x: Array,
|
||||
idx: Index,
|
||||
y: Numeric,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False) -> Array:
|
||||
"""Pure equivalent of :code:`x[idx] = y`.
|
||||
|
||||
Returns the value of `x` that would result from the
|
||||
@ -309,12 +345,13 @@ def index_update(x, idx, y, indices_are_sorted=False, unique_indices=False):
|
||||
return _scatter_update(
|
||||
x, idx, y, lax.scatter, indices_are_sorted, unique_indices)
|
||||
|
||||
def segment_sum(data,
|
||||
segment_ids,
|
||||
num_segments=None,
|
||||
indices_are_sorted=False,
|
||||
unique_indices=False,
|
||||
bucket_size=None): # TODO(zhangqiaorjc): use non-None default.
|
||||
def segment_sum(data: Array,
|
||||
segment_ids: Array,
|
||||
num_segments: Optional[int] = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
# TODO(zhangqiaorjc): use non-None default for bucket_size.
|
||||
bucket_size: Optional[int] = None) -> Array:
|
||||
"""Computes the sum within segments of an array.
|
||||
|
||||
Similar to TensorFlow's segment_sum:
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -36,6 +36,15 @@ from jax.interpreters import xla
|
||||
from jax._src.util import prod
|
||||
|
||||
|
||||
Array = Any
|
||||
RealArray = Array
|
||||
IntegerArray = Array
|
||||
# TODO: Import or define these to match
|
||||
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
|
||||
DTypeLikeInt = Any
|
||||
DTypeLikeFloat = Any
|
||||
|
||||
|
||||
_UINT_DTYPES = {8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64}
|
||||
|
||||
|
||||
@ -265,7 +274,7 @@ def _split(key, num) -> jnp.ndarray:
|
||||
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
||||
|
||||
|
||||
def fold_in(key, data):
|
||||
def fold_in(key: jnp.ndarray, data: int) -> jnp.ndarray:
|
||||
"""Folds in data to a PRNG key to form a new PRNG key.
|
||||
|
||||
Args:
|
||||
@ -349,9 +358,9 @@ def _check_shape(name, shape: Union[Sequence[int], NamedShape], *param_shapes):
|
||||
|
||||
def uniform(key: jnp.ndarray,
|
||||
shape: Union[Sequence[int], NamedShape] = (),
|
||||
dtype: np.dtype = dtypes.float_,
|
||||
minval: Union[float, jnp.ndarray] = 0.,
|
||||
maxval: Union[float, jnp.ndarray] = 1.) -> jnp.ndarray:
|
||||
dtype: DTypeLikeFloat = dtypes.float_,
|
||||
minval: RealArray = 0.,
|
||||
maxval: RealArray = 1.) -> jnp.ndarray:
|
||||
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
||||
|
||||
Args:
|
||||
@ -407,9 +416,9 @@ def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray:
|
||||
|
||||
def randint(key: jnp.ndarray,
|
||||
shape: Sequence[int],
|
||||
minval: Union[int, jnp.ndarray],
|
||||
maxval: Union[int, jnp.ndarray],
|
||||
dtype: np.dtype = dtypes.int_):
|
||||
minval: IntegerArray,
|
||||
maxval: IntegerArray,
|
||||
dtype: DTypeLikeInt = dtypes.int_):
|
||||
"""Sample uniform random values in [minval, maxval) with given shape/dtype.
|
||||
|
||||
Args:
|
||||
@ -488,7 +497,7 @@ def shuffle(key: jnp.ndarray, x: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
|
||||
return _shuffle(key, x, axis) # type: ignore
|
||||
|
||||
|
||||
def permutation(key, x):
|
||||
def permutation(key: jnp.ndarray, x: IntegerArray) -> jnp.ndarray:
|
||||
"""
|
||||
Permute elements of an array along its first axis or return a permuted range.
|
||||
|
||||
@ -506,11 +515,12 @@ def permutation(key, x):
|
||||
# scalar case, must be a concrete integer
|
||||
if not np.issubdtype(lax.dtype(x), np.integer):
|
||||
raise TypeError("x must be an integer or at least 1-dimensional")
|
||||
x = int(x)
|
||||
x = int(x) # type: ignore[assignment]
|
||||
return _shuffle(key, jnp.arange(x), 0)
|
||||
elif np.ndim(x) == 1:
|
||||
return _shuffle(key, x, 0)
|
||||
else:
|
||||
assert isinstance(x, jnp.ndarray)
|
||||
ind = _shuffle(key, jnp.arange(x.shape[0]), 0) # type: ignore[attribute-error]
|
||||
return x[ind]
|
||||
|
||||
@ -543,7 +553,11 @@ def _shuffle(key, x, axis) -> jnp.ndarray:
|
||||
return x
|
||||
|
||||
|
||||
def choice(key, a, shape=(), replace=True, p=None):
|
||||
def choice(key: jnp.ndarray,
|
||||
a: IntegerArray,
|
||||
shape: Sequence[int] = (),
|
||||
replace: bool = True,
|
||||
p=None) -> jnp.ndarray:
|
||||
"""Generates a random sample from a given 1-D array.
|
||||
|
||||
Args:
|
||||
@ -572,7 +586,7 @@ def choice(key, a, shape=(), replace=True, p=None):
|
||||
a = int(a)
|
||||
else:
|
||||
a = _asarray(a)
|
||||
n_inputs = a if np.ndim(a) == 0 else len(a)
|
||||
n_inputs = int(a) if np.ndim(a) == 0 else len(a) # type: ignore[arg-type]
|
||||
n_draws = prod(shape)
|
||||
if n_draws == 0:
|
||||
return jnp.zeros(shape, dtype=lax.dtype(a))
|
||||
@ -584,7 +598,7 @@ def choice(key, a, shape=(), replace=True, p=None):
|
||||
if p is None:
|
||||
if replace:
|
||||
ind = randint(key, shape, 0, n_inputs)
|
||||
result = ind if np.ndim(a) == 0 else a[ind]
|
||||
result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index]
|
||||
else:
|
||||
result = permutation(key, a)[:n_draws]
|
||||
else:
|
||||
@ -594,18 +608,18 @@ def choice(key, a, shape=(), replace=True, p=None):
|
||||
p_cuml = jnp.cumsum(p)
|
||||
r = p_cuml[-1] * (1 - uniform(key, shape))
|
||||
ind = jnp.searchsorted(p_cuml, r)
|
||||
result = ind if np.ndim(a) == 0 else a[ind]
|
||||
result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index]
|
||||
else:
|
||||
# Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||
g = -gumbel(key, (n_inputs,)) - jnp.log(p)
|
||||
ind = jnp.argsort(g)[:n_draws]
|
||||
result = ind if np.ndim(a) == 0 else a[ind]
|
||||
result = ind if np.ndim(a) == 0 else a[ind] # type: ignore[index]
|
||||
return result.reshape(shape)
|
||||
|
||||
|
||||
def normal(key: jnp.ndarray,
|
||||
shape: Union[Sequence[int], NamedShape] = (),
|
||||
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample standard normal random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -648,10 +662,10 @@ def _normal_real(key, shape, dtype) -> jnp.ndarray:
|
||||
|
||||
|
||||
def multivariate_normal(key: jnp.ndarray,
|
||||
mean: jnp.ndarray,
|
||||
cov: jnp.ndarray,
|
||||
mean: RealArray,
|
||||
cov: RealArray,
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: np.dtype = dtypes.float_,
|
||||
dtype: DTypeLikeFloat = dtypes.float_,
|
||||
method: str = 'cholesky') -> jnp.ndarray:
|
||||
"""Sample multivariate normal random values with given mean and covariance.
|
||||
|
||||
@ -716,10 +730,10 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> jnp.ndarray:
|
||||
|
||||
|
||||
def truncated_normal(key: jnp.ndarray,
|
||||
lower: Union[float, jnp.ndarray],
|
||||
upper: Union[float, jnp.ndarray],
|
||||
lower: RealArray,
|
||||
upper: RealArray,
|
||||
shape: Optional[Union[Sequence[int], NamedShape]] = None,
|
||||
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample truncated standard normal random values with given shape and dtype.
|
||||
|
||||
Args:
|
||||
@ -773,7 +787,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> jnp.ndarray:
|
||||
|
||||
|
||||
def bernoulli(key: jnp.ndarray,
|
||||
p: jnp.ndarray = np.float32(0.5), # type: ignore[assignment]
|
||||
p: RealArray = np.float32(0.5),
|
||||
shape: Optional[Union[Sequence[int], NamedShape]] = None) -> jnp.ndarray:
|
||||
"""Sample Bernoulli random values with given shape and mean.
|
||||
|
||||
@ -810,10 +824,10 @@ def _bernoulli(key, p, shape) -> jnp.ndarray:
|
||||
|
||||
|
||||
def beta(key: jnp.ndarray,
|
||||
a: Union[float, jnp.ndarray],
|
||||
b: Union[float, jnp.ndarray],
|
||||
a: RealArray,
|
||||
b: RealArray,
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: np.dtype = dtypes.float_) -> jnp.ndarray:
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Beta random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -856,7 +870,9 @@ def _beta(key, a, b, shape, dtype):
|
||||
return gamma_a / (gamma_a + gamma_b)
|
||||
|
||||
|
||||
def cauchy(key, shape=(), dtype=dtypes.float_):
|
||||
def cauchy(key: jnp.ndarray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Cauchy random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -884,7 +900,10 @@ def _cauchy(key, shape, dtype):
|
||||
return lax.tan(lax.mul(pi, lax.sub(u, _constant_like(u, 0.5))))
|
||||
|
||||
|
||||
def dirichlet(key, alpha, shape=None, dtype=dtypes.float_):
|
||||
def dirichlet(key: jnp.ndarray,
|
||||
alpha: RealArray,
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Dirichlet random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -928,7 +947,9 @@ def _dirichlet(key, alpha, shape, dtype):
|
||||
return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def exponential(key, shape=(), dtype=dtypes.float_):
|
||||
def exponential(key: jnp.ndarray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Exponential random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -1056,7 +1077,10 @@ xla.backend_specific_translations['cpu'][random_gamma_p] = xla.lower_fun(
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule
|
||||
|
||||
def gamma(key, a, shape=None, dtype=dtypes.float_):
|
||||
def gamma(key: jnp.ndarray,
|
||||
a: RealArray,
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Gamma random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -1181,7 +1205,10 @@ def _poisson(key, lam, shape, dtype):
|
||||
return lax.select(lam == 0, jnp.zeros_like(result), result)
|
||||
|
||||
|
||||
def poisson(key, lam, shape=(), dtype=dtypes.int_):
|
||||
def poisson(key: jnp.ndarray,
|
||||
lam: RealArray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> jnp.ndarray:
|
||||
"""Sample Poisson random values with given shape and integer dtype.
|
||||
|
||||
Args:
|
||||
@ -1203,7 +1230,9 @@ def poisson(key, lam, shape=(), dtype=dtypes.int_):
|
||||
return _poisson(key, lam, shape, dtype)
|
||||
|
||||
|
||||
def gumbel(key, shape=(), dtype=dtypes.float_):
|
||||
def gumbel(key: jnp.ndarray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Gumbel random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -1230,7 +1259,10 @@ def _gumbel(key, shape, dtype):
|
||||
uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.)))
|
||||
|
||||
|
||||
def categorical(key, logits, axis=-1, shape=None):
|
||||
def categorical(key: jnp.ndarray,
|
||||
logits: jnp.ndarray,
|
||||
axis: int = -1,
|
||||
shape: Optional[Sequence[int]] = None) -> jnp.ndarray:
|
||||
"""Sample random values from categorical distributions.
|
||||
|
||||
Args:
|
||||
@ -1254,13 +1286,16 @@ def categorical(key, logits, axis=-1, shape=None):
|
||||
if shape is None:
|
||||
shape = batch_shape
|
||||
else:
|
||||
shape = tuple(shape)
|
||||
_check_shape("categorical", shape, batch_shape)
|
||||
|
||||
sample_shape = shape[:len(shape)-len(batch_shape)]
|
||||
return jnp.argmax(gumbel(key, sample_shape + logits.shape, logits.dtype) + logits, axis=axis)
|
||||
|
||||
|
||||
def laplace(key, shape=(), dtype=dtypes.float_):
|
||||
def laplace(key: jnp.ndarray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Laplace random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -1288,7 +1323,9 @@ def _laplace(key, shape, dtype):
|
||||
return lax.mul(lax.sign(u), lax.log1p(lax.neg(lax.abs(u))))
|
||||
|
||||
|
||||
def logistic(key, shape=(), dtype=dtypes.float_):
|
||||
def logistic(key: jnp.ndarray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample logistic random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -1315,12 +1352,15 @@ def _logistic(key, shape, dtype):
|
||||
return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))
|
||||
|
||||
|
||||
def pareto(key, b, shape=None, dtype=dtypes.float_):
|
||||
def pareto(key: jnp.ndarray,
|
||||
b: RealArray,
|
||||
shape: Optional[Sequence[int]] = None,
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Pareto random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
key: a PRNGKey used as the random key.
|
||||
a: a float or array of floats broadcast-compatible with ``shape``
|
||||
b: a float or array of floats broadcast-compatible with ``shape``
|
||||
representing the parameter of the distribution.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
shape. Must be broadcast-compatible with ``b``. The default (None)
|
||||
@ -1352,7 +1392,10 @@ def _pareto(key, b, shape, dtype):
|
||||
return lax.exp(e / b)
|
||||
|
||||
|
||||
def t(key, df, shape=(), dtype=dtypes.float_):
|
||||
def t(key: jnp.ndarray,
|
||||
df: RealArray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample Student's t random values with given shape and float dtype.
|
||||
|
||||
Args:
|
||||
@ -1392,7 +1435,9 @@ def _t(key, df, shape, dtype):
|
||||
return n * jnp.sqrt(half_df / g)
|
||||
|
||||
|
||||
def rademacher(key, shape, dtype=dtypes.int_):
|
||||
def rademacher(key: jnp.ndarray,
|
||||
shape: Sequence[int],
|
||||
dtype: DTypeLikeInt = dtypes.int_) -> jnp.ndarray:
|
||||
"""Sample from a Rademacher distribution.
|
||||
|
||||
Args:
|
||||
@ -1416,7 +1461,9 @@ def _rademacher(key, shape, dtype):
|
||||
return (2 * bernoulli_samples - 1).astype(dtype)
|
||||
|
||||
|
||||
def maxwell(key, shape=(), dtype=dtypes.float_):
|
||||
def maxwell(key: jnp.ndarray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample from a one sided Maxwell distribution.
|
||||
|
||||
The scipy counterpart is `scipy.stats.maxwell`.
|
||||
@ -1447,7 +1494,11 @@ def _maxwell(key, shape, dtype):
|
||||
return jnp.linalg.norm(norm_rvs, axis=-1)
|
||||
|
||||
|
||||
def double_sided_maxwell(key, loc, scale, shape=(), dtype=dtypes.float_):
|
||||
def double_sided_maxwell(key: jnp.ndarray,
|
||||
loc: RealArray,
|
||||
scale: RealArray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample from a double sided Maxwell distribution.
|
||||
|
||||
Samples using:
|
||||
@ -1488,7 +1539,11 @@ def _double_sided_maxwell(key, loc, scale, shape, dtype):
|
||||
return random_sign * maxwell_rvs * scale + loc
|
||||
|
||||
|
||||
def weibull_min(key, scale, concentration, shape=(), dtype=dtypes.float_):
|
||||
def weibull_min(key: jnp.ndarray,
|
||||
scale: RealArray,
|
||||
concentration: RealArray,
|
||||
shape: Sequence[int] = (),
|
||||
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
|
||||
"""Sample from a Weibull distribution.
|
||||
|
||||
The scipy counterpart is `scipy.stats.weibull_min`.
|
||||
|
Loading…
x
Reference in New Issue
Block a user