Add types to jax/_src/numpy/util.py

This commit is contained in:
Jake VanderPlas 2022-10-03 15:55:56 -07:00
parent ae49d2e033
commit 069866e07a
3 changed files with 44 additions and 38 deletions

View File

@ -4353,7 +4353,7 @@ mlir.register_lowering(rng_bit_generator_p,
_rng_bit_generator_lowering)
def _array_copy(arr):
def _array_copy(arr: ArrayLike) -> Array:
return copy_p.bind(arr)
# The copy_p primitive exists for expressing making copies of runtime arrays.

View File

@ -76,6 +76,7 @@ from jax._src.numpy.util import ( # noqa: F401
_register_stackable, _stackable, _where, _wraps)
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl
@ -1838,7 +1839,8 @@ https://jax.readthedocs.io/en/latest/faq.html).
"""
@_wraps(np.array, lax_description=_ARRAY_DOC)
def array(object, dtype=None, copy=True, order="K", ndmin=0):
def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
order: str = "K", ndmin: int = 0) -> Array:
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
@ -1878,6 +1880,8 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
# (See https://github.com/google/jax/issues/8950)
ndarray_types = (device_array.DeviceArray, core.Tracer, ArrayImpl)
out: ArrayLike
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
@ -1902,10 +1906,10 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
raise TypeError(f"Unexpected input type for array: {type(object)}")
out = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out):
out = lax.expand_dims(out, range(ndmin - ndim(out)))
return out
out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
return out_array
def _convert_to_array_if_dtype_fails(x):
@ -1918,7 +1922,7 @@ def _convert_to_array_if_dtype_fails(x):
@_wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a, dtype=None, order=None):
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "asarray")
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
return array(a, dtype=dtype, copy=False, order=order)

View File

@ -16,7 +16,7 @@ from functools import partial
import re
import textwrap
from typing import (
Any, Callable, NamedTuple, Optional, Dict, Sequence, Set, Type, TypeVar
Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Type, TypeVar
)
import warnings
@ -28,6 +28,7 @@ from jax._src.util import safe_zip, safe_map
from jax._src import api
from jax import core
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
import numpy as np
@ -215,7 +216,7 @@ def _wraps(
_dtype = partial(dtypes.dtype, canonicalize=True)
def _asarray(arr):
def _asarray(arr: ArrayLike) -> Array:
"""
Pared-down utility to convert object to a DeviceArray.
Note this will not correctly handle lists or tuples.
@ -224,10 +225,10 @@ def _asarray(arr):
dtype, weak_type = dtypes._lattice_result_type(arr)
return lax_internal._convert_element_type(arr, dtype, weak_type)
def _promote_shapes(fun_name, *args):
def _promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return args
return [_asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
@ -238,10 +239,10 @@ def _promote_shapes(fun_name, *args):
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return args # no need for rank promotion, so rely on lax promotion
return [_asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return args # rely on lax scalar promotion
return [_asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
@ -250,7 +251,7 @@ def _promote_shapes(fun_name, *args):
for arg, shp in zip(args, shapes)]
def _rank_promotion_warning_or_error(fun_name, shapes):
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
@ -265,18 +266,18 @@ def _rank_promotion_warning_or_error(fun_name, shapes):
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def _promote_dtypes(*args):
def _promote_dtypes(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
return [_asarray(arg) for arg in args]
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
def _promote_dtypes_inexact(*args):
def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
@ -287,7 +288,7 @@ def _promote_dtypes_inexact(*args):
for x in args]
def _promote_dtypes_numeric(*args):
def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a numeric (non-bool) type."""
@ -298,7 +299,7 @@ def _promote_dtypes_numeric(*args):
for x in args]
def _promote_dtypes_complex(*args):
def _promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a complex type."""
@ -309,23 +310,23 @@ def _promote_dtypes_complex(*args):
for x in args]
def _complex_elem_type(dtype):
def _complex_elem_type(dtype: DTypeLike) -> DType:
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype
def _arraylike(x):
def _arraylike(x: ArrayLike) -> bool:
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
hasattr(x, '__jax_array__') or np.isscalar(x))
def _stackable(*args):
def _stackable(*args: Any) -> bool:
return all(type(arg) in stackables for arg in args)
stackables: Set[Type] = set()
_register_stackable: Callable[[Type], None] = stackables.add
def _check_arraylike(fun_name, *args):
def _check_arraylike(fun_name: str, *args: Any):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if any(not _arraylike(arg) for arg in args):
@ -335,7 +336,7 @@ def _check_arraylike(fun_name, *args):
raise TypeError(msg.format(fun_name, type(arg), pos))
def _check_no_float0s(fun_name, *args):
def _check_no_float0s(fun_name: str, *args: Any):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
raise TypeError(
@ -348,20 +349,20 @@ def _check_no_float0s(fun_name, *args):
"taken a gradient with respect to an integer argument.")
def _promote_args(fun_name, *args):
def _promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))
def _promote_args_numeric(fun_name, *args):
def _promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_numeric(*args))
def _promote_args_inexact(fun_name, *args):
def _promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
@ -371,20 +372,18 @@ def _promote_args_inexact(fun_name, *args):
@partial(api.jit, inline=True)
def _broadcast_arrays(*args):
def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
# TODO(mattjj): remove the array(arg) here
return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg)
for arg in args]
return [_asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]
def _broadcast_to(arr, shape):
def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
return arr.broadcast_to(shape) # type: ignore[union-attr]
_check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
@ -412,7 +411,8 @@ def _broadcast_to(arr, shape):
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@api.jit
def _where(condition, x=None, y=None):
def _where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None) -> Array:
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
@ -420,7 +420,9 @@ def _where(condition, x=None, y=None):
if not np.issubdtype(_dtype(condition), np.bool_):
condition = lax.ne(condition, lax_internal._zero(condition))
x, y = _promote_dtypes(x, y)
condition, x, y = _broadcast_arrays(condition, x, y)
try: is_always_empty = core.is_empty_shape(np.shape(x))
except: is_always_empty = False # can fail with dynamic shapes
return lax.select(condition, x, y) if not is_always_empty else x
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
try:
is_always_empty = core.is_empty_shape(x_arr.shape)
except:
is_always_empty = False # can fail with dynamic shapes
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr