mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Document valid enum values for precision. (#3441)
This is a little tricky to figure out otherwise.
This commit is contained in:
parent
c9d1b99e51
commit
3deada9ede
@ -497,8 +497,9 @@ def conv_general_dilated(
|
||||
of length `n+2`.
|
||||
feature_group_count: integer, default 1. See XLA HLO docs.
|
||||
batch_group_count: integer, default 1. See XLA HLO docs.
|
||||
precision: Optional. Either `None`, which means the default precision for
|
||||
the backend, or a `Precision` enum value.
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
|
||||
Returns:
|
||||
An array containing the convolution result.
|
||||
@ -566,8 +567,9 @@ def dot(lhs: Array, rhs: Array, precision: Optional[PrecisionType] = None) -> Ar
|
||||
Args:
|
||||
lhs: an array of rank 1 or 2.
|
||||
rhs: an array of rank 1 or 2.
|
||||
precision: Optional. Either `None`, which means the default precision for
|
||||
the backend, or a `Precision` enum value.
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
|
||||
Returns:
|
||||
An array containing the product.
|
||||
@ -597,8 +599,9 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers,
|
||||
dimension_numbers: a tuple of tuples of the form
|
||||
`((lhs_contracting_dims, rhs_contracting_dims),
|
||||
(lhs_batch_dims, rhs_batch_dims))`
|
||||
precision: Optional. Either `None`, which means the default precision for
|
||||
the backend, or a `Precision` enum value.
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
|
||||
Returns:
|
||||
An array containing the result.
|
||||
@ -1345,8 +1348,9 @@ def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
|
||||
window_strides: a sequence of `n` integers, representing the inter-window
|
||||
strides.
|
||||
padding: either the string `'SAME'`, the string `'VALID'`.
|
||||
precision: Optional. Either `None`, which means the default precision for
|
||||
the backend, or a `Precision` enum value.
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
|
||||
Returns:
|
||||
An array containing the convolution result.
|
||||
@ -1376,8 +1380,9 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
|
||||
rhs_dilation: `None`, or a sequence of `n` integers, giving the
|
||||
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
|
||||
is also known as atrous convolution.
|
||||
precision: Optional. Either `None`, which means the default precision for
|
||||
the backend, or a `Precision` enum value.
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
|
||||
Returns:
|
||||
An array containing the convolution result.
|
||||
@ -1448,8 +1453,9 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
to the gradient-derived functions like keras.layers.Conv2DTranspose
|
||||
applied to the same kernel. For typical use in neural nets this is completely
|
||||
pointless and just makes input/output channel specification confusing.
|
||||
precision: Optional. Either `None`, which means the default precision for
|
||||
the backend, or a `Precision` enum value.
|
||||
precision: Optional. Either ``None``, which means the default precision for
|
||||
the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
|
||||
``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
|
||||
Returns:
|
||||
Transposed N-d convolution, with output padding following the conventions of
|
||||
|
@ -66,7 +66,9 @@ newaxis = None
|
||||
_PRECISION_DOC = """\
|
||||
In addition to the original NumPy arguments listed below, also supports
|
||||
``precision`` for extra control over matrix-multiplication precision
|
||||
on supported devices. See :py:func:`jax.lax.dot` for details.
|
||||
on supported devices. ``precision`` may be set to ``None``, which means
|
||||
default precision for the backend, or any ``jax.lax.Precision`` enum value
|
||||
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``).
|
||||
"""
|
||||
|
||||
# We replace some builtin names to follow Numpy's API, so we capture here.
|
||||
|
Loading…
x
Reference in New Issue
Block a user