mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
436 lines
17 KiB
Python
436 lines
17 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import re
|
|
import textwrap
|
|
from typing import (
|
|
Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Type, TypeVar
|
|
)
|
|
import warnings
|
|
|
|
from jax._src import dtypes
|
|
from jax._src import api
|
|
from jax._src import core
|
|
from jax._src.config import config
|
|
from jax._src.lax import lax
|
|
from jax._src.numpy.ndarray import ndarray
|
|
from jax._src.util import safe_zip, safe_map
|
|
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
|
|
|
|
import numpy as np
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
map, unsafe_map = safe_map, map
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
_parameter_break = re.compile("\n(?=[A-Za-z_])")
|
|
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
|
|
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$', re.MULTILINE)
|
|
_versionadded = re.compile(r'^\s+\.\.\s+versionadded::', re.MULTILINE)
|
|
_docreference = re.compile(r':doc:`(.*?)\s*<.*?>`')
|
|
|
|
class ParsedDoc(NamedTuple):
|
|
"""
|
|
docstr: full docstring
|
|
signature: signature from docstring.
|
|
summary: summary from docstring.
|
|
front_matter: front matter before sections.
|
|
sections: dictionary of section titles to section content.
|
|
"""
|
|
docstr: Optional[str]
|
|
signature: str = ""
|
|
summary: str = ""
|
|
front_matter: str = ""
|
|
sections: Dict[str, str] = {}
|
|
|
|
|
|
def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
|
|
"""Parse a standard numpy-style docstring.
|
|
|
|
Args:
|
|
docstr: the raw docstring from a function
|
|
Returns:
|
|
ParsedDoc: parsed version of the docstring
|
|
"""
|
|
if docstr is None or not docstr.strip():
|
|
return ParsedDoc(docstr)
|
|
|
|
# Remove any :doc: directives in the docstring to avoid sphinx errors
|
|
docstr = _docreference.sub(
|
|
lambda match: f"{match.groups()[0]}", docstr)
|
|
|
|
signature, body = "", docstr
|
|
match = _numpy_signature_re.match(body)
|
|
if match:
|
|
signature = match.group()
|
|
body = docstr[match.end():]
|
|
|
|
firstline, _, body = body.partition('\n')
|
|
body = textwrap.dedent(body.lstrip('\n'))
|
|
|
|
match = _numpy_signature_re.match(body)
|
|
if match:
|
|
signature = match.group()
|
|
body = body[match.end():]
|
|
|
|
summary = firstline
|
|
if not summary:
|
|
summary, _, body = body.lstrip('\n').partition('\n')
|
|
body = textwrap.dedent(body.lstrip('\n'))
|
|
|
|
front_matter = ""
|
|
body = "\n" + body
|
|
section_list = _section_break.split(body)
|
|
if not _section_break.match(section_list[0]):
|
|
front_matter, *section_list = section_list
|
|
sections = {section.split('\n', 1)[0]: section for section in section_list}
|
|
|
|
return ParsedDoc(docstr=docstr, signature=signature, summary=summary,
|
|
front_matter=front_matter, sections=sections)
|
|
|
|
|
|
def _parse_parameters(body: str) -> Dict[str, str]:
|
|
"""Parse the Parameters section of a docstring."""
|
|
title, underline, content = body.split('\n', 2)
|
|
assert title == 'Parameters'
|
|
assert underline and not underline.strip('-')
|
|
parameters = _parameter_break.split(content)
|
|
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
|
|
|
|
|
|
def _parse_extra_params(extra_params: str) -> Dict[str, str]:
|
|
"""Parse the extra parameters passed to _wraps()"""
|
|
parameters = _parameter_break.split(extra_params.strip('\n'))
|
|
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}
|
|
|
|
|
|
def _wraps(
|
|
fun: Optional[Callable[..., Any]],
|
|
update_doc: bool = True,
|
|
lax_description: str = "",
|
|
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
|
|
skip_params: Sequence[str] = (),
|
|
extra_params: Optional[str] = None,
|
|
module: Optional[str] = None,
|
|
) -> Callable[[_T], _T]:
|
|
"""Specialized version of functools.wraps for wrapping numpy functions.
|
|
|
|
This produces a wrapped function with a modified docstring. In particular, if
|
|
`update_doc` is True, parameters listed in the wrapped function that are not
|
|
supported by the decorated function will be removed from the docstring. For
|
|
this reason, it is important that parameter names match those in the original
|
|
numpy function.
|
|
|
|
Args:
|
|
fun: The function being wrapped
|
|
update_doc: whether to transform the numpy docstring to remove references of
|
|
parameters that are supported by the numpy version but not the JAX version.
|
|
If False, include the numpy docstring verbatim.
|
|
lax_description: a string description that will be added to the beginning of
|
|
the docstring.
|
|
sections: a list of sections to include in the docstring. The default is
|
|
["Parameters", "returns", "References"]
|
|
skip_params: a list of strings containing names of parameters accepted by the
|
|
function that should be skipped in the parameter list.
|
|
extra_params: an optional string containing additional parameter descriptions.
|
|
When ``update_doc=True``, these will be added to the list of parameter
|
|
descriptions in the updated doc.
|
|
module: an optional string specifying the module from which the wrapped function
|
|
is imported. This is useful for objects such as ufuncs, where the module cannot
|
|
be determined from the wrapped function itself.
|
|
"""
|
|
def wrap(op):
|
|
op.__np_wrapped__ = fun
|
|
# Allows this pattern: @wraps(getattr(np, 'new_function', None))
|
|
if fun is None:
|
|
return op
|
|
docstr = getattr(fun, "__doc__", None)
|
|
name = getattr(fun, "__name__", getattr(op, "__name__", str(op)))
|
|
try:
|
|
mod = module or fun.__module__
|
|
except AttributeError:
|
|
if config.jax_enable_checks:
|
|
raise ValueError(f"function {fun} defines no __module__; pass module keyword to _wraps.")
|
|
else:
|
|
name = f"{mod}.{name}"
|
|
if docstr:
|
|
try:
|
|
parsed = _parse_numpydoc(docstr)
|
|
|
|
if update_doc and 'Parameters' in parsed.sections:
|
|
code = getattr(getattr(op, "__wrapped__", op), "__code__", None)
|
|
# Remove unrecognized parameter descriptions.
|
|
parameters = _parse_parameters(parsed.sections['Parameters'])
|
|
if extra_params:
|
|
parameters.update(_parse_extra_params(extra_params))
|
|
parameters = {p: desc for p, desc in parameters.items()
|
|
if (code is None or p in code.co_varnames)
|
|
and p not in skip_params}
|
|
if parameters:
|
|
parsed.sections['Parameters'] = (
|
|
"Parameters\n"
|
|
"----------\n" +
|
|
"\n".join(_versionadded.split(desc)[0].rstrip()
|
|
for p, desc in parameters.items())
|
|
)
|
|
else:
|
|
del parsed.sections['Parameters']
|
|
|
|
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
|
|
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
|
|
if lax_description:
|
|
docstr += "\n" + lax_description.strip() + "\n"
|
|
docstr += "\n*Original docstring below.*\n"
|
|
|
|
# We remove signatures from the docstrings, because they redundant at best and
|
|
# misleading at worst: e.g. JAX wrappers don't implement all ufunc keyword arguments.
|
|
# if parsed.signature:
|
|
# docstr += "\n" + parsed.signature.strip() + "\n"
|
|
|
|
if parsed.front_matter:
|
|
docstr += "\n" + parsed.front_matter.strip() + "\n"
|
|
kept_sections = (content.strip() for section, content in parsed.sections.items()
|
|
if section in sections)
|
|
if kept_sections:
|
|
docstr += "\n" + "\n\n".join(kept_sections) + "\n"
|
|
except:
|
|
if config.jax_enable_checks:
|
|
raise
|
|
docstr = fun.__doc__
|
|
|
|
op.__doc__ = docstr
|
|
for attr in ['__name__', '__qualname__']:
|
|
try:
|
|
value = getattr(fun, attr)
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
setattr(op, attr, value)
|
|
return op
|
|
return wrap
|
|
|
|
_dtype = partial(dtypes.dtype, canonicalize=True)
|
|
|
|
def _asarray(arr: ArrayLike) -> Array:
|
|
"""
|
|
Pared-down utility to convert object to a DeviceArray.
|
|
Note this will not correctly handle lists or tuples.
|
|
"""
|
|
_check_arraylike("_asarray", arr)
|
|
dtype, weak_type = dtypes._lattice_result_type(arr)
|
|
return lax._convert_element_type(arr, dtype, weak_type)
|
|
|
|
def _promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
|
|
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
|
|
if len(args) < 2:
|
|
return [_asarray(arg) for arg in args]
|
|
else:
|
|
shapes = [np.shape(arg) for arg in args]
|
|
if config.jax_dynamic_shapes:
|
|
# With dynamic shapes we don't support singleton-dimension broadcasting;
|
|
# we instead broadcast out to the full shape as a temporary workaround.
|
|
# TODO(mattjj): revise this workaround
|
|
res_shape = lax.broadcast_shapes(*shapes) # Can raise an error!
|
|
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
|
|
else:
|
|
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
|
|
return [_asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
|
|
nonscalar_ranks = {len(shp) for shp in shapes if shp}
|
|
if len(nonscalar_ranks) < 2:
|
|
return [_asarray(arg) for arg in args] # rely on lax scalar promotion
|
|
else:
|
|
if config.jax_numpy_rank_promotion != "allow":
|
|
_rank_promotion_warning_or_error(fun_name, shapes)
|
|
result_rank = len(lax.broadcast_shapes(*shapes))
|
|
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
|
|
for arg, shp in zip(args, shapes)]
|
|
|
|
|
|
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
|
|
if config.jax_numpy_rank_promotion == "warn":
|
|
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
|
|
"Set the jax_numpy_rank_promotion config option to 'allow' to "
|
|
"disable this warning; for more information, see "
|
|
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
|
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
|
|
elif config.jax_numpy_rank_promotion == "raise":
|
|
msg = ("Operands could not be broadcast together for {} on shapes {} "
|
|
"and with the config option jax_numpy_rank_promotion='raise'. "
|
|
"For more information, see "
|
|
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
|
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
|
|
|
|
|
|
def _promote_dtypes(*args: ArrayLike) -> List[Array]:
|
|
"""Convenience function to apply Numpy argument dtype promotion."""
|
|
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
|
if len(args) < 2:
|
|
return [_asarray(arg) for arg in args]
|
|
else:
|
|
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
|
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
|
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
|
|
|
|
|
|
def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
|
|
"""Convenience function to apply Numpy argument dtype promotion.
|
|
|
|
Promotes arguments to an inexact type."""
|
|
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
|
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
|
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype)
|
|
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
|
|
for x in args]
|
|
|
|
|
|
def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
|
|
"""Convenience function to apply Numpy argument dtype promotion.
|
|
|
|
Promotes arguments to a numeric (non-bool) type."""
|
|
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
|
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
|
to_dtype_numeric = dtypes.to_numeric_dtype(to_dtype)
|
|
return [lax._convert_element_type(x, to_dtype_numeric, weak_type)
|
|
for x in args]
|
|
|
|
|
|
def _promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
|
|
"""Convenience function to apply Numpy argument dtype promotion.
|
|
|
|
Promotes arguments to a complex type."""
|
|
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
|
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
|
to_dtype_complex = dtypes.to_complex_dtype(to_dtype)
|
|
return [lax._convert_element_type(x, to_dtype_complex, weak_type)
|
|
for x in args]
|
|
|
|
|
|
def _complex_elem_type(dtype: DTypeLike) -> DType:
|
|
"""Returns the float type of the real/imaginary parts of a complex dtype."""
|
|
return np.abs(np.zeros((), dtype)).dtype
|
|
|
|
|
|
def _arraylike(x: ArrayLike) -> bool:
|
|
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
|
|
hasattr(x, '__jax_array__') or np.isscalar(x))
|
|
|
|
|
|
def _stackable(*args: Any) -> bool:
|
|
return all(type(arg) in stackables for arg in args)
|
|
stackables: Set[Type] = set()
|
|
_register_stackable: Callable[[Type], None] = stackables.add
|
|
|
|
|
|
def _check_arraylike(fun_name: str, *args: Any):
|
|
"""Check if all args fit JAX's definition of arraylike."""
|
|
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
|
|
if any(not _arraylike(arg) for arg in args):
|
|
pos, arg = next((i, arg) for i, arg in enumerate(args)
|
|
if not _arraylike(arg))
|
|
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
|
|
raise TypeError(msg.format(fun_name, type(arg), pos))
|
|
|
|
|
|
def _check_no_float0s(fun_name: str, *args: Any):
|
|
"""Check if none of the args have dtype float0."""
|
|
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
|
|
raise TypeError(
|
|
f"Called {fun_name} with a float0 array. "
|
|
"float0s do not support any operations by design because they "
|
|
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
|
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
|
"to cast a float0 array to a regular zeros array. \n"
|
|
"If you didn't expect to get a float0 you might have accidentally "
|
|
"taken a gradient with respect to an integer argument.")
|
|
|
|
|
|
def _promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
|
|
"""Convenience function to apply Numpy argument shape and dtype promotion."""
|
|
_check_arraylike(fun_name, *args)
|
|
_check_no_float0s(fun_name, *args)
|
|
return _promote_shapes(fun_name, *_promote_dtypes(*args))
|
|
|
|
|
|
def _promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
|
|
_check_arraylike(fun_name, *args)
|
|
_check_no_float0s(fun_name, *args)
|
|
return _promote_shapes(fun_name, *_promote_dtypes_numeric(*args))
|
|
|
|
|
|
def _promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
|
|
"""Convenience function to apply Numpy argument shape and dtype promotion.
|
|
|
|
Promotes non-inexact types to an inexact type."""
|
|
_check_arraylike(fun_name, *args)
|
|
_check_no_float0s(fun_name, *args)
|
|
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
|
|
|
|
|
|
@partial(api.jit, inline=True)
|
|
def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
|
|
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
|
shapes = [np.shape(arg) for arg in args]
|
|
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
|
|
return [_asarray(arg) for arg in args]
|
|
result_shape = lax.broadcast_shapes(*shapes)
|
|
return [_broadcast_to(arg, result_shape) for arg in args]
|
|
|
|
|
|
def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
|
|
if hasattr(arr, "broadcast_to"):
|
|
return arr.broadcast_to(shape) # type: ignore[union-attr]
|
|
_check_arraylike("broadcast_to", arr)
|
|
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
|
|
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
|
|
shape = (shape,)
|
|
shape = core.canonicalize_shape(shape) # check that shape is concrete
|
|
arr_shape = np.shape(arr)
|
|
if core.symbolic_equal_shape(arr_shape, shape):
|
|
return arr
|
|
else:
|
|
nlead = len(shape) - len(arr_shape)
|
|
shape_tail = shape[nlead:]
|
|
compatible = all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
|
|
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
|
|
if nlead < 0 or not compatible:
|
|
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
|
|
raise ValueError(msg.format(arr_shape, shape))
|
|
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
|
|
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
|
|
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
|
|
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
|
|
return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
|
|
|
|
|
|
# The `jit` on `where` exists to avoid materializing constants in cases like
|
|
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
|
|
# materialize the broadcast forms of scalar arguments.
|
|
@api.jit
|
|
def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
|
|
if x is None or y is None:
|
|
raise ValueError("Either both or neither of the x and y arguments should "
|
|
"be provided to jax.numpy.where, got {} and {}."
|
|
.format(x, y))
|
|
if not np.issubdtype(_dtype(condition), np.bool_):
|
|
condition = lax.ne(condition, lax._zero(condition))
|
|
x, y = _promote_dtypes(x, y)
|
|
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
|
|
try:
|
|
is_always_empty = core.is_empty_shape(x_arr.shape)
|
|
except:
|
|
is_always_empty = False # can fail with dynamic shapes
|
|
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr
|