mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add internal jax.lax.asarray utility
This commit is contained in:
parent
67a28ce30f
commit
8f72454bdf
@ -122,6 +122,15 @@ def _try_broadcast_shapes(
|
||||
return None
|
||||
return tuple(result_shape)
|
||||
|
||||
def asarray(x: ArrayLike) -> Array:
|
||||
"""Lightweight conversion of ArrayLike input to Array output."""
|
||||
if isinstance(x, Array):
|
||||
return x
|
||||
elif isinstance(x, np.ndarray) or np.isscalar(x):
|
||||
return api.device_put(x)
|
||||
else:
|
||||
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")
|
||||
|
||||
@overload
|
||||
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...
|
||||
|
||||
|
@ -41,11 +41,6 @@ _lax_const = lax_internal._const
|
||||
|
||||
Axis = Union[None, int, Sequence[int]]
|
||||
|
||||
|
||||
def _asarray(a: ArrayLike) -> Array:
|
||||
# simplified version of jnp.asarray() for local use.
|
||||
return a if isinstance(a, Array) else api.device_put(a)
|
||||
|
||||
def _isscalar(element: Any) -> bool:
|
||||
if hasattr(element, '__jax_array__'):
|
||||
element = element.__jax_array__()
|
||||
@ -54,7 +49,7 @@ def _isscalar(element: Any) -> bool:
|
||||
def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array:
|
||||
# simplified version of jnp.moveaxis() for local use.
|
||||
check_arraylike("moveaxis", a)
|
||||
a = _asarray(a)
|
||||
a = lax_internal.asarray(a)
|
||||
source = _canonicalize_axis(source, np.ndim(a))
|
||||
destination = _canonicalize_axis(destination, np.ndim(a))
|
||||
perm = [i for i in range(np.ndim(a)) if i != source]
|
||||
@ -92,7 +87,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
|
||||
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
|
||||
f"where mask one has to specify 'initial'")
|
||||
|
||||
a = a if isinstance(a, Array) else _asarray(a)
|
||||
a = a if isinstance(a, Array) else lax_internal.asarray(a)
|
||||
a = preproc(a) if preproc else a
|
||||
pos_dims, dims = _reduction_dims(a, axis)
|
||||
|
||||
@ -135,7 +130,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val:
|
||||
else:
|
||||
result = lax.reduce(a, init_val, op, dims)
|
||||
if initial is not None:
|
||||
initial_arr = lax.convert_element_type(initial, _asarray(a).dtype)
|
||||
initial_arr = lax.convert_element_type(initial, lax_internal.asarray(a).dtype)
|
||||
if initial_arr.shape != ():
|
||||
raise ValueError("initial value must be a scalar. "
|
||||
f"Got array of shape {initial_arr.shape}")
|
||||
@ -434,7 +429,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None,
|
||||
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
|
||||
|
||||
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
|
||||
a = _asarray(a).astype(computation_dtype)
|
||||
a = lax_internal.asarray(a).astype(computation_dtype)
|
||||
a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
|
||||
centered = lax.sub(a, a_mean)
|
||||
if dtypes.issubdtype(centered.dtype, np.complexfloating):
|
||||
@ -607,7 +602,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None =
|
||||
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
|
||||
|
||||
computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
|
||||
a = _asarray(a).astype(computation_dtype)
|
||||
a = lax_internal.asarray(a).astype(computation_dtype)
|
||||
a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
|
||||
|
||||
centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients.
|
||||
@ -701,7 +696,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ..
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, False)
|
||||
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False)
|
||||
|
||||
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
|
||||
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
|
||||
@ -717,7 +712,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int,
|
||||
if interpolation is not None:
|
||||
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
|
||||
"Use 'method=' instead.", DeprecationWarning)
|
||||
return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, True)
|
||||
return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True)
|
||||
|
||||
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
|
||||
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
|
||||
|
@ -29,7 +29,7 @@ from jax._src.api import jit, custom_jvp
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import (
|
||||
_asarray, check_arraylike, promote_args, promote_args_inexact,
|
||||
check_arraylike, promote_args, promote_args_inexact,
|
||||
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
|
||||
promote_shapes, _where, _wraps)
|
||||
|
||||
@ -140,7 +140,7 @@ fabs = _one_to_one_unop(np.fabs, lax.abs, True)
|
||||
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
|
||||
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
|
||||
negative = _one_to_one_unop(np.negative, lax.neg)
|
||||
positive = _one_to_one_unop(np.positive, lambda x: _asarray(x))
|
||||
positive = _one_to_one_unop(np.positive, lambda x: lax.asarray(x))
|
||||
floor = _one_to_one_unop(np.floor, lax.floor, True)
|
||||
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
|
||||
exp = _one_to_one_unop(np.exp, lax.exp, True)
|
||||
@ -211,7 +211,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
def absolute(x: ArrayLike, /) -> Array:
|
||||
check_arraylike('absolute', x)
|
||||
dt = dtypes.dtype(x)
|
||||
return _asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
|
||||
return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
|
||||
abs = _wraps(np.abs, module='numpy')(absolute)
|
||||
|
||||
|
||||
@ -596,7 +596,7 @@ radians = deg2rad
|
||||
@partial(jit, inline=True)
|
||||
def conjugate(x: ArrayLike, /) -> Array:
|
||||
check_arraylike("conjugate", x)
|
||||
return lax.conj(x) if np.iscomplexobj(x) else _asarray(x)
|
||||
return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)
|
||||
conj = conjugate
|
||||
|
||||
|
||||
@ -611,7 +611,7 @@ def imag(val: ArrayLike, /) -> Array:
|
||||
@partial(jit, inline=True)
|
||||
def real(val: ArrayLike, /) -> Array:
|
||||
check_arraylike("real", val)
|
||||
return lax.real(val) if np.iscomplexobj(val) else _asarray(val)
|
||||
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)
|
||||
|
||||
@_wraps(np.modf, module='numpy', skip_params=['out'])
|
||||
@jit
|
||||
|
@ -223,19 +223,10 @@ def _wraps(
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
def _asarray(arr: ArrayLike) -> Array:
|
||||
"""
|
||||
Pared-down utility to convert object to a DeviceArray.
|
||||
Note this will not correctly handle lists or tuples.
|
||||
"""
|
||||
check_arraylike("_asarray", arr)
|
||||
dtype, weak_type = dtypes._lattice_result_type(arr)
|
||||
return lax._convert_element_type(arr, dtype, weak_type)
|
||||
|
||||
def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
|
||||
if len(args) < 2:
|
||||
return [_asarray(arg) for arg in args]
|
||||
return [lax.asarray(arg) for arg in args]
|
||||
else:
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if config.jax_dynamic_shapes:
|
||||
@ -246,10 +237,10 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
|
||||
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
|
||||
else:
|
||||
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
|
||||
return [_asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
|
||||
return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
|
||||
nonscalar_ranks = {len(shp) for shp in shapes if shp}
|
||||
if len(nonscalar_ranks) < 2:
|
||||
return [_asarray(arg) for arg in args] # rely on lax scalar promotion
|
||||
return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
|
||||
else:
|
||||
if config.jax_numpy_rank_promotion != "allow":
|
||||
_rank_promotion_warning_or_error(fun_name, shapes)
|
||||
@ -277,7 +268,7 @@ def promote_dtypes(*args: ArrayLike) -> List[Array]:
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
||||
if len(args) < 2:
|
||||
return [_asarray(arg) for arg in args]
|
||||
return [lax.asarray(arg) for arg in args]
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
@ -392,7 +383,7 @@ def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
|
||||
return [_asarray(arg) for arg in args]
|
||||
return [lax.asarray(arg) for arg in args]
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
return [_broadcast_to(arg, result_shape) for arg in args]
|
||||
|
||||
@ -401,7 +392,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
|
||||
if hasattr(arr, "broadcast_to"):
|
||||
return arr.broadcast_to(shape) # type: ignore[union-attr]
|
||||
check_arraylike("broadcast_to", arr)
|
||||
arr = arr if isinstance(arr, Array) else _asarray(arr)
|
||||
arr = arr if isinstance(arr, Array) else lax.asarray(arr)
|
||||
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
shape = core.canonicalize_shape(shape) # check that shape is concrete
|
||||
|
Loading…
x
Reference in New Issue
Block a user