diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 3bd9e4d70..8ee8ae37f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -122,6 +122,15 @@ def _try_broadcast_shapes( return None return tuple(result_shape) +def asarray(x: ArrayLike) -> Array: + """Lightweight conversion of ArrayLike input to Array output.""" + if isinstance(x, Array): + return x + elif isinstance(x, np.ndarray) or np.isscalar(x): + return api.device_put(x) + else: + raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.") + @overload def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ... diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 7e245487d..7fa11a9cf 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -41,11 +41,6 @@ _lax_const = lax_internal._const Axis = Union[None, int, Sequence[int]] - -def _asarray(a: ArrayLike) -> Array: - # simplified version of jnp.asarray() for local use. - return a if isinstance(a, Array) else api.device_put(a) - def _isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): element = element.__jax_array__() @@ -54,7 +49,7 @@ def _isscalar(element: Any) -> bool: def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array: # simplified version of jnp.moveaxis() for local use. check_arraylike("moveaxis", a) - a = _asarray(a) + a = lax_internal.asarray(a) source = _canonicalize_axis(source, np.ndim(a)) destination = _canonicalize_axis(destination, np.ndim(a)) perm = [i for i in range(np.ndim(a)) if i != source] @@ -92,7 +87,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: raise ValueError(f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") - a = a if isinstance(a, Array) else _asarray(a) + a = a if isinstance(a, Array) else lax_internal.asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) @@ -135,7 +130,7 @@ def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: else: result = lax.reduce(a, init_val, op, dims) if initial is not None: - initial_arr = lax.convert_element_type(initial, _asarray(a).dtype) + initial_arr = lax.convert_element_type(initial, lax_internal.asarray(a).dtype) if initial_arr.shape != (): raise ValueError("initial value must be a scalar. " f"Got array of shape {initial_arr.shape}") @@ -434,7 +429,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, raise NotImplementedError("The 'out' argument to jnp.var is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) - a = _asarray(a).astype(computation_dtype) + a = lax_internal.asarray(a).astype(computation_dtype) a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = lax.sub(a, a_mean) if dtypes.issubdtype(centered.dtype, np.complexfloating): @@ -607,7 +602,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) - a = _asarray(a).astype(computation_dtype) + a = lax_internal.asarray(a).astype(computation_dtype) a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients. @@ -701,7 +696,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, .. if interpolation is not None: warnings.warn("The interpolation= argument to 'quantile' is deprecated. " "Use 'method=' instead.", DeprecationWarning) - return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, False) + return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, False) @_wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', @@ -717,7 +712,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, if interpolation is not None: warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " "Use 'method=' instead.", DeprecationWarning) - return _quantile(_asarray(a), _asarray(q), axis, interpolation or method, keepdims, True) + return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, interpolation or method, keepdims, True) def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], interpolation: str, keepdims: bool, squash_nans: bool) -> Array: diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index eab4abded..cc1322b94 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -29,7 +29,7 @@ from jax._src.api import jit, custom_jvp from jax._src.lax import lax from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import ( - _asarray, check_arraylike, promote_args, promote_args_inexact, + check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, _wraps) @@ -140,7 +140,7 @@ fabs = _one_to_one_unop(np.fabs, lax.abs, True) bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) invert = _one_to_one_unop(np.invert, lax.bitwise_not) negative = _one_to_one_unop(np.negative, lax.neg) -positive = _one_to_one_unop(np.positive, lambda x: _asarray(x)) +positive = _one_to_one_unop(np.positive, lambda x: lax.asarray(x)) floor = _one_to_one_unop(np.floor, lax.floor, True) ceil = _one_to_one_unop(np.ceil, lax.ceil, True) exp = _one_to_one_unop(np.exp, lax.exp, True) @@ -211,7 +211,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: def absolute(x: ArrayLike, /) -> Array: check_arraylike('absolute', x) dt = dtypes.dtype(x) - return _asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) + return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) abs = _wraps(np.abs, module='numpy')(absolute) @@ -596,7 +596,7 @@ radians = deg2rad @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: check_arraylike("conjugate", x) - return lax.conj(x) if np.iscomplexobj(x) else _asarray(x) + return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) conj = conjugate @@ -611,7 +611,7 @@ def imag(val: ArrayLike, /) -> Array: @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: check_arraylike("real", val) - return lax.real(val) if np.iscomplexobj(val) else _asarray(val) + return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) @_wraps(np.modf, module='numpy', skip_params=['out']) @jit diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 5127758b8..0cda54957 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -223,19 +223,10 @@ def _wraps( _dtype = partial(dtypes.dtype, canonicalize=True) -def _asarray(arr: ArrayLike) -> Array: - """ - Pared-down utility to convert object to a DeviceArray. - Note this will not correctly handle lists or tuples. - """ - check_arraylike("_asarray", arr) - dtype, weak_type = dtypes._lattice_result_type(arr) - return lax._convert_element_type(arr, dtype, weak_type) - def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]: """Apply NumPy-style broadcasting, making args shape-compatible for lax.py.""" if len(args) < 2: - return [_asarray(arg) for arg in args] + return [lax.asarray(arg) for arg in args] else: shapes = [np.shape(arg) for arg in args] if config.jax_dynamic_shapes: @@ -246,10 +237,10 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]: return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)] else: if all(len(shapes[0]) == len(s) for s in shapes[1:]): - return [_asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion + return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion nonscalar_ranks = {len(shp) for shp in shapes if shp} if len(nonscalar_ranks) < 2: - return [_asarray(arg) for arg in args] # rely on lax scalar promotion + return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion else: if config.jax_numpy_rank_promotion != "allow": _rank_promotion_warning_or_error(fun_name, shapes) @@ -277,7 +268,7 @@ def promote_dtypes(*args: ArrayLike) -> List[Array]: """Convenience function to apply Numpy argument dtype promotion.""" # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing. if len(args) < 2: - return [_asarray(arg) for arg in args] + return [lax.asarray(arg) for arg in args] else: to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) @@ -392,7 +383,7 @@ def _broadcast_arrays(*args: ArrayLike) -> List[Array]: """Like Numpy's broadcast_arrays but doesn't return views.""" shapes = [np.shape(arg) for arg in args] if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes): - return [_asarray(arg) for arg in args] + return [lax.asarray(arg) for arg in args] result_shape = lax.broadcast_shapes(*shapes) return [_broadcast_to(arg, result_shape) for arg in args] @@ -401,7 +392,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array: if hasattr(arr, "broadcast_to"): return arr.broadcast_to(shape) # type: ignore[union-attr] check_arraylike("broadcast_to", arr) - arr = arr if isinstance(arr, Array) else _asarray(arr) + arr = arr if isinstance(arr, Array) else lax.asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape,) shape = core.canonicalize_shape(shape) # check that shape is concrete