mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
lax_numpy.py: factor out some common utilities
Re-lands part of #9724 PiperOrigin-RevId: 433838553
This commit is contained in:
parent
5abd664938
commit
ddf23dead3
@ -29,7 +29,7 @@ import collections
|
||||
from functools import partial
|
||||
import operator
|
||||
import types
|
||||
from typing import Any, Sequence, FrozenSet, Optional, Tuple, Union, Set, Type, Callable
|
||||
from typing import Any, Sequence, FrozenSet, Optional, Tuple, Union
|
||||
from textwrap import dedent as _dedent
|
||||
import warnings
|
||||
|
||||
@ -42,7 +42,6 @@ from jax import core
|
||||
from jax import errors
|
||||
from jax import lax
|
||||
from jax.core import ShapedArray, DShapedArray, ConcreteArray, canonicalize_shape
|
||||
from jax.config import config
|
||||
from jax.interpreters import pxla
|
||||
from jax.tree_util import tree_leaves, tree_flatten, tree_map
|
||||
|
||||
@ -52,7 +51,10 @@ from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.util import ( # noqa: F401
|
||||
_arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike, _complex_elem_type, _promote_args,
|
||||
_promote_args_inexact, _promote_dtypes, _promote_dtypes_inexact, _promote_shapes, _register_stackable,
|
||||
_stackable, _where, _wraps)
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
@ -206,74 +208,6 @@ _INT_DTYPES = {
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
|
||||
def _promote_shapes(fun_name, *args):
|
||||
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
|
||||
if len(args) < 2:
|
||||
return args
|
||||
else:
|
||||
shapes = [shape(arg) for arg in args]
|
||||
if _all(len(shapes[0]) == len(s) for s in shapes[1:]):
|
||||
return args # no need for rank promotion, so rely on lax promotion
|
||||
nonscalar_ranks = {len(shp) for shp in shapes if shp}
|
||||
if len(nonscalar_ranks) < 2:
|
||||
return args
|
||||
else:
|
||||
if config.jax_numpy_rank_promotion != "allow":
|
||||
_rank_promotion_warning_or_error(fun_name, shapes)
|
||||
if config.jax_dynamic_shapes:
|
||||
# With dynamic shapes we don't support singleton-dimension broadcasting;
|
||||
# we instead broadcast out to the full shape as a temporary workaround.
|
||||
res_shape = lax.broadcast_shapes(*shapes)
|
||||
return [broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
|
||||
else:
|
||||
result_rank = len(lax.broadcast_shapes(*shapes))
|
||||
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
|
||||
for arg, shp in zip(args, shapes)]
|
||||
|
||||
|
||||
def _rank_promotion_warning_or_error(fun_name, shapes):
|
||||
if config.jax_numpy_rank_promotion == "warn":
|
||||
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
|
||||
"Set the jax_numpy_rank_promotion config option to 'allow' to "
|
||||
"disable this warning; for more information, see "
|
||||
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
||||
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||
elif config.jax_numpy_rank_promotion == "raise":
|
||||
msg = ("Operands could not be broadcast together for {} on shapes {} "
|
||||
"and with the config option jax_numpy_rank_promotion='raise'. "
|
||||
"For more information, see "
|
||||
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
||||
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||
|
||||
def _promote_dtypes(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
||||
if len(args) < 2:
|
||||
return args
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype, weak_type)
|
||||
for x in args]
|
||||
|
||||
def _promote_dtypes_inexact(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_inexact = _to_inexact_dtype(to_dtype)
|
||||
weak_type = (weak_type and to_dtype == to_dtype_inexact)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
|
||||
for x in args]
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
"""Promotes a dtype into an inexact dtype, if it is not already one."""
|
||||
return dtype if issubdtype(dtype, inexact) else promote_types(dtype, float_)
|
||||
|
||||
def _complex_elem_type(dtype):
|
||||
"""Returns the float type of the real/imaginary parts of a complex dtype."""
|
||||
return np.abs(np.zeros((), dtype)).dtype
|
||||
|
||||
def _result_dtype(op, *args):
|
||||
"""Compute result dtype of applying op to arguments with given dtypes."""
|
||||
@ -281,51 +215,6 @@ def _result_dtype(op, *args):
|
||||
return _dtype(op(*args))
|
||||
|
||||
|
||||
def _arraylike(x):
|
||||
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
|
||||
hasattr(x, '__jax_array__') or isscalar(x))
|
||||
|
||||
|
||||
def _stackable(*args):
|
||||
return _all(type(arg) in stackables for arg in args)
|
||||
stackables: Set[Type] = set()
|
||||
_register_stackable: Callable[[Type], None] = stackables.add
|
||||
|
||||
def _check_arraylike(fun_name, *args):
|
||||
"""Check if all args fit JAX's definition of arraylike."""
|
||||
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
|
||||
if _any(not _arraylike(arg) for arg in args):
|
||||
pos, arg = next((i, arg) for i, arg in enumerate(args)
|
||||
if not _arraylike(arg))
|
||||
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
|
||||
raise TypeError(msg.format(fun_name, type(arg), pos))
|
||||
|
||||
def _check_no_float0s(fun_name, *args):
|
||||
"""Check if none of the args have dtype float0."""
|
||||
if _any(dtypes.dtype(arg) is dtypes.float0 for arg in args):
|
||||
raise TypeError(
|
||||
f"Called {fun_name} with a float0 array. "
|
||||
"float0s do not support any operations by design because they "
|
||||
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
||||
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
||||
"to cast a float0 array to a regular zeros array. \n"
|
||||
"If you didn't expect to get a float0 you might have accidentally "
|
||||
"taken a gradient with respect to an integer argument.")
|
||||
|
||||
def _promote_args(fun_name, *args):
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion."""
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes(*args))
|
||||
|
||||
def _promote_args_inexact(fun_name, *args):
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion.
|
||||
|
||||
Promotes non-inexact types to an inexact type."""
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
|
||||
|
||||
def _convert_and_clip_integer(val, dtype):
|
||||
"""
|
||||
Convert integer-typed val to specified integer dtype, clipping to dtype
|
||||
@ -1868,24 +1757,6 @@ def isin(element, test_elements, assume_unique=False, invert=False):
|
||||
return result.reshape(shape(element))
|
||||
|
||||
|
||||
# The `jit` on `where` exists to avoid materializing constants in cases like
|
||||
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
|
||||
# materialize the broadcast forms of scalar arguments.
|
||||
@jit
|
||||
def _where(condition, x=None, y=None):
|
||||
if x is None or y is None:
|
||||
raise ValueError("Either both or neither of the x and y arguments should "
|
||||
"be provided to jax.numpy.where, got {} and {}."
|
||||
.format(x, y))
|
||||
if not issubdtype(_dtype(condition), bool_):
|
||||
condition = lax.ne(condition, zeros_like(condition))
|
||||
x, y = _promote_dtypes(x, y)
|
||||
condition, x, y = broadcast_arrays(condition, x, y)
|
||||
try: is_always_empty = core.is_empty_shape(np.shape(x))
|
||||
except: is_always_empty = False # can fail with dynamic shapes
|
||||
return lax.select(condition, x, y) if not is_always_empty else x
|
||||
|
||||
|
||||
@_wraps(np.where,
|
||||
lax_description=_dedent("""
|
||||
At present, JAX does not support JIT-compilation of the single-argument form
|
||||
@ -1966,47 +1837,15 @@ def broadcast_shapes(*shapes):
|
||||
shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes]
|
||||
return lax.broadcast_shapes(*shapes)
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def broadcast_arrays(*args):
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
# Avoid calling _check_arraylike() here to allow passing through objects
|
||||
# like PRNGKeyArray which are specially handled in broadcast_to() below.
|
||||
shapes = [shape(arg) for arg in args]
|
||||
if not shapes or _all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
|
||||
# TODO(mattjj): remove the array(arg) here
|
||||
return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg)
|
||||
for arg in args]
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
return [broadcast_to(arg, result_shape) for arg in args]
|
||||
|
||||
|
||||
@_wraps(np.broadcast_to, lax_description="""\
|
||||
broadcast_arrays = _wraps(np.broadcast_arrays, lax_description="""\
|
||||
The JAX version does not necessarily return a view of the input.
|
||||
""")
|
||||
def broadcast_to(arr, shape):
|
||||
if hasattr(arr, "broadcast_to"):
|
||||
return arr.broadcast_to(shape)
|
||||
_check_arraylike("broadcast_to", arr)
|
||||
arr = arr if isinstance(arr, ndarray) else array(arr)
|
||||
if not isinstance(shape, tuple) and ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
shape = canonicalize_shape(shape) # check that shape is concrete
|
||||
arr_shape = _shape(arr)
|
||||
if core.symbolic_equal_shape(arr_shape, shape):
|
||||
return arr
|
||||
else:
|
||||
nlead = len(shape) - len(arr_shape)
|
||||
shape_tail = shape[nlead:]
|
||||
compatible = _all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
|
||||
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
|
||||
if nlead < 0 or not compatible:
|
||||
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
|
||||
raise ValueError(msg.format(arr_shape, shape))
|
||||
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
|
||||
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
|
||||
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
|
||||
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
|
||||
return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims)
|
||||
""")(_broadcast_arrays)
|
||||
|
||||
|
||||
broadcast_to = _wraps(np.broadcast_to, lax_description="""\
|
||||
The JAX version does not necessarily return a view of the input.
|
||||
""")(_broadcast_to)
|
||||
|
||||
|
||||
def _split(op, ary, indices_or_sections, axis=0):
|
||||
|
@ -12,11 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Callable, NamedTuple, Optional, Dict, Sequence
|
||||
from typing import Callable, NamedTuple, Optional, Dict, Sequence, Set, Type
|
||||
import warnings
|
||||
|
||||
from jax._src.config import config
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src import api
|
||||
from jax import core
|
||||
from jax._src.lax import lax
|
||||
|
||||
import numpy as np
|
||||
|
||||
_parameter_break = re.compile("\n(?=[A-Za-z_])")
|
||||
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
|
||||
@ -178,3 +189,192 @@ def _wraps(fun: Optional[Callable], update_doc: bool = True, lax_description: st
|
||||
setattr(op, attr, value)
|
||||
return op
|
||||
return wrap
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
def _asarray(arr):
|
||||
"""
|
||||
Pared-down utility to convert object to a DeviceArray.
|
||||
Note this will not correctly handle lists or tuples.
|
||||
"""
|
||||
_check_arraylike("_asarray", arr)
|
||||
dtype, weak_type = dtypes._lattice_result_type(arr)
|
||||
return lax_internal._convert_element_type(arr, dtype, weak_type)
|
||||
|
||||
def _promote_shapes(fun_name, *args):
|
||||
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
|
||||
if len(args) < 2:
|
||||
return args
|
||||
else:
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
|
||||
return args # no need for rank promotion, so rely on lax promotion
|
||||
nonscalar_ranks = {len(shp) for shp in shapes if shp}
|
||||
if len(nonscalar_ranks) < 2:
|
||||
return args
|
||||
else:
|
||||
if config.jax_numpy_rank_promotion != "allow":
|
||||
_rank_promotion_warning_or_error(fun_name, shapes)
|
||||
if config.jax_dynamic_shapes:
|
||||
# With dynamic shapes we don't support singleton-dimension broadcasting;
|
||||
# we instead broadcast out to the full shape as a temporary workaround.
|
||||
res_shape = lax.broadcast_shapes(*shapes)
|
||||
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
|
||||
else:
|
||||
result_rank = len(lax.broadcast_shapes(*shapes))
|
||||
return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
|
||||
for arg, shp in zip(args, shapes)]
|
||||
|
||||
|
||||
def _rank_promotion_warning_or_error(fun_name, shapes):
|
||||
if config.jax_numpy_rank_promotion == "warn":
|
||||
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
|
||||
"Set the jax_numpy_rank_promotion config option to 'allow' to "
|
||||
"disable this warning; for more information, see "
|
||||
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
||||
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||
elif config.jax_numpy_rank_promotion == "raise":
|
||||
msg = ("Operands could not be broadcast together for {} on shapes {} "
|
||||
"and with the config option jax_numpy_rank_promotion='raise'. "
|
||||
"For more information, see "
|
||||
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
|
||||
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
|
||||
|
||||
|
||||
def _promote_dtypes(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
||||
if len(args) < 2:
|
||||
return args
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
|
||||
def _promote_dtypes_inexact(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion.
|
||||
|
||||
Promotes arguments to an inexact type."""
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype)
|
||||
to_dtype_inexact = _to_inexact_dtype(to_dtype)
|
||||
weak_type = (weak_type and to_dtype == to_dtype_inexact)
|
||||
return [lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
|
||||
for x in args]
|
||||
|
||||
|
||||
def _to_inexact_dtype(dtype):
|
||||
"""Promotes a dtype into an inexact dtype, if it is not already one."""
|
||||
return dtype if dtypes.issubdtype(dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_)
|
||||
|
||||
|
||||
def _complex_elem_type(dtype):
|
||||
"""Returns the float type of the real/imaginary parts of a complex dtype."""
|
||||
return np.abs(np.zeros((), dtype)).dtype
|
||||
|
||||
|
||||
def _arraylike(x):
|
||||
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
|
||||
hasattr(x, '__jax_array__') or np.isscalar(x))
|
||||
|
||||
|
||||
def _stackable(*args):
|
||||
return all(type(arg) in stackables for arg in args)
|
||||
stackables: Set[Type] = set()
|
||||
_register_stackable: Callable[[Type], None] = stackables.add
|
||||
|
||||
|
||||
def _check_arraylike(fun_name, *args):
|
||||
"""Check if all args fit JAX's definition of arraylike."""
|
||||
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
|
||||
if any(not _arraylike(arg) for arg in args):
|
||||
pos, arg = next((i, arg) for i, arg in enumerate(args)
|
||||
if not _arraylike(arg))
|
||||
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
|
||||
raise TypeError(msg.format(fun_name, type(arg), pos))
|
||||
|
||||
|
||||
def _check_no_float0s(fun_name, *args):
|
||||
"""Check if none of the args have dtype float0."""
|
||||
if any(dtypes.dtype(arg) is dtypes.float0 for arg in args):
|
||||
raise TypeError(
|
||||
f"Called {fun_name} with a float0 array. "
|
||||
"float0s do not support any operations by design because they "
|
||||
"are not compatible with non-trivial vector spaces. No implicit dtype "
|
||||
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
|
||||
"to cast a float0 array to a regular zeros array. \n"
|
||||
"If you didn't expect to get a float0 you might have accidentally "
|
||||
"taken a gradient with respect to an integer argument.")
|
||||
|
||||
|
||||
def _promote_args(fun_name, *args):
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion."""
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes(*args))
|
||||
|
||||
|
||||
def _promote_args_inexact(fun_name, *args):
|
||||
"""Convenience function to apply Numpy argument shape and dtype promotion.
|
||||
|
||||
Promotes non-inexact types to an inexact type."""
|
||||
_check_arraylike(fun_name, *args)
|
||||
_check_no_float0s(fun_name, *args)
|
||||
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
|
||||
|
||||
|
||||
@partial(api.jit, inline=True)
|
||||
def _broadcast_arrays(*args):
|
||||
"""Like Numpy's broadcast_arrays but doesn't return views."""
|
||||
shapes = [np.shape(arg) for arg in args]
|
||||
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
|
||||
# TODO(mattjj): remove the array(arg) here
|
||||
return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg)
|
||||
for arg in args]
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
return [_broadcast_to(arg, result_shape) for arg in args]
|
||||
|
||||
|
||||
def _broadcast_to(arr, shape):
|
||||
if hasattr(arr, "broadcast_to"):
|
||||
return arr.broadcast_to(shape)
|
||||
_check_arraylike("broadcast_to", arr)
|
||||
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
|
||||
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
|
||||
shape = (shape,)
|
||||
shape = core.canonicalize_shape(shape) # check that shape is concrete
|
||||
arr_shape = np.shape(arr)
|
||||
if core.symbolic_equal_shape(arr_shape, shape):
|
||||
return arr
|
||||
else:
|
||||
nlead = len(shape) - len(arr_shape)
|
||||
shape_tail = shape[nlead:]
|
||||
compatible = all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
|
||||
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
|
||||
if nlead < 0 or not compatible:
|
||||
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
|
||||
raise ValueError(msg.format(arr_shape, shape))
|
||||
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
|
||||
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
|
||||
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
|
||||
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
|
||||
return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
|
||||
|
||||
|
||||
# The `jit` on `where` exists to avoid materializing constants in cases like
|
||||
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
|
||||
# materialize the broadcast forms of scalar arguments.
|
||||
@api.jit
|
||||
def _where(condition, x=None, y=None):
|
||||
if x is None or y is None:
|
||||
raise ValueError("Either both or neither of the x and y arguments should "
|
||||
"be provided to jax.numpy.where, got {} and {}."
|
||||
.format(x, y))
|
||||
if not np.issubdtype(_dtype(condition), np.bool_):
|
||||
condition = lax.ne(condition, lax_internal._zero(condition))
|
||||
x, y = _promote_dtypes(x, y)
|
||||
condition, x, y = _broadcast_arrays(condition, x, y)
|
||||
try: is_always_empty = core.is_empty_shape(np.shape(x))
|
||||
except: is_always_empty = False # can fail with dynamic shapes
|
||||
return lax.select(condition, x, y) if not is_always_empty else x
|
||||
|
Loading…
x
Reference in New Issue
Block a user