Documentation improvements.

This commit is contained in:
Peter Hawkins 2019-02-28 22:48:31 -05:00
parent 6af073e4f4
commit c0a7bcb6fa
2 changed files with 159 additions and 51 deletions

View File

@ -448,14 +448,60 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
dimension_numbers=dimension_numbers, lhs_shape=lhs.shape,
rhs_shape=rhs.shape)
def dot(lhs, rhs): return dot_p.bind(lhs, rhs)
def dot(lhs, rhs):
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
Wraps XLA's `Dot
<https://www.tensorflow.org/xla/operation_semantics#dot>`_
operator.
For more general contraction, see the `dot_general` operator.
Args:
lhs: an array of rank 1 or 2.
rhs: an array of rank 1 or 2.
Returns:
An array containing the product.
"""
return dot_p.bind(lhs, rhs)
def dot_general(lhs, rhs, dimension_numbers):
"""More general contraction operator.
Wraps XLA's `DotGeneral
<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`_
operator.
Args:
lhs: an array
rhs: an array
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`
Returns:
An array containing the result.
"""
lhs_dims, rhs_dims = dimension_numbers
dimension_numbers = (tuple(map(tuple, lhs_dims)), tuple(map(tuple, rhs_dims)))
return dot_general_p.bind(lhs, rhs, dimension_numbers=dimension_numbers)
def broadcast(operand, sizes):
"""Broadcasts an array, adding new major dimensions.
Wraps XLA's `Broadcast
<https://www.tensorflow.org/xla/operation_semantics#broadcast>`_
operator.
Args:
operand: an array
sizes: a sequence of integers, giving the sizes of new major dimensions
to add.
Returns:
An array containing the result.
"""
return broadcast_p.bind(operand, sizes=tuple(sizes))
def broadcast_in_dim(operand, shape, broadcast_dimensions):
@ -467,6 +513,10 @@ def broadcast_in_dim(operand, shape, broadcast_dimensions):
broadcast_dimensions=tuple(broadcast_dimensions))
def reshape(operand, new_sizes, dimensions=None):
"""Wraps XLA's `Reshape
<https://www.tensorflow.org/xla/operation_semantics#reshape>`_
operator.
"""
same_shape = onp.shape(operand) == tuple(new_sizes)
same_dims = dimensions is None or tuple(dimensions) == tuple(range(onp.ndim(operand)))
if same_shape and same_dims:
@ -478,27 +528,51 @@ def reshape(operand, new_sizes, dimensions=None):
old_sizes=onp.shape(operand))
def pad(operand, padding_value, padding_config):
"""Wraps XLA's `Pad
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
operator.
"""
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
def rev(operand, dimensions):
"""Wraps XLA's `Rev
<https://www.tensorflow.org/xla/operation_semantics#rev_reverse>`_
operator.
"""
return rev_p.bind(operand, dimensions=tuple(dimensions))
def select(pred, on_true, on_false):
"""Wraps XLA's `Select
<https://www.tensorflow.org/xla/operation_semantics#select>`_
operator.
"""
return select_p.bind(pred, on_true, on_false)
def slice(operand, start_indices, limit_indices, strides=None):
"""Wraps XLA's `Slice
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
operator.
"""
return slice_p.bind(operand, start_indices=tuple(start_indices),
limit_indices=tuple(limit_indices),
strides=None if strides is None else tuple(strides),
operand_shape=operand.shape)
def dynamic_slice(operand, start_indices, slice_sizes):
"""Wraps XLA's `DynamicSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
operator.
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_slice_p.bind(
operand, start_indices, slice_sizes=tuple(slice_sizes),
operand_shape=operand.shape)
def dynamic_update_slice(operand, update, start_indices):
"""Wraps XLA's `DynamicUpdateSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
operator.
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_update_slice_p.bind(operand, update, start_indices,
update_shape=update.shape)
@ -570,6 +644,10 @@ def index_take(src, idxs, axes):
return gather(src, indices, dimension_numbers=dnums, slice_sizes=slice_sizes)
def transpose(operand, permutation):
"""Wraps XLA's `Transpose
<https://www.tensorflow.org/xla/operation_semantics#transpose>`_
operator.
"""
permutation = tuple(permutation)
if permutation == tuple(range(len(permutation))):
return operand
@ -577,6 +655,10 @@ def transpose(operand, permutation):
return transpose_p.bind(operand, permutation=permutation)
def reduce(operand, init_value, computation, dimensions):
"""Wraps XLA's `Reduce
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
operator.
"""
monoid_reducer = _get_monoid_reducer(computation, init_value)
if monoid_reducer:
return monoid_reducer(operand, dimensions)
@ -637,6 +719,10 @@ def _reduce_and(operand, axes):
def reduce_window(operand, init_value, computation, window_dimensions,
window_strides, padding):
"""Wraps XLA's `ReduceWindow
<https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
operator.
"""
monoid_reducer = _get_monoid_window_reducer(computation, init_value)
if monoid_reducer:
return monoid_reducer(operand, window_dimensions, window_strides, padding)
@ -706,6 +792,10 @@ def _select_and_gather_add(tangents, operand, select_prim, window_dimensions,
window_strides=tuple(window_strides), padding=padding)
def sort(operand, dimension=-1):
"""Wraps XLA's `Sort
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
operator.
"""
return sort_p.bind(operand, dimension=dimension)
def sort_key_val(keys, values, dimension=-1):
@ -792,23 +882,31 @@ def full(shape, fill_value, dtype=None):
dtype = xla_bridge.canonicalize_dtype(dtype)
# For constants (defined as Python scalars, raw ndarrays, or DeviceValues),
# create a FilledConstant value, otherwise just call broadcast.
# create a _FilledConstant value, otherwise just call broadcast.
if onp.isscalar(fill_value) or type(fill_value) is onp.ndarray:
return FilledConstant(onp.asarray(fill_value, dtype), shape)
return _FilledConstant(onp.asarray(fill_value, dtype), shape)
elif isinstance(fill_value, xla.DeviceValue):
val = onp.asarray(fill_value, dtype)
return FilledConstant(val, shape)
return _FilledConstant(val, shape)
else:
return broadcast(convert_element_type(fill_value, dtype), shape)
def iota(dtype, size):
"""Wraps XLA's `Iota
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
operator.
"""
return broadcasted_iota(dtype, (int(size),), 0)
def broadcasted_iota(dtype, shape, dimension):
"""Wraps XLA's `Iota
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
operator.
"""
dtype = xla_bridge.canonicalize_dtype(dtype)
shape = tuple(map(int, shape))
dimension = int(dimension)
return IotaConstant(dtype, shape, dimension)
return _IotaConstant(dtype, shape, dimension)
def eye(dtype, size):
return broadcasted_eye(dtype, (size, size), (0, 1))
@ -819,7 +917,7 @@ def broadcasted_eye(dtype, shape, axes):
dtype = xla_bridge.canonicalize_dtype(dtype)
shape = tuple(map(int, shape))
axes = tuple(map(int, axes))
return EyeConstant(shape, axes, dtype)
return _EyeConstant(shape, axes, dtype)
def stop_gradient(x):
@ -1727,8 +1825,8 @@ def _dot_batch_rule(batched_args, batch_dims):
dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)]
return dot_general(lhs, rhs, dim_nums), 0
dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
dot_p = standard_primitive(_dot_shape_rule, dot_dtype_rule, 'dot')
_dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
dot_p = standard_primitive(_dot_shape_rule, _dot_dtype_rule, 'dot')
ad.defbilinear(dot_p, _dot_transpose_lhs, _dot_transpose_rhs)
batching.primitive_batchers[dot_p] = _dot_batch_rule
@ -1903,10 +2001,10 @@ def _clamp_shape_rule(min, operand, max):
raise TypeError(m.format(max.shape))
return operand.shape
clamp_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_any, _any, _any],
'clamp')
_clamp_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_any, _any, _any],
'clamp')
clamp_p = standard_primitive(_clamp_shape_rule, clamp_dtype_rule, 'clamp')
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
ad.defjvp(clamp_p,
lambda g, min, operand, max:
select(bitwise_and(gt(min, operand), lt(min, max)),
@ -3328,7 +3426,7 @@ batching.primitive_batchers[shaped_identity_p] = \
### constants
class FilledConstant(xla.DeviceConstant):
class _FilledConstant(xla.DeviceConstant):
__slots__ = ["fill_value"]
def __init__(self, fill_value, shape):
@ -3352,7 +3450,7 @@ class FilledConstant(xla.DeviceConstant):
filled_const.shape)
class IotaConstant(xla.DeviceConstant):
class _IotaConstant(xla.DeviceConstant):
__slots__ = ["axis"]
def __init__(self, dtype, shape, axis):
@ -3381,7 +3479,7 @@ class IotaConstant(xla.DeviceConstant):
return c.BroadcastedIota(dtype, iota_constant.shape, iota_constant.axis)
class EyeConstant(xla.DeviceConstant):
class _EyeConstant(xla.DeviceConstant):
__slots__ = ["axes"]
def __init__(self, shape, axes, dtype):
@ -3417,7 +3515,7 @@ class EyeConstant(xla.DeviceConstant):
return c.ConvertElementType(_reduce(c.And, eyes), etype)
for _t in [FilledConstant, IotaConstant, EyeConstant]:
for _t in [_FilledConstant, _IotaConstant, _EyeConstant]:
xla_bridge.register_constant_handler(_t, _t.constant_handler)
core.pytype_aval_mappings[_t] = ConcreteArray
xla.pytype_aval_mappings[_t] = xla.pytype_aval_mappings[xla.DeviceArray]

View File

@ -134,25 +134,24 @@ _LOGNDTR_FLOAT32_UPPER = onp.array(5, onp.float32)
def ndtr(x):
"""Normal distribution function.
r"""Normal distribution function.
Returns the area under the Gaussian probability density function, integrated
from minus infinity to x:
```
1 / x
ndtr(x) = ---------- | exp(-0.5 t**2) dt
sqrt(2 pi) /-inf
= 0.5 (1 + erf(x / sqrt(2)))
= 0.5 erfc(x / sqrt(2))
```
.. math::
\begin{align}
\mathrm{ndtr}(x) =&
\ \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \\
=&\ \frac{1}{2} (1 + \mathrm{erf}(\frac{x}{\sqrt{2}})) \\
=&\ \frac{1}{2} \mathrm{erfc}(\frac{x}{\sqrt{2}})
\end{align}
Args:
x: An array of type `float32`, `float64`.
Returns:
ndtr: An array with `dtype=x.dtype`.
An array with `dtype=x.dtype`.
Raises:
TypeError: if `x` is not floating-type.
@ -181,19 +180,19 @@ def _ndtr(x):
def ndtri(p):
"""The inverse of the CDF of the Normal distribution function.
r"""The inverse of the CDF of the Normal distribution function.
Returns x such that the area under the pdf from minus infinity to x is equal
to p.
Returns `x` such that the area under the PDF from :math:`-\infty` to `x` is equal
to `p`.
A piece-wise rational approximation is done for the function.
This is a port of the implementation in netlib.
This is a based on the implementation in netlib.
Args:
p: an array of type `float32`, `float64`.
Returns:
x: an array with `dtype=p.dtype`.
an array with `dtype=p.dtype`.
Raises:
TypeError: if `p` is not floating-type.
@ -317,48 +316,59 @@ def _ndtri(p):
def log_ndtr(x, series_order=3):
"""Log Normal distribution function.
r"""Log Normal distribution function.
For details of the Normal distribution function see `ndtr`.
This function calculates `(log o ndtr)(x)` by either calling `log(ndtr(x))` or
using an asymptotic series. Specifically:
This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
:math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:
- For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
`log(1-x) ~= -x, x << 1`.
:math:`\log(1-x) \approx -x, x \ll 1`.
- For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
and take a log.
- For `x <= lower_segment`, we use the series approximation of erf to compute
- For `x <= lower_segment`, we use the series approximation of `erf` to compute
the log CDF directly.
The `lower_segment` is set based on the precision of the input:
```
lower_segment = { -20, x.dtype=float64
{ -10, x.dtype=float32
upper_segment = { 8, x.dtype=float64
{ 5, x.dtype=float32
```
.. math::
\begin{align}
\mathit{lower\_segment} =&
\ \begin{cases}
-20 & x.\mathrm{dtype}=\mathit{float64} \\
-10 & x.\mathrm{dtype}=\mathit{float32} \\
\end{cases} \\
\mathit{upper\_segment} =&
\ \begin{cases}
8& x.\mathrm{dtype}=\mathit{float64} \\
5& x.\mathrm{dtype}=\mathit{float32} \\
\end{cases}
\end{align}
When `x < lower_segment`, the `ndtr` asymptotic series approximation is:
```
ndtr(x) = scale * (1 + sum) + R_N
scale = exp(-0.5 x**2) / (-x sqrt(2 pi))
sum = Sum{(-1)^n (2n-1)!! / (x**2)^n, n=1:N}
R_N = O(exp(-0.5 x**2) (2N+1)!! / |x|^{2N+3})
```
.. math::
\begin{align}
\mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\
\mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
\mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})
\end{align}
where `(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a
[double-factorial](https://en.wikipedia.org/wiki/Double_factorial).
where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a
`double-factorial
<https://en.wikipedia.org/wiki/Double_factorial>`_ operator.
Args:
x: an array of type `float32`, `float64`.
series_order: Positive Python `integer`. Maximum depth to
series_order: Positive Python integer. Maximum depth to
evaluate the asymptotic expansion. This is the `N` above.
Returns:
log_ndtr: an array with `dtype=x.dtype`.
an array with `dtype=x.dtype`.
Raises:
TypeError: if `x.dtype` is not handled.