lax_numpy.py: factor out some common utilities

Re-lands part of #9724

PiperOrigin-RevId: 433838553
This commit is contained in:
Jake VanderPlas 2022-03-10 13:34:42 -08:00 committed by jax authors
parent 5abd664938
commit ddf23dead3
2 changed files with 213 additions and 174 deletions

View File

@ -29,7 +29,7 @@ import collections
from functools import partial
import operator
import types
from typing import Any, Sequence, FrozenSet, Optional, Tuple, Union, Set, Type, Callable
from typing import Any, Sequence, FrozenSet, Optional, Tuple, Union
from textwrap import dedent as _dedent
import warnings
@ -42,7 +42,6 @@ from jax import core
from jax import errors
from jax import lax
from jax.core import ShapedArray, DShapedArray, ConcreteArray, canonicalize_shape
from jax.config import config
from jax.interpreters import pxla
from jax.tree_util import tree_leaves, tree_flatten, tree_map
@ -52,7 +51,10 @@ from jax._src.api_util import _ensure_index_tuple
from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator
from jax._src.lax import lax as lax_internal
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import ( # noqa: F401
_arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike, _complex_elem_type, _promote_args,
_promote_args_inexact, _promote_dtypes, _promote_dtypes_inexact, _promote_shapes, _register_stackable,
_stackable, _where, _wraps)
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
@ -206,74 +208,6 @@ _INT_DTYPES = {
_lax_const = lax_internal._const
def _promote_shapes(fun_name, *args):
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return args
else:
shapes = [shape(arg) for arg in args]
if _all(len(shapes[0]) == len(s) for s in shapes[1:]):
return args # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return args
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
if config.jax_dynamic_shapes:
# With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround.
res_shape = lax.broadcast_shapes(*shapes)
return [broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
result_rank = len(lax.broadcast_shapes(*shapes))
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]
def _rank_promotion_warning_or_error(fun_name, shapes):
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif config.jax_numpy_rank_promotion == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def _promote_dtypes(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype, weak_type)
for x in args]
def _promote_dtypes_inexact(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = _to_inexact_dtype(to_dtype)
weak_type = (weak_type and to_dtype == to_dtype_inexact)
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]
def _to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
return dtype if issubdtype(dtype, inexact) else promote_types(dtype, float_)
def _complex_elem_type(dtype):
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype
def _result_dtype(op, *args):
"""Compute result dtype of applying op to arguments with given dtypes."""
@ -281,51 +215,6 @@ def _result_dtype(op, *args):
return _dtype(op(*args))
def _arraylike(x):
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
hasattr(x, '__jax_array__') or isscalar(x))
def _stackable(*args):
return _all(type(arg) in stackables for arg in args)
stackables: Set[Type] = set()
_register_stackable: Callable[[Type], None] = stackables.add
def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if _any(not _arraylike(arg) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
if not _arraylike(arg))
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
raise TypeError(msg.format(fun_name, type(arg), pos))
def _check_no_float0s(fun_name, *args):
"""Check if none of the args have dtype float0."""
if _any(dtypes.dtype(arg) is dtypes.float0 for arg in args):
raise TypeError(
f"Called {fun_name} with a float0 array. "
"float0s do not support any operations by design because they "
"are not compatible with non-trivial vector spaces. No implicit dtype "
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
"to cast a float0 array to a regular zeros array. \n"
"If you didn't expect to get a float0 you might have accidentally "
"taken a gradient with respect to an integer argument.")
def _promote_args(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))
def _promote_args_inexact(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
def _convert_and_clip_integer(val, dtype):
"""
Convert integer-typed val to specified integer dtype, clipping to dtype
@ -1868,24 +1757,6 @@ def isin(element, test_elements, assume_unique=False, invert=False):
return result.reshape(shape(element))
# The `jit` on `where` exists to avoid materializing constants in cases like
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@jit
def _where(condition, x=None, y=None):
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not issubdtype(_dtype(condition), bool_):
condition = lax.ne(condition, zeros_like(condition))
x, y = _promote_dtypes(x, y)
condition, x, y = broadcast_arrays(condition, x, y)
try: is_always_empty = core.is_empty_shape(np.shape(x))
except: is_always_empty = False # can fail with dynamic shapes
return lax.select(condition, x, y) if not is_always_empty else x
@_wraps(np.where,
lax_description=_dedent("""
At present, JAX does not support JIT-compilation of the single-argument form
@ -1966,47 +1837,15 @@ def broadcast_shapes(*shapes):
shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes]
return lax.broadcast_shapes(*shapes)
@partial(jit, inline=True)
def broadcast_arrays(*args):
"""Like Numpy's broadcast_arrays but doesn't return views."""
# Avoid calling _check_arraylike() here to allow passing through objects
# like PRNGKeyArray which are specially handled in broadcast_to() below.
shapes = [shape(arg) for arg in args]
if not shapes or _all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
# TODO(mattjj): remove the array(arg) here
return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg)
for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [broadcast_to(arg, result_shape) for arg in args]
@_wraps(np.broadcast_to, lax_description="""\
broadcast_arrays = _wraps(np.broadcast_arrays, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")
def broadcast_to(arr, shape):
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
_check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, ndarray) else array(arr)
if not isinstance(shape, tuple) and ndim(shape) == 0:
shape = (shape,)
shape = canonicalize_shape(shape) # check that shape is concrete
arr_shape = _shape(arr)
if core.symbolic_equal_shape(arr_shape, shape):
return arr
else:
nlead = len(shape) - len(arr_shape)
shape_tail = shape[nlead:]
compatible = _all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
if nlead < 0 or not compatible:
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims)
""")(_broadcast_arrays)
broadcast_to = _wraps(np.broadcast_to, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")(_broadcast_to)
def _split(op, ary, indices_or_sections, axis=0):

View File

@ -12,11 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import re
import textwrap
from typing import Callable, NamedTuple, Optional, Dict, Sequence
from typing import Callable, NamedTuple, Optional, Dict, Sequence, Set, Type
import warnings
from jax._src.config import config
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.numpy.ndarray import ndarray
from jax._src.util import safe_zip
from jax._src import api
from jax import core
from jax._src.lax import lax
import numpy as np
_parameter_break = re.compile("\n(?=[A-Za-z_])")
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
@ -178,3 +189,192 @@ def _wraps(fun: Optional[Callable], update_doc: bool = True, lax_description: st
setattr(op, attr, value)
return op
return wrap
_dtype = partial(dtypes.dtype, canonicalize=True)
def _asarray(arr):
"""
Pared-down utility to convert object to a DeviceArray.
Note this will not correctly handle lists or tuples.
"""
_check_arraylike("_asarray", arr)
dtype, weak_type = dtypes._lattice_result_type(arr)
return lax_internal._convert_element_type(arr, dtype, weak_type)
def _promote_shapes(fun_name, *args):
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return args
else:
shapes = [np.shape(arg) for arg in args]
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return args # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return args
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
if config.jax_dynamic_shapes:
# With dynamic shapes we don't support singleton-dimension broadcasting;
# we instead broadcast out to the full shape as a temporary workaround.
res_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
result_rank = len(lax.broadcast_shapes(*shapes))
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]
def _rank_promotion_warning_or_error(fun_name, shapes):
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif config.jax_numpy_rank_promotion == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def _promote_dtypes(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
def _promote_dtypes_inexact(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = _to_inexact_dtype(to_dtype)
weak_type = (weak_type and to_dtype == to_dtype_inexact)
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]
def _to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
return dtype if dtypes.issubdtype(dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_)
def _complex_elem_type(dtype):
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype
def _arraylike(x):
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
hasattr(x, '__jax_array__') or np.isscalar(x))
def _stackable(*args):
return all(type(arg) in stackables for arg in args)
stackables: Set[Type] = set()
_register_stackable: Callable[[Type], None] = stackables.add
def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if any(not _arraylike(arg) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
if not _arraylike(arg))
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
raise TypeError(msg.format(fun_name, type(arg), pos))
def _check_no_float0s(fun_name, *args):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) is dtypes.float0 for arg in args):
raise TypeError(
f"Called {fun_name} with a float0 array. "
"float0s do not support any operations by design because they "
"are not compatible with non-trivial vector spaces. No implicit dtype "
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
"to cast a float0 array to a regular zeros array. \n"
"If you didn't expect to get a float0 you might have accidentally "
"taken a gradient with respect to an integer argument.")
def _promote_args(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))
def _promote_args_inexact(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
@partial(api.jit, inline=True)
def _broadcast_arrays(*args):
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
# TODO(mattjj): remove the array(arg) here
return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg)
for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]
def _broadcast_to(arr, shape):
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
_check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
shape = (shape,)
shape = core.canonicalize_shape(shape) # check that shape is concrete
arr_shape = np.shape(arr)
if core.symbolic_equal_shape(arr_shape, shape):
return arr
else:
nlead = len(shape) - len(arr_shape)
shape_tail = shape[nlead:]
compatible = all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
if nlead < 0 or not compatible:
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
# The `jit` on `where` exists to avoid materializing constants in cases like
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@api.jit
def _where(condition, x=None, y=None):
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not np.issubdtype(_dtype(condition), np.bool_):
condition = lax.ne(condition, lax_internal._zero(condition))
x, y = _promote_dtypes(x, y)
condition, x, y = _broadcast_arrays(condition, x, y)
try: is_always_empty = core.is_empty_shape(np.shape(x))
except: is_always_empty = False # can fail with dynamic shapes
return lax.select(condition, x, y) if not is_always_empty else x