Add some doc strings to lax primitives.

Since lax is a semipublic API, its public methods need at least minimal documentation. Many of the docstrings added in this PR are somewhat redundant, but at least a few contain useful information, and the documentation reads better with at least some minimal text for each function.

Hide some methods that shouldn't be public from the lax API docs.
This commit is contained in:
Peter Hawkins 2019-02-19 11:30:31 -05:00
parent febadd7354
commit c5aa87f4f1

View File

@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
`lax` is a library of primitives that underpins libraries such as `jax.numpy`.
Many of the primitives are thin wrappers around equivalent XLA operations,
described by the `XLA operation semantics
<https://www.tensorflow.org/xla/operation_semantics>`_ documentation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -70,77 +78,240 @@ def broadcast_shapes(*shapes):
return tuple(result_shape)
def identity(x): return x
def identity(x):
r"""Identity function: :math:`x`."""
return x
### traceables
def neg(x): return neg_p.bind(x)
def sign(x): return sign_p.bind(x)
def floor(x): return floor_p.bind(x)
def ceil(x): return ceil_p.bind(x)
def round(x): return round_p.bind(x)
def neg(x):
r"""Elementwise negation: :math:`-x`."""
return neg_p.bind(x)
def is_finite(x): return is_finite_p.bind(x)
def sign(x):
r"""Elementwise sign.
def exp(x): return exp_p.bind(x)
def expm1(x): return expm1_p.bind(x)
def log(x): return log_p.bind(x)
def log1p(x): return log1p_p.bind(x)
def tanh(x): return tanh_p.bind(x)
def sin(x): return sin_p.bind(x)
def cos(x): return cos_p.bind(x)
def atan2(x, y): return atan2_p.bind(x, y)
:math:`\mathrm{sign}(x) = \begin{cases}
-1 & x < 0\\
-0 & x = -0\\
\mathit{NaN} & x = \mathit{NaN}\\
+0 & x = +0\\
1 & x > 0
\end{cases}`.
"""
return sign_p.bind(x)
def lgamma(x): return lgamma_p.bind(x)
def digamma(x): return digamma_p.bind(x)
def erf(x): return erf_p.bind(x)
def erfc(x): return erfc_p.bind(x)
def erf_inv(x): return erf_inv_p.bind(x)
def floor(x):
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
return floor_p.bind(x)
def real(x): return real_p.bind(x)
def imag(x): return imag_p.bind(x)
def complex(x, y): return complex_p.bind(_brcast(x, y), _brcast(y, x))
def conj(x): return conj_p.bind(x, input_dtype=_dtype(x))
def abs(x): return abs_p.bind(x)
def pow(x, y): return pow_p.bind(x, y)
def ceil(x):
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
return ceil_p.bind(x)
def bitwise_not(x): return not_p.bind(x)
def bitwise_and(x, y): return and_p.bind(x, y)
def bitwise_or(x, y): return or_p.bind(x, y)
def bitwise_xor(x, y): return xor_p.bind(x, y)
def round(x):
r"""Elementwise round.
def add(x, y): return add_p.bind(x, y)
def sub(x, y): return sub_p.bind(x, y)
def mul(x, y): return mul_p.bind(x, y)
def div(x, y): return div_p.bind(x, y)
def rem(x, y): return rem_p.bind(x, y)
Rounds values to the nearest integer. Halfway values (e.g., `0.5`) are rounded
away from zero."""
return round_p.bind(x)
def is_finite(x):
r"""Elementwise :math:`\mathrm{isfinite}`.
For each element x returns `True` if and only if x is not :math:`\pm\infty` or
:math:`\mathit{NaN}`.
"""
return is_finite_p.bind(x)
def exp(x):
r"""Elementwise exponential: :math:`e^x`."""
return exp_p.bind(x)
def expm1(x):
r"""Elementwise :math:`e^{x - 1}`."""
return expm1_p.bind(x)
def log(x):
r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`."""
return log_p.bind(x)
def log1p(x):
r"""Elementwise :math:`\mathrm{log}(1 + x)`."""
return log1p_p.bind(x)
def tanh(x):
r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`."""
return tanh_p.bind(x)
def sin(x):
r"""Elementwise sine: :math:`\mathrm{sin}(x)`."""
return sin_p.bind(x)
def cos(x):
r"""Elementwise cosine: :math:`\mathrm{cos}(x)`."""
return cos_p.bind(x)
def atan2(x, y):
r"""Elementwise arc tangent of two variables:
:math:`\mathrm{atan}({x \over y})`."""
return atan2_p.bind(x, y)
def lgamma(x):
r"""Elementwise log gamma: :math:`\mathrm{log}(\Gamma(x))`."""
return lgamma_p.bind(x)
def digamma(x):
r"""Elementwise digamma: :math:`\psi(x)`."""
return digamma_p.bind(x)
def erf(x):
r"""Elementwise error function: :math:`\mathrm{erf}(x)`."""
return erf_p.bind(x)
def erfc(x):
r"""Elementwise complementary error function:
:math:`\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)`."""
return erfc_p.bind(x)
def erf_inv(x):
r"""Elementwise inverse error function: :math:`\mathrm{erf}^{-1}(x)`."""
return erf_inv_p.bind(x)
def real(x):
r"""Elementwise extract real part: :math:`\mathrm{Re}(x)`.
Returns the real part of a complex number.
"""
return real_p.bind(x)
def imag(x):
r"""Elementwise extract imaginary part: :math:`\mathrm{Im}(x)`.
Returns the imaginary part of a complex number.
"""
return imag_p.bind(x)
def complex(x, y):
r"""Elementwise make complex number: :math:`x + jy`.
Builds a complex number from real and imaginary parts.
"""
return complex_p.bind(_brcast(x, y), _brcast(y, x))
def conj(x):
r"""Elementwise complex conjugate function: :math:`\overline{x}`."""
return conj_p.bind(x, input_dtype=_dtype(x))
def abs(x):
r"""Elementwise absolute value: :math:`|x|`."""
return abs_p.bind(x)
def pow(x, y):
r"""Elementwise power: :math:`x^y`."""
return pow_p.bind(x, y)
def bitwise_not(x):
r"""Elementwise NOT: :math:`\neg x`."""
return not_p.bind(x)
def bitwise_and(x, y):
r"""Elementwise AND: :math:`x \wedge y`."""
return and_p.bind(x, y)
def bitwise_or(x, y):
r"""Elementwise OR: :math:`x \vee y`."""
return or_p.bind(x, y)
def bitwise_xor(x, y):
r"""Elementwise exclusive OR: :math:`x \oplus y`."""
return xor_p.bind(x, y)
def add(x, y):
r"""Elementwise addition: :math:`x + y`."""
return add_p.bind(x, y)
def sub(x, y):
r"""Elementwise subtraction: :math:`x - y`."""
return sub_p.bind(x, y)
def mul(x, y):
r"""Elementwise multiplication: :math:`x \times y`."""
return mul_p.bind(x, y)
def div(x, y):
r"""Elementwise division: :math:`x \over y`."""
return div_p.bind(x, y)
def rem(x, y):
r"""Elementwise remainder: :math:`x \bmod y`."""
return rem_p.bind(x, y)
def max(x, y):
"""Elementwise maximum.
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
return max_p.bind(x, y)
def min(x, y):
"""Elementwise minimum.
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
For complex numbers, uses a lexicographic comparison on the
`(real, imaginary)` pairs."""
return min_p.bind(x, y)
def shift_left(x, y): return shift_left_p.bind(x, y)
def shift_right_arithmetic(x, y): return shift_right_arithmetic_p.bind(x, y)
def shift_right_logical(x, y): return shift_right_logical_p.bind(x, y)
def shift_left(x, y):
r"""Elementwise left shift: :math:`x \ll y`."""
return shift_left_p.bind(x, y)
def eq(x, y): return eq_p.bind(x, y)
def ne(x, y): return ne_p.bind(x, y)
def ge(x, y): return ge_p.bind(x, y)
def gt(x, y): return gt_p.bind(x, y)
def le(x, y): return le_p.bind(x, y)
def lt(x, y): return lt_p.bind(x, y)
def shift_right_arithmetic(x, y):
r"""Elementwise arithmetic right shift: :math:`x \gg y`."""
return shift_right_arithmetic_p.bind(x, y)
def shift_right_logical(x, y):
r"""Elementwise logical right shift: :math:`x \gg y`."""
return shift_right_logical_p.bind(x, y)
def eq(x, y):
r"""Elementwise equals: :math:`x = y`."""
return eq_p.bind(x, y)
def ne(x, y):
r"""Elementwise not-equals: :math:`x \neq y`."""
return ne_p.bind(x, y)
def ge(x, y):
r"""Elementwise greater-than-or-equals: :math:`x \geq y`."""
return ge_p.bind(x, y)
def gt(x, y):
r"""Elementwise greater-than: :math:`x > y`."""
return gt_p.bind(x, y)
def le(x, y):
r"""Elementwise less-than-or-equals: :math:`x \leq y`."""
return le_p.bind(x, y)
def lt(x, y):
r"""Elementwise less-than: :math:`x < y`."""
return lt_p.bind(x, y)
def convert_element_type(operand, new_dtype):
"""Elementwise cast.
Wraps XLA's `ConvertElementType
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
operator, which performs an elementwise conversion from one type to another.
Similar to a C++ `static_cast`.
Args:
operand: an array or scalar value to be cast
new_dtype: the new type. Should be a NumPy type.
Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
"""
new_dtype = xla_bridge.canonicalize_dtype(new_dtype)
old_dtype = _dtype(operand)
if old_dtype != new_dtype:
@ -155,6 +326,21 @@ def convert_element_type(operand, new_dtype):
return operand
def bitcast_convert_type(operand, new_dtype):
"""Elementwise bitcast.
Wraps XLA's `BitcastConvertType
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
operator, which performs a bit cast from one type to another. The bitwidth
of the source and destination types must match.
Args:
operand: an array or scalar value to be cast
new_dtype: the new type. Should be a NumPy type.
Returns:
An array with the same shape as `operand`, bitcast elementwise to
`new_dtype`.
"""
new_dtype = xla_bridge.canonicalize_dtype(new_dtype)
old_dtype = _dtype(operand)
if old_dtype != new_dtype:
@ -162,8 +348,16 @@ def bitcast_convert_type(operand, new_dtype):
else:
return operand
def clamp(min, operand, max):
return clamp_p.bind(min, operand, max)
def clamp(min, x, max):
r"""Elementwise clamp.
Returns :math:`\mathrm{clamp}(x) = \begin{cases}
\mathit{min} & \text{if } x < \mathit{min},\\
\mathit{max} & \text{if } x > \mathit{max},\\
x & \text{otherwise}
\end{cases}`.
"""
return clamp_p.bind(min, x, max)
def concatenate(operands, dimension):
return concatenate_p.bind(*operands, dimension=dimension,
@ -689,47 +883,60 @@ def batch_matmul(lhs, rhs):
# as non-primitive to maintain a smaller set of autodiff primitives.
def sqrt(x):
r"""Elementwise square root: :math:`\sqrt{x}`."""
return pow(x, _const(x, 0.5))
def rsqrt(x):
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
return pow(x, _const(x, -0.5))
def square(x):
r"""Elementwise square: :math:`x^2`."""
return mul(x, x)
def reciprocal(x):
r"""Elementwise reciprocal: :math:`1 \over x`."""
return div(_const(x, 1), x)
def tan(x):
r"""Elementwise tangent: :math:`\mathrm{tan}(x)`."""
return div(sin(x), cos(x))
def asin(x):
r"""Elementwise arc sine: :math:`\mathrm{asin}(x)`."""
return mul(_const(x, 2),
atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x))))))
def acos(x):
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
return mul(_const(x, 2),
atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x)))
def atan(x):
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
return atan2(x, _const(x, 1))
def sinh(x):
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`."""
return mul(_const(x, 0.5), sub(exp(x), exp(neg(x))))
def cosh(x):
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`."""
return mul(_const(x, 0.5), add(exp(x), exp(neg(x))))
def asinh(x):
r"""Elementwise arc hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
# asinh(x) = log(x + sqrt(x**2 + 1))
return log(add(x, sqrt(add(mul(x, x), _const(x, 1)))))
def acosh(x):
r"""Elementwise arc hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
# acosh(x) = log(x + sqrt((x + 1) * (x - 1)))
return log(add(x, mul(sqrt(add(x, _const(x, 1))),
sqrt(sub(x, _const(x, 1))))))
def atanh(x):
r"""Elementwise arc hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
# atanh(x) = 0.5 * log((1 + x) / (1 - x))
return mul(_const(x, 0.5), log(div(add(_const(x, 1), x),
sub(_const(x, 1), x))))
@ -825,7 +1032,7 @@ def binop_dtype_rule(result_dtype, accepted_dtypes, name, *avals, **kwargs):
return result_dtype(*avals)
def broadcasting_shape_rule(name, *avals):
def _broadcasting_shape_rule(name, *avals):
shapes = onp.array([aval.shape for aval in avals if aval.shape])
if not shapes.size:
return ()
@ -843,7 +1050,7 @@ def broadcasting_shape_rule(name, *avals):
def binop(result_dtype, accepted_dtypes, name, translation_rule=None):
dtype_rule = partial(binop_dtype_rule, result_dtype, accepted_dtypes, name)
shape_rule = partial(broadcasting_shape_rule, name)
shape_rule = partial(_broadcasting_shape_rule, name)
prim = standard_primitive(shape_rule, dtype_rule, name,
translation_rule=translation_rule)
batching.defbroadcasting(prim)
@ -984,17 +1191,17 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
# TODO handle broadcasting
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
def pow_jvp_lhs(g, x, y):
def _pow_jvp_lhs(g, x, y):
# we call _safe_mul here so that we get the behavior 0*inf = 0, since when a
# coefficient in `g` is zero we want to keep it at zero, not produce a nan.
# see https://github.com/google/jax/pull/383
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
return _safe_mul(_brcast(g, y), jac)
def pow_jvp_rhs(g, x, y):
def _pow_jvp_rhs(g, x, y):
return mul(_brcast(g, x), mul(log(_replace_zero(x)), pow(x, y)))
ad.defjvp(pow_p, pow_jvp_lhs, pow_jvp_rhs)
ad.defjvp(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
not_p = standard_unop(_int | _bool, 'not')
@ -3015,46 +3222,46 @@ class EyeConstant(xla.DeviceConstant):
return c.ConvertElementType(_reduce(c.And, eyes), etype)
for t in [FilledConstant, IotaConstant, EyeConstant]:
xla_bridge.register_constant_handler(t, t.constant_handler)
core.pytype_aval_mappings[t] = ConcreteArray
xla.pytype_aval_mappings[t] = xla.pytype_aval_mappings[xla.DeviceArray]
xla.canonicalize_dtype_handlers[t] = identity
batching.pytype_aval_mappings[t] = make_shaped_array
ad_util.jaxval_adders[t] = add
ad_util.jaxval_zeros_likers[t] = zeros_like_array
for _t in [FilledConstant, IotaConstant, EyeConstant]:
xla_bridge.register_constant_handler(_t, _t.constant_handler)
core.pytype_aval_mappings[_t] = ConcreteArray
xla.pytype_aval_mappings[_t] = xla.pytype_aval_mappings[xla.DeviceArray]
xla.canonicalize_dtype_handlers[_t] = identity
batching.pytype_aval_mappings[_t] = make_shaped_array
ad_util.jaxval_adders[_t] = add
ad_util.jaxval_zeros_likers[_t] = zeros_like_array
### parallel
def PmapPrimitive(name):
prim = Primitive(name)
prim.def_impl(partial(unbound_name_error, name))
prim.def_impl(partial(_unbound_name_error, name))
prim.def_abstract_eval(lambda x, *args, **kwargs: x) # default
return prim
def unbound_name_error(primitive_name, *args, **kwargs):
def _unbound_name_error(primitive_name, *args, **kwargs):
axis_name = kwargs['axis_name']
msg = "axis name '{}' is unbound for primitive {}."
raise NameError(msg.format(axis_name, primitive_name))
def psum_transpose_rule(t, axis_name):
def _psum_transpose_rule(t, axis_name):
return [t]
def psum_parallel_translation_rule(c, val, device_groups):
def _psum_parallel_translation_rule(c, val, device_groups):
if len(device_groups) > 1:
return c.CrossReplicaSum(val, device_groups)
else:
return c.CrossReplicaSum(val)
def psum_pmap_rule(val, axis):
def _psum_pmap_rule(val, axis):
return _reduce_sum(val, [axis]), None
psum_p = PmapPrimitive('psum')
parallel.pmap_primitive_rules[psum_p] = psum_pmap_rule
pxla.parallel_translation_rules[psum_p] = psum_parallel_translation_rule
ad.deflinear(psum_p, psum_transpose_rule)
parallel.pmap_primitive_rules[psum_p] = _psum_pmap_rule
pxla.parallel_translation_rules[psum_p] = _psum_parallel_translation_rule
ad.deflinear(psum_p, _psum_transpose_rule)
parallel.defreducer(reduce_sum_p, psum_p)