mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add types to jax/_src/numpy/util.py
This commit is contained in:
parent
ae49d2e033
commit
069866e07a
@ -4353,7 +4353,7 @@ mlir.register_lowering(rng_bit_generator_p,
|
||||
_rng_bit_generator_lowering)
|
||||
|
||||
|
||||
def _array_copy(arr):
|
||||
def _array_copy(arr: ArrayLike) -> Array:
|
||||
return copy_p.bind(arr)
|
||||
|
||||
# The copy_p primitive exists for expressing making copies of runtime arrays.
|
||||
|
@ -76,6 +76,7 @@ from jax._src.numpy.util import ( # noqa: F401
|
||||
_register_stackable, _stackable, _where, _wraps)
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
canonicalize_axis as _canonicalize_axis)
|
||||
from jax._src.array import ArrayImpl
|
||||
@ -1838,7 +1839,8 @@ https://jax.readthedocs.io/en/latest/faq.html).
|
||||
"""
|
||||
|
||||
@_wraps(np.array, lax_description=_ARRAY_DOC)
|
||||
def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
|
||||
order: str = "K", ndmin: int = 0) -> Array:
|
||||
if order is not None and order != "K":
|
||||
raise NotImplementedError("Only implemented for order='K'")
|
||||
|
||||
@ -1878,6 +1880,8 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
# (See https://github.com/google/jax/issues/8950)
|
||||
ndarray_types = (device_array.DeviceArray, core.Tracer, ArrayImpl)
|
||||
|
||||
out: ArrayLike
|
||||
|
||||
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
|
||||
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
|
||||
# containing large integers; see discussion in
|
||||
@ -1902,10 +1906,10 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
|
||||
raise TypeError(f"Unexpected input type for array: {type(object)}")
|
||||
|
||||
out = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
|
||||
if ndmin > ndim(out):
|
||||
out = lax.expand_dims(out, range(ndmin - ndim(out)))
|
||||
return out
|
||||
out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
|
||||
if ndmin > ndim(out_array):
|
||||
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
|
||||
return out_array
|
||||
|
||||
|
||||
def _convert_to_array_if_dtype_fails(x):
|
||||
@ -1918,7 +1922,7 @@ def _convert_to_array_if_dtype_fails(x):
|
||||
|
||||
|
||||
@_wraps(np.asarray, lax_description=_ARRAY_DOC)
|
||||
def asarray(a, dtype=None, order=None):
|
||||
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array:
|
||||
lax_internal._check_user_dtype_supported(dtype, "asarray")
|
||||
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
|
||||
return array(a, dtype=dtype, copy=False, order=order)
|
||||
|
@ -16,7 +16,7 @@ from functools import partial
|
||||
import re
|
||||
import textwrap
|
||||
from typing import (
|
||||
Any, Callable, NamedTuple, Optional, Dict, Sequence, Set, Type, TypeVar
|
||||
Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Type, TypeVar
|
||||
)
|
||||
import warnings
|
||||
|
||||
@ -28,6 +28,7 @@ from jax._src.util import safe_zip, safe_map
|
||||
from jax._src import api
|
||||
from jax import core
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -215,7 +216,7 @@ def _wraps(
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
def _asarray(arr):
|
||||
def _asarray(arr: ArrayLike) -> Array:
|
||||
"""
|
||||
Pared-down utility to convert object to a DeviceArray.
|
||||
Note this will not correctly handle lists or tuples.
|
||||
@ -224,10 +225,10 @@ def _asarray(arr):
|
||||
dtype, weak_type = dtypes._lattice_result_type(arr)
|
||||
return lax_internal._convert_element_type(arr, dtype, weak_type)
|
||||
|
||||
def _promote_shapes(fun_name, *args):
|
||||
def _promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
|
||||
if len(args) < 2:
|
||||
return args
|
||||
return [_asarray(arg) for arg in args]
|
||||
else:
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if config.jax_dynamic_shapes:
|
||||
@ -238,10 +239,10 @@ def _promote_shapes(fun_name, *args):
|
||||
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
|
||||
else:
|
||||
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
|
||||
return args # no need for rank promotion, so rely on lax promotion
|
||||
return [_asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
|
||||
nonscalar_ranks = {len(shp) for shp in shapes if shp}
|
||||
if len(nonscalar_ranks) < 2:
|
||||
return args # rely on lax scalar promotion
|
||||
return [_asarray(arg) for arg in args] # rely on lax scalar promotion
|
||||
else:
|
||||
if config.jax_numpy_rank_promotion != "allow":
|
||||
_rank_promotion_warning_or_error(fun_name, shapes)
|
||||
@ -250,7 +251,7 @@ def _promote_shapes(fun_name, *args):
|
||||
for arg, shp in zip(args, shapes)]
|
||||
|
||||
|
||||
def _rank_promotion_warning_or_error(fun_name, shapes):
|
||||
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
|
||||
if config.jax_numpy_rank_promotion == "warn":
|
||||
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
|
||||
"Set the jax_numpy_rank_promotion config option to 'allow' to "
|
||||
@ -265,18 +266,18 @@ def _rank_promotion_warning_or_error(fun_name, shapes):
|
||||
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||
|
||||
|
||||
def _promote_dtypes(*args):
|
||||
def _promote_dtypes(*args: ArrayLike) -> List[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
||||
if len(args) < 2:
|
||||
return args
|
||||
return [_asarray(arg) for arg in args]
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
|
||||
def _promote_dtypes_inexact(*args):
|
||||
def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to an inexact type."""
|
||||
@ -287,7 +288,7 @@ def _promote_dtypes_inexact(*args):
|
||||
for x in args]
|
||||
|
||||
|
||||
def _promote_dtypes_numeric(*args):
|
||||
def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to a numeric (non-bool) type."""
|
||||
@ -298,7 +299,7 @@ def _promote_dtypes_numeric(*args):
|
||||
for x in args]
|
||||
|
||||
|
||||
def _promote_dtypes_complex(*args):
|
||||
def _promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to a complex type."""
|
||||
@ -309,23 +310,23 @@ def _promote_dtypes_complex(*args):
|
||||
for x in args]
|
||||
|
||||
|
||||
def _complex_elem_type(dtype):
|
||||
def _complex_elem_type(dtype: DTypeLike) -> DType:
|
||||
"""Returns the float type of the real/imaginary parts of a complex dtype."""
|
||||
return np.abs(np.zeros((), dtype)).dtype
|
||||
|
||||
|
||||
def _arraylike(x):
|
||||
def _arraylike(x: ArrayLike) -> bool:
|
||||
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
|
||||
hasattr(x, '__jax_array__') or np.isscalar(x))
|
||||
|
||||
|
||||
def _stackable(*args):
|
||||
def _stackable(*args: Any) -> bool:
|
||||
return all(type(arg) in stackables for arg in args)
|
||||
stackables: Set[Type] = set()
|
||||
_register_stackable: Callable[[Type], None] = stackables.add
|
||||
|
||||
|
||||
def _check_arraylike(fun_name, *args):
|
||||
def _check_arraylike(fun_name: str, *args: Any):
|
||||
"""Check if all args fit JAX's definition of arraylike."""
|
||||
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
|
||||
if any(not _arraylike(arg) for arg in args):
|
||||
@ -335,7 +336,7 @@ def _check_arraylike(fun_name, *args):
|
||||
raise TypeError(msg.format(fun_name, type(arg), pos))
|
||||
|
||||
|
||||
def _check_no_float0s(fun_name, *args):
|
||||
def _check_no_float0s(fun_name: str, *args: Any):
|
||||
"""Check if none of the args have dtype float0."""
|
||||
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
|
||||
raise TypeError(
|
||||
@ -348,20 +349,20 @@ def _check_no_float0s(fun_name, *args):
|
||||
"taken a gradient with respect to an integer argument.")
|
||||
|
||||
|
||||
def _promote_args(fun_name, *args):
|
||||
def _promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion."""
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes(*args))
|
||||
|
||||
|
||||
def _promote_args_numeric(fun_name, *args):
|
||||
def _promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes_numeric(*args))
|
||||
|
||||
|
||||
def _promote_args_inexact(fun_name, *args):
|
||||
def _promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion.
|
||||
|
||||
Promotes non-inexact types to an inexact type."""
|
||||
@ -371,20 +372,18 @@ def _promote_args_inexact(fun_name, *args):
|
||||
|
||||
|
||||
@partial(api.jit, inline=True)
|
||||
def _broadcast_arrays(*args):
|
||||
def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
|
||||
# TODO(mattjj): remove the array(arg) here
|
||||
return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg)
|
||||
for arg in args]
|
||||
return [_asarray(arg) for arg in args]
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
return [_broadcast_to(arg, result_shape) for arg in args]
|
||||
|
||||
|
||||
def _broadcast_to(arr, shape):
|
||||
def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
|
||||
if hasattr(arr, "broadcast_to"):
|
||||
return arr.broadcast_to(shape)
|
||||
return arr.broadcast_to(shape) # type: ignore[union-attr]
|
||||
_check_arraylike("broadcast_to", arr)
|
||||
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
|
||||
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
|
||||
@ -412,7 +411,8 @@ def _broadcast_to(arr, shape):
|
||||
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
|
||||
# materialize the broadcast forms of scalar arguments.
|
||||
@api.jit
|
||||
def _where(condition, x=None, y=None):
|
||||
def _where(condition: ArrayLike, x: Optional[ArrayLike] = None,
|
||||
y: Optional[ArrayLike] = None) -> Array:
|
||||
if x is None or y is None:
|
||||
raise ValueError("Either both or neither of the x and y arguments should "
|
||||
"be provided to jax.numpy.where, got {} and {}."
|
||||
@ -420,7 +420,9 @@ def _where(condition, x=None, y=None):
|
||||
if not np.issubdtype(_dtype(condition), np.bool_):
|
||||
condition = lax.ne(condition, lax_internal._zero(condition))
|
||||
x, y = _promote_dtypes(x, y)
|
||||
condition, x, y = _broadcast_arrays(condition, x, y)
|
||||
try: is_always_empty = core.is_empty_shape(np.shape(x))
|
||||
except: is_always_empty = False # can fail with dynamic shapes
|
||||
return lax.select(condition, x, y) if not is_always_empty else x
|
||||
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
|
||||
try:
|
||||
is_always_empty = core.is_empty_shape(x_arr.shape)
|
||||
except:
|
||||
is_always_empty = False # can fail with dynamic shapes
|
||||
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr
|
||||
|
Loading…
x
Reference in New Issue
Block a user