mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Expose existing functions in array API namespace
This commit is contained in:
parent
2c85ca6fec
commit
8b93da1830
@ -114,6 +114,7 @@ from jax.experimental.array_api._elementwise_functions import (
|
||||
ceil as ceil,
|
||||
clip as clip,
|
||||
conj as conj,
|
||||
copysign as copysign,
|
||||
cos as cos,
|
||||
cosh as cosh,
|
||||
divide as divide,
|
||||
@ -139,6 +140,8 @@ from jax.experimental.array_api._elementwise_functions import (
|
||||
logical_not as logical_not,
|
||||
logical_or as logical_or,
|
||||
logical_xor as logical_xor,
|
||||
maximum as maximum,
|
||||
minimum as minimum,
|
||||
multiply as multiply,
|
||||
negative as negative,
|
||||
not_equal as not_equal,
|
||||
@ -148,6 +151,7 @@ from jax.experimental.array_api._elementwise_functions import (
|
||||
remainder as remainder,
|
||||
round as round,
|
||||
sign as sign,
|
||||
signbit as signbit,
|
||||
sin as sin,
|
||||
sinh as sinh,
|
||||
sqrt as sqrt,
|
||||
@ -168,7 +172,9 @@ from jax.experimental.array_api._manipulation_functions import (
|
||||
concat as concat,
|
||||
expand_dims as expand_dims,
|
||||
flip as flip,
|
||||
moveaxis as moveaxis,
|
||||
permute_dims as permute_dims,
|
||||
repeat as repeat,
|
||||
reshape as reshape,
|
||||
roll as roll,
|
||||
squeeze as squeeze,
|
||||
@ -179,6 +185,7 @@ from jax.experimental.array_api._searching_functions import (
|
||||
argmax as argmax,
|
||||
argmin as argmin,
|
||||
nonzero as nonzero,
|
||||
searchsorted as searchsorted,
|
||||
where as where,
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,6 @@ from jax.experimental.array_api._data_type_functions import (
|
||||
result_type as _result_type,
|
||||
isdtype as _isdtype,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _promote_dtypes(name, *args):
|
||||
@ -148,6 +147,11 @@ def conj(x, /):
|
||||
return jax.numpy.conj(x)
|
||||
|
||||
|
||||
def copysign(x1, x2, /):
|
||||
"""Composes a floating-point value with the magnitude of x1_i and the sign of x2_i for each element of the input array x1."""
|
||||
return jax.numpy.copysign(x1, x2)
|
||||
|
||||
|
||||
def cos(x, /):
|
||||
"""Calculates an implementation-dependent approximation to the cosine for each element x_i of the input array x."""
|
||||
x, = _promote_dtypes("cos", x)
|
||||
@ -300,6 +304,18 @@ def logical_xor(x1, x2, /):
|
||||
return jax.numpy.logical_xor(x1, x2)
|
||||
|
||||
|
||||
def maximum(x1, x2, /):
|
||||
"""Computes the maximum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2."""
|
||||
x1, x2 = _promote_dtypes("maximum", x1, x2)
|
||||
return jax.numpy.maximum(x1, x2)
|
||||
|
||||
|
||||
def minimum(x1, x2, /):
|
||||
"""Computes the minimum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2."""
|
||||
x1, x2 = _promote_dtypes("minimum", x1, x2)
|
||||
return jax.numpy.minimum(x1, x2)
|
||||
|
||||
|
||||
def multiply(x1, x2, /):
|
||||
"""Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
|
||||
x1, x2 = _promote_dtypes("multiply", x1, x2)
|
||||
@ -356,6 +372,11 @@ def sign(x, /):
|
||||
return jax.numpy.sign(x)
|
||||
|
||||
|
||||
def signbit(x, /):
|
||||
"""Determines whether the sign bit is set for each element x_i of the input array x."""
|
||||
return jax.numpy.signbit(x)
|
||||
|
||||
|
||||
def sin(x, /):
|
||||
"""Calculates an implementation-dependent approximation to the sine for each element x_i of the input array x."""
|
||||
x, = _promote_dtypes("sin", x)
|
||||
|
@ -47,11 +47,21 @@ def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array:
|
||||
return jax.numpy.flip(x, axis=axis)
|
||||
|
||||
|
||||
def moveaxis(x: Array, source: int | tuple[int, ...], destination: int | tuple[int, ...], /) -> Array:
|
||||
"""Moves array axes (dimensions) to new positions, while leaving other axes in their original positions."""
|
||||
return jax.numpy.moveaxis(x, source, destination)
|
||||
|
||||
|
||||
def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
|
||||
"""Permutes the axes (dimensions) of an array x."""
|
||||
return jax.numpy.permute_dims(x, axes=axes)
|
||||
|
||||
|
||||
def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
|
||||
"""Repeats each element of an array a specified number of times on a per-element basis."""
|
||||
return jax.numpy.repeat(x, repeats=repeats, axis=axis)
|
||||
|
||||
|
||||
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
|
||||
"""Reshapes an array without changing its data."""
|
||||
del copy # unused
|
||||
|
@ -33,6 +33,15 @@ def nonzero(x, /):
|
||||
return jax.numpy.nonzero(x)
|
||||
|
||||
|
||||
def searchsorted(x1, x2, /, *, side='left', sorter=None):
|
||||
"""
|
||||
Finds the indices into x1 such that, if the corresponding elements in x2
|
||||
were inserted before the indices, the order of x1, when sorted in ascending
|
||||
order, would be preserved.
|
||||
"""
|
||||
return jax.numpy.searchsorted(x1, x2, side=side, sorter=sorter)
|
||||
|
||||
|
||||
def where(condition, x1, x2, /):
|
||||
"""Returns elements chosen from x1 or x2 depending on condition."""
|
||||
dtype = _result_type(x1, x2)
|
||||
|
@ -65,6 +65,7 @@ MAIN_NAMESPACE = {
|
||||
'complex64',
|
||||
'concat',
|
||||
'conj',
|
||||
'copysign',
|
||||
'cos',
|
||||
'cosh',
|
||||
'divide',
|
||||
@ -115,9 +116,12 @@ MAIN_NAMESPACE = {
|
||||
'matmul',
|
||||
'matrix_transpose',
|
||||
'max',
|
||||
'maximum',
|
||||
'mean',
|
||||
'meshgrid',
|
||||
'min',
|
||||
'minimum',
|
||||
'moveaxis',
|
||||
'multiply',
|
||||
'nan',
|
||||
'negative',
|
||||
@ -133,11 +137,14 @@ MAIN_NAMESPACE = {
|
||||
'prod',
|
||||
'real',
|
||||
'remainder',
|
||||
'repeat',
|
||||
'reshape',
|
||||
'result_type',
|
||||
'roll',
|
||||
'round',
|
||||
'searchsorted',
|
||||
'sign',
|
||||
'signbit',
|
||||
'sin',
|
||||
'sinh',
|
||||
'sort',
|
||||
|
Loading…
x
Reference in New Issue
Block a user