Document valid enum values for precision. (#3441)

This is a little tricky to figure out otherwise.
This commit is contained in:
Stephan Hoyer 2020-06-14 21:42:45 -07:00 committed by GitHub
parent c9d1b99e51
commit 3deada9ede
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 13 deletions

View File

@ -497,8 +497,9 @@ def conv_general_dilated(
of length `n+2`.
feature_group_count: integer, default 1. See XLA HLO docs.
batch_group_count: integer, default 1. See XLA HLO docs.
precision: Optional. Either `None`, which means the default precision for
the backend, or a `Precision` enum value.
precision: Optional. Either ``None``, which means the default precision for
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
Returns:
An array containing the convolution result.
@ -566,8 +567,9 @@ def dot(lhs: Array, rhs: Array, precision: Optional[PrecisionType] = None) -> Ar
Args:
lhs: an array of rank 1 or 2.
rhs: an array of rank 1 or 2.
precision: Optional. Either `None`, which means the default precision for
the backend, or a `Precision` enum value.
precision: Optional. Either ``None``, which means the default precision for
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
Returns:
An array containing the product.
@ -597,8 +599,9 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`
precision: Optional. Either `None`, which means the default precision for
the backend, or a `Precision` enum value.
precision: Optional. Either ``None``, which means the default precision for
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
Returns:
An array containing the result.
@ -1345,8 +1348,9 @@ def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`.
precision: Optional. Either `None`, which means the default precision for
the backend, or a `Precision` enum value.
precision: Optional. Either ``None``, which means the default precision for
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
Returns:
An array containing the convolution result.
@ -1376,8 +1380,9 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
precision: Optional. Either `None`, which means the default precision for
the backend, or a `Precision` enum value.
precision: Optional. Either ``None``, which means the default precision for
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
Returns:
An array containing the convolution result.
@ -1448,8 +1453,9 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
to the gradient-derived functions like keras.layers.Conv2DTranspose
applied to the same kernel. For typical use in neural nets this is completely
pointless and just makes input/output channel specification confusing.
precision: Optional. Either `None`, which means the default precision for
the backend, or a `Precision` enum value.
precision: Optional. Either ``None``, which means the default precision for
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``).
Returns:
Transposed N-d convolution, with output padding following the conventions of

View File

@ -66,7 +66,9 @@ newaxis = None
_PRECISION_DOC = """\
In addition to the original NumPy arguments listed below, also supports
``precision`` for extra control over matrix-multiplication precision
on supported devices. See :py:func:`jax.lax.dot` for details.
on supported devices. ``precision`` may be set to ``None``, which means
default precision for the backend, or any ``jax.lax.Precision`` enum value
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``).
"""
# We replace some builtin names to follow Numpy's API, so we capture here.