Add internal jax.lax.asarray utility

This commit is contained in:
Jake VanderPlas 2023-03-30 10:21:55 -07:00
parent 67a28ce30f
commit 8f72454bdf
4 changed files with 27 additions and 32 deletions

View File

@ -122,6 +122,15 @@ def _try_broadcast_shapes(
return None
return tuple(result_shape)
def asarray(x: ArrayLike) -> Array:
"""Lightweight conversion of ArrayLike input to Array output."""
if isinstance(x, Array):
return x
elif isinstance(x, np.ndarray) or np.isscalar(x):
return api.device_put(x)
else:
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
@overload
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...

View File

@ -41,11 +41,6 @@ _lax_const = lax_internal._const
Axis = Union[None, int, Sequence[int]]
def _asarray(a: ArrayLike) -> Array:
# simplified version of jnp.asarray() for local use.
return a if isinstance(a, Array) else api.device_put(a)
def _isscalar(element: Any) -> bool:
if hasattr(element, '__jax_array__'):
element = element.__jax_array__()
@ -54,7 +49,7 @@ def _isscalar(element: Any) -> bool:
def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
# simplified version of jnp.moveaxis() for local use.
check_arraylike("moveaxis", a)
a = _asarray(a)
a = lax_internal.asarray(a)
source = _canonicalize_axis(source, np.ndim(a))
destination = _canonicalize_axis(destination, np.ndim(a))
perm = [i for i in range(np.ndim(a)) if i != source]
@ -92,7 +87,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
f"where mask one has to specify 'initial'")
a = a if isinstance(a, Array) else _asarray(a)
a = a if isinstance(a, Array) else lax_internal.asarray(a)
a = preproc(a) if preproc else a
pos_dims, dims = _reduction_dims(a, axis)
@ -135,7 +130,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
else:
result = lax.reduce(a, init_val, op, dims)
if initial is not None:
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
initial_arr = lax.convert_element_type(initial, lax_internal.asarray(a).dtype)
if initial_arr.shape != ():
raise ValueError("initial value must be a scalar. "
f"Got array of shape {initial_arr.shape}")
@ -434,7 +429,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
a = _asarray(a).astype(computation_dtype)
a = lax_internal.asarray(a).astype(computation_dtype)
a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
centered = lax.sub(a, a_mean)
if dtypes.issubdtype(centered.dtype, np.complexfloating):
@ -607,7 +602,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None =
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
a = _asarray(a).astype(computation_dtype)
a = lax_internal.asarray(a).astype(computation_dtype)
a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients.
@ -701,7 +696,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ..
if interpolation is not None:
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, False)
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False)
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
@ -717,7 +712,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int,
if interpolation is not None:
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, True)
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:

View File

@ -29,7 +29,7 @@ from jax._src.api import jit, custom_jvp
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import (
_asarray, check_arraylike, promote_args, promote_args_inexact,
check_arraylike, promote_args, promote_args_inexact,
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
promote_shapes, _where, _wraps)
@ -140,7 +140,7 @@ fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
positive = _one_to_one_unop(np.positive, lambda x: _asarray(x))
positive = _one_to_one_unop(np.positive, lambda x: lax.asarray(x))
floor = _one_to_one_unop(np.floor, lax.floor, True)
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
exp = _one_to_one_unop(np.exp, lax.exp, True)
@ -211,7 +211,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
def absolute(x: ArrayLike, /) -> Array:
check_arraylike('absolute', x)
dt = dtypes.dtype(x)
return _asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs, module='numpy')(absolute)
@ -596,7 +596,7 @@ radians = deg2rad
@partial(jit, inline=True)
def conjugate(x: ArrayLike, /) -> Array:
check_arraylike("conjugate", x)
return lax.conj(x) if np.iscomplexobj(x) else _asarray(x)
return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)
conj = conjugate
@ -611,7 +611,7 @@ def imag(val: ArrayLike, /) -> Array:
@partial(jit, inline=True)
def real(val: ArrayLike, /) -> Array:
check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else _asarray(val)
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
@_wraps(np.modf, module='numpy', skip_params=['out'])
@jit

View File

@ -223,19 +223,10 @@ def _wraps(
_dtype = partial(dtypes.dtype, canonicalize=True)
def _asarray(arr: ArrayLike) -> Array:
"""
Pared-down utility to convert object to a DeviceArray.
Note this will not correctly handle lists or tuples.
"""
check_arraylike("_asarray", arr)
dtype, weak_type = dtypes._lattice_result_type(arr)
return lax._convert_element_type(arr, dtype, weak_type)
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return [_asarray(arg) for arg in args]
return [lax.asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
@ -246,10 +237,10 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return [_asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return [_asarray(arg) for arg in args] # rely on lax scalar promotion
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
@ -277,7 +268,7 @@ def promote_dtypes(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return [_asarray(arg) for arg in args]
return [lax.asarray(arg) for arg in args]
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
@ -392,7 +383,7 @@ def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
return [_asarray(arg) for arg in args]
return [lax.asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]
@ -401,7 +392,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape) # type: ignore[union-attr]
check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, Array) else _asarray(arr)
arr = arr if isinstance(arr, Array) else lax.asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
shape = (shape,)
shape = core.canonicalize_shape(shape) # check that shape is concrete