mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24905 from jakevdp:old-arg
PiperOrigin-RevId: 696679346
This commit is contained in:
commit
5764afb4b3
@ -82,7 +82,7 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
|
||||
|
||||
ReductionOp = Callable[[Any, Any], Any]
|
||||
|
||||
def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike,
|
||||
def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
|
||||
*, has_identity: bool = True,
|
||||
preproc: Callable[[ArrayLike], ArrayLike] | None = None,
|
||||
bool_op: ReductionOp | None = None,
|
||||
@ -215,7 +215,7 @@ def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None,
|
||||
promote_integers: bool = True) -> Array:
|
||||
return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
|
||||
return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,
|
||||
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.psum,
|
||||
@ -301,7 +301,7 @@ def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None,
|
||||
promote_integers: bool = True) -> Array:
|
||||
return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric,
|
||||
return _reduction(a, "prod", lax.mul, 1, preproc=_cast_to_numeric,
|
||||
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, promote_integers=promote_integers)
|
||||
@ -386,7 +386,7 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
|
||||
return _reduction(a, "max", lax.max, -np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmax)
|
||||
|
||||
@ -468,7 +468,7 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, initial: ArrayLike | None = None,
|
||||
where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
|
||||
return _reduction(a, "min", lax.min, np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmin)
|
||||
|
||||
@ -548,7 +548,7 @@ def min(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
|
||||
return _reduction(a, "all", lax.bitwise_and, True, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
|
||||
@ -604,7 +604,7 @@ def all(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
|
||||
def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None,
|
||||
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
|
||||
return _reduction(a, "any", lax.bitwise_or, False, preproc=_cast_to_bool,
|
||||
axis=axis, out=out, keepdims=keepdims, where_=where)
|
||||
|
||||
|
||||
@ -664,7 +664,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
|
||||
arr = lax_internal.asarray(a)
|
||||
init_val = np.array(-1, dtype=dtype or arr.dtype)
|
||||
return _reduction(arr, name="reduce_bitwise_and", np_fun=None, op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
|
||||
return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
|
||||
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where)
|
||||
|
||||
@ -673,7 +673,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
|
||||
def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, name="reduce_bitwise_or", np_fun=None, op=lax.bitwise_or, init_val=0, preproc=_require_integer,
|
||||
return _reduction(a, name="reduce_bitwise_or", op=lax.bitwise_or, init_val=0, preproc=_require_integer,
|
||||
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where)
|
||||
|
||||
@ -682,7 +682,7 @@ def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
|
||||
def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, name="reduce_bitwise_xor", np_fun=None, op=lax.bitwise_xor, init_val=0, preproc=_require_integer,
|
||||
return _reduction(a, name="reduce_bitwise_xor", op=lax.bitwise_xor, init_val=0, preproc=_require_integer,
|
||||
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where)
|
||||
|
||||
@ -691,7 +691,7 @@ def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
|
||||
def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, name="reduce_logical_and", np_fun=None, op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool,
|
||||
return _reduction(a, name="reduce_logical_and", op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool,
|
||||
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where)
|
||||
|
||||
@ -700,7 +700,7 @@ def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
|
||||
def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, name="reduce_logical_or", np_fun=None, op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool,
|
||||
return _reduction(a, name="reduce_logical_or", op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool,
|
||||
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where)
|
||||
|
||||
@ -709,7 +709,7 @@ def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None
|
||||
def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
|
||||
out: None = None, keepdims: bool = False,
|
||||
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
|
||||
return _reduction(a, name="reduce_logical_xor", np_fun=None, op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool,
|
||||
return _reduction(a, name="reduce_logical_xor", op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool,
|
||||
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user