mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Documentation improvements.
This commit is contained in:
parent
6af073e4f4
commit
c0a7bcb6fa
128
jax/lax.py
128
jax/lax.py
@ -448,14 +448,60 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
|
||||
dimension_numbers=dimension_numbers, lhs_shape=lhs.shape,
|
||||
rhs_shape=rhs.shape)
|
||||
|
||||
def dot(lhs, rhs): return dot_p.bind(lhs, rhs)
|
||||
def dot(lhs, rhs):
|
||||
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
||||
|
||||
Wraps XLA's `Dot
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dot>`_
|
||||
operator.
|
||||
|
||||
For more general contraction, see the `dot_general` operator.
|
||||
|
||||
Args:
|
||||
lhs: an array of rank 1 or 2.
|
||||
rhs: an array of rank 1 or 2.
|
||||
|
||||
Returns:
|
||||
An array containing the product.
|
||||
"""
|
||||
return dot_p.bind(lhs, rhs)
|
||||
|
||||
def dot_general(lhs, rhs, dimension_numbers):
|
||||
"""More general contraction operator.
|
||||
|
||||
Wraps XLA's `DotGeneral
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
|
||||
operator.
|
||||
|
||||
Args:
|
||||
lhs: an array
|
||||
rhs: an array
|
||||
dimension_numbers: a tuple of tuples of the form
|
||||
`((lhs_contracting_dims, rhs_contracting_dims),
|
||||
(lhs_batch_dims, rhs_batch_dims))`
|
||||
|
||||
Returns:
|
||||
An array containing the result.
|
||||
"""
|
||||
lhs_dims, rhs_dims = dimension_numbers
|
||||
dimension_numbers = (tuple(map(tuple, lhs_dims)), tuple(map(tuple, rhs_dims)))
|
||||
return dot_general_p.bind(lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
|
||||
def broadcast(operand, sizes):
|
||||
"""Broadcasts an array, adding new major dimensions.
|
||||
|
||||
Wraps XLA's `Broadcast
|
||||
<https://www.tensorflow.org/xla/operation_semantics#broadcast>`_
|
||||
operator.
|
||||
|
||||
Args:
|
||||
operand: an array
|
||||
sizes: a sequence of integers, giving the sizes of new major dimensions
|
||||
to add.
|
||||
|
||||
Returns:
|
||||
An array containing the result.
|
||||
"""
|
||||
return broadcast_p.bind(operand, sizes=tuple(sizes))
|
||||
|
||||
def broadcast_in_dim(operand, shape, broadcast_dimensions):
|
||||
@ -467,6 +513,10 @@ def broadcast_in_dim(operand, shape, broadcast_dimensions):
|
||||
broadcast_dimensions=tuple(broadcast_dimensions))
|
||||
|
||||
def reshape(operand, new_sizes, dimensions=None):
|
||||
"""Wraps XLA's `Reshape
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
|
||||
operator.
|
||||
"""
|
||||
same_shape = onp.shape(operand) == tuple(new_sizes)
|
||||
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
|
||||
if same_shape and same_dims:
|
||||
@ -478,27 +528,51 @@ def reshape(operand, new_sizes, dimensions=None):
|
||||
old_sizes=onp.shape(operand))
|
||||
|
||||
def pad(operand, padding_value, padding_config):
|
||||
"""Wraps XLA's `Pad
|
||||
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
|
||||
operator.
|
||||
"""
|
||||
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
|
||||
|
||||
def rev(operand, dimensions):
|
||||
"""Wraps XLA's `Rev
|
||||
<https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_
|
||||
operator.
|
||||
"""
|
||||
return rev_p.bind(operand, dimensions=tuple(dimensions))
|
||||
|
||||
def select(pred, on_true, on_false):
|
||||
"""Wraps XLA's `Select
|
||||
<https://www.tensorflow.org/xla/operation_semantics#select>`_
|
||||
operator.
|
||||
"""
|
||||
return select_p.bind(pred, on_true, on_false)
|
||||
|
||||
def slice(operand, start_indices, limit_indices, strides=None):
|
||||
"""Wraps XLA's `Slice
|
||||
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
|
||||
operator.
|
||||
"""
|
||||
return slice_p.bind(operand, start_indices=tuple(start_indices),
|
||||
limit_indices=tuple(limit_indices),
|
||||
strides=None if strides is None else tuple(strides),
|
||||
operand_shape=operand.shape)
|
||||
|
||||
def dynamic_slice(operand, start_indices, slice_sizes):
|
||||
"""Wraps XLA's `DynamicSlice
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
|
||||
operator.
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
return dynamic_slice_p.bind(
|
||||
operand, start_indices, slice_sizes=tuple(slice_sizes),
|
||||
operand_shape=operand.shape)
|
||||
|
||||
def dynamic_update_slice(operand, update, start_indices):
|
||||
"""Wraps XLA's `DynamicUpdateSlice
|
||||
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
|
||||
operator.
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
return dynamic_update_slice_p.bind(operand, update, start_indices,
|
||||
update_shape=update.shape)
|
||||
@ -570,6 +644,10 @@ def index_take(src, idxs, axes):
|
||||
return gather(src, indices, dimension_numbers=dnums, slice_sizes=slice_sizes)
|
||||
|
||||
def transpose(operand, permutation):
|
||||
"""Wraps XLA's `Transpose
|
||||
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
|
||||
operator.
|
||||
"""
|
||||
permutation = tuple(permutation)
|
||||
if permutation == tuple(range(len(permutation))):
|
||||
return operand
|
||||
@ -577,6 +655,10 @@ def transpose(operand, permutation):
|
||||
return transpose_p.bind(operand, permutation=permutation)
|
||||
|
||||
def reduce(operand, init_value, computation, dimensions):
|
||||
"""Wraps XLA's `Reduce
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
|
||||
operator.
|
||||
"""
|
||||
monoid_reducer = _get_monoid_reducer(computation, init_value)
|
||||
if monoid_reducer:
|
||||
return monoid_reducer(operand, dimensions)
|
||||
@ -637,6 +719,10 @@ def _reduce_and(operand, axes):
|
||||
|
||||
def reduce_window(operand, init_value, computation, window_dimensions,
|
||||
window_strides, padding):
|
||||
"""Wraps XLA's `ReduceWindow
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
|
||||
operator.
|
||||
"""
|
||||
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
|
||||
if monoid_reducer:
|
||||
return monoid_reducer(operand, window_dimensions, window_strides, padding)
|
||||
@ -706,6 +792,10 @@ def _select_and_gather_add(tangents, operand, select_prim, window_dimensions,
|
||||
window_strides=tuple(window_strides), padding=padding)
|
||||
|
||||
def sort(operand, dimension=-1):
|
||||
"""Wraps XLA's `Sort
|
||||
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
|
||||
operator.
|
||||
"""
|
||||
return sort_p.bind(operand, dimension=dimension)
|
||||
|
||||
def sort_key_val(keys, values, dimension=-1):
|
||||
@ -792,23 +882,31 @@ def full(shape, fill_value, dtype=None):
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
|
||||
# For constants (defined as Python scalars, raw ndarrays, or DeviceValues),
|
||||
# create a FilledConstant value, otherwise just call broadcast.
|
||||
# create a _FilledConstant value, otherwise just call broadcast.
|
||||
if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray:
|
||||
return FilledConstant(onp.asarray(fill_value, dtype), shape)
|
||||
return _FilledConstant(onp.asarray(fill_value, dtype), shape)
|
||||
elif isinstance(fill_value, xla.DeviceValue):
|
||||
val = onp.asarray(fill_value, dtype)
|
||||
return FilledConstant(val, shape)
|
||||
return _FilledConstant(val, shape)
|
||||
else:
|
||||
return broadcast(convert_element_type(fill_value, dtype), shape)
|
||||
|
||||
def iota(dtype, size):
|
||||
"""Wraps XLA's `Iota
|
||||
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
|
||||
operator.
|
||||
"""
|
||||
return broadcasted_iota(dtype, (int(size),), 0)
|
||||
|
||||
def broadcasted_iota(dtype, shape, dimension):
|
||||
"""Wraps XLA's `Iota
|
||||
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
|
||||
operator.
|
||||
"""
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
shape = tuple(map(int, shape))
|
||||
dimension = int(dimension)
|
||||
return IotaConstant(dtype, shape, dimension)
|
||||
return _IotaConstant(dtype, shape, dimension)
|
||||
|
||||
def eye(dtype, size):
|
||||
return broadcasted_eye(dtype, (size, size), (0, 1))
|
||||
@ -819,7 +917,7 @@ def broadcasted_eye(dtype, shape, axes):
|
||||
dtype = xla_bridge.canonicalize_dtype(dtype)
|
||||
shape = tuple(map(int, shape))
|
||||
axes = tuple(map(int, axes))
|
||||
return EyeConstant(shape, axes, dtype)
|
||||
return _EyeConstant(shape, axes, dtype)
|
||||
|
||||
|
||||
def stop_gradient(x):
|
||||
@ -1727,8 +1825,8 @@ def _dot_batch_rule(batched_args, batch_dims):
|
||||
dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)]
|
||||
return dot_general(lhs, rhs, dim_nums), 0
|
||||
|
||||
dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
|
||||
dot_p = standard_primitive(_dot_shape_rule, dot_dtype_rule, 'dot')
|
||||
_dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
|
||||
dot_p = standard_primitive(_dot_shape_rule, _dot_dtype_rule, 'dot')
|
||||
ad.defbilinear(dot_p, _dot_transpose_lhs, _dot_transpose_rhs)
|
||||
batching.primitive_batchers[dot_p] = _dot_batch_rule
|
||||
|
||||
@ -1903,10 +2001,10 @@ def _clamp_shape_rule(min, operand, max):
|
||||
raise TypeError(m.format(max.shape))
|
||||
return operand.shape
|
||||
|
||||
clamp_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_any, _any, _any],
|
||||
'clamp')
|
||||
_clamp_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_any, _any, _any],
|
||||
'clamp')
|
||||
|
||||
clamp_p = standard_primitive(_clamp_shape_rule, clamp_dtype_rule, 'clamp')
|
||||
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
|
||||
ad.defjvp(clamp_p,
|
||||
lambda g, min, operand, max:
|
||||
select(bitwise_and(gt(min, operand), lt(min, max)),
|
||||
@ -3328,7 +3426,7 @@ batching.primitive_batchers[shaped_identity_p] = \
|
||||
### constants
|
||||
|
||||
|
||||
class FilledConstant(xla.DeviceConstant):
|
||||
class _FilledConstant(xla.DeviceConstant):
|
||||
__slots__ = ["fill_value"]
|
||||
|
||||
def __init__(self, fill_value, shape):
|
||||
@ -3352,7 +3450,7 @@ class FilledConstant(xla.DeviceConstant):
|
||||
filled_const.shape)
|
||||
|
||||
|
||||
class IotaConstant(xla.DeviceConstant):
|
||||
class _IotaConstant(xla.DeviceConstant):
|
||||
__slots__ = ["axis"]
|
||||
|
||||
def __init__(self, dtype, shape, axis):
|
||||
@ -3381,7 +3479,7 @@ class IotaConstant(xla.DeviceConstant):
|
||||
return c.BroadcastedIota(dtype, iota_constant.shape, iota_constant.axis)
|
||||
|
||||
|
||||
class EyeConstant(xla.DeviceConstant):
|
||||
class _EyeConstant(xla.DeviceConstant):
|
||||
__slots__ = ["axes"]
|
||||
|
||||
def __init__(self, shape, axes, dtype):
|
||||
@ -3417,7 +3515,7 @@ class EyeConstant(xla.DeviceConstant):
|
||||
return c.ConvertElementType(_reduce(c.And, eyes), etype)
|
||||
|
||||
|
||||
for _t in [FilledConstant, IotaConstant, EyeConstant]:
|
||||
for _t in [_FilledConstant, _IotaConstant, _EyeConstant]:
|
||||
xla_bridge.register_constant_handler(_t, _t.constant_handler)
|
||||
core.pytype_aval_mappings[_t] = ConcreteArray
|
||||
xla.pytype_aval_mappings[_t] = xla.pytype_aval_mappings[xla.DeviceArray]
|
||||
|
@ -134,25 +134,24 @@ _LOGNDTR_FLOAT32_UPPER = onp.array(5, onp.float32)
|
||||
|
||||
|
||||
def ndtr(x):
|
||||
"""Normal distribution function.
|
||||
r"""Normal distribution function.
|
||||
|
||||
Returns the area under the Gaussian probability density function, integrated
|
||||
from minus infinity to x:
|
||||
|
||||
```
|
||||
1 / x
|
||||
ndtr(x) = ---------- | exp(-0.5 t**2) dt
|
||||
sqrt(2 pi) /-inf
|
||||
|
||||
= 0.5 (1 + erf(x / sqrt(2)))
|
||||
= 0.5 erfc(x / sqrt(2))
|
||||
```
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{ndtr}(x) =&
|
||||
\ \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \\
|
||||
=&\ \frac{1}{2} (1 + \mathrm{erf}(\frac{x}{\sqrt{2}})) \\
|
||||
=&\ \frac{1}{2} \mathrm{erfc}(\frac{x}{\sqrt{2}})
|
||||
\end{align}
|
||||
|
||||
Args:
|
||||
x: An array of type `float32`, `float64`.
|
||||
|
||||
Returns:
|
||||
ndtr: An array with `dtype=x.dtype`.
|
||||
An array with `dtype=x.dtype`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `x` is not floating-type.
|
||||
@ -181,19 +180,19 @@ def _ndtr(x):
|
||||
|
||||
|
||||
def ndtri(p):
|
||||
"""The inverse of the CDF of the Normal distribution function.
|
||||
r"""The inverse of the CDF of the Normal distribution function.
|
||||
|
||||
Returns x such that the area under the pdf from minus infinity to x is equal
|
||||
to p.
|
||||
Returns `x` such that the area under the PDF from :math:`-\infty` to `x` is equal
|
||||
to `p`.
|
||||
|
||||
A piece-wise rational approximation is done for the function.
|
||||
This is a port of the implementation in netlib.
|
||||
This is a based on the implementation in netlib.
|
||||
|
||||
Args:
|
||||
p: an array of type `float32`, `float64`.
|
||||
|
||||
Returns:
|
||||
x: an array with `dtype=p.dtype`.
|
||||
an array with `dtype=p.dtype`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `p` is not floating-type.
|
||||
@ -317,48 +316,59 @@ def _ndtri(p):
|
||||
|
||||
|
||||
def log_ndtr(x, series_order=3):
|
||||
"""Log Normal distribution function.
|
||||
r"""Log Normal distribution function.
|
||||
|
||||
For details of the Normal distribution function see `ndtr`.
|
||||
|
||||
This function calculates `(log o ndtr)(x)` by either calling `log(ndtr(x))` or
|
||||
using an asymptotic series. Specifically:
|
||||
This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
|
||||
:math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:
|
||||
|
||||
- For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
|
||||
`log(1-x) ~= -x, x << 1`.
|
||||
:math:`\log(1-x) \approx -x, x \ll 1`.
|
||||
- For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
|
||||
and take a log.
|
||||
- For `x <= lower_segment`, we use the series approximation of erf to compute
|
||||
- For `x <= lower_segment`, we use the series approximation of `erf` to compute
|
||||
the log CDF directly.
|
||||
|
||||
The `lower_segment` is set based on the precision of the input:
|
||||
|
||||
```
|
||||
lower_segment = { -20, x.dtype=float64
|
||||
{ -10, x.dtype=float32
|
||||
upper_segment = { 8, x.dtype=float64
|
||||
{ 5, x.dtype=float32
|
||||
```
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathit{lower\_segment} =&
|
||||
\ \begin{cases}
|
||||
-20 & x.\mathrm{dtype}=\mathit{float64} \\
|
||||
-10 & x.\mathrm{dtype}=\mathit{float32} \\
|
||||
\end{cases} \\
|
||||
\mathit{upper\_segment} =&
|
||||
\ \begin{cases}
|
||||
8& x.\mathrm{dtype}=\mathit{float64} \\
|
||||
5& x.\mathrm{dtype}=\mathit{float32} \\
|
||||
\end{cases}
|
||||
\end{align}
|
||||
|
||||
|
||||
When `x < lower_segment`, the `ndtr` asymptotic series approximation is:
|
||||
|
||||
```
|
||||
ndtr(x) = scale * (1 + sum) + R_N
|
||||
scale = exp(-0.5 x**2) / (-x sqrt(2 pi))
|
||||
sum = Sum{(-1)^n (2n-1)!! / (x**2)^n, n=1:N}
|
||||
R_N = O(exp(-0.5 x**2) (2N+1)!! / |x|^{2N+3})
|
||||
```
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\
|
||||
\mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
|
||||
\mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
|
||||
R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})
|
||||
\end{align}
|
||||
|
||||
where `(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a
|
||||
[double-factorial](https://en.wikipedia.org/wiki/Double_factorial).
|
||||
where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a
|
||||
`double-factorial
|
||||
<https://en.wikipedia.org/wiki/Double_factorial>`_ operator.
|
||||
|
||||
|
||||
Args:
|
||||
x: an array of type `float32`, `float64`.
|
||||
series_order: Positive Python `integer`. Maximum depth to
|
||||
series_order: Positive Python integer. Maximum depth to
|
||||
evaluate the asymptotic expansion. This is the `N` above.
|
||||
|
||||
Returns:
|
||||
log_ndtr: an array with `dtype=x.dtype`.
|
||||
an array with `dtype=x.dtype`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `x.dtype` is not handled.
|
||||
|
Loading…
x
Reference in New Issue
Block a user