Expose existing functions in array API namespace

This commit is contained in:
Meekail Zain 2024-04-15 16:25:30 +00:00
parent 2c85ca6fec
commit 8b93da1830
5 changed files with 55 additions and 1 deletions

View File

@ -114,6 +114,7 @@ from jax.experimental.array_api._elementwise_functions import (
ceil as ceil,
clip as clip,
conj as conj,
copysign as copysign,
cos as cos,
cosh as cosh,
divide as divide,
@ -139,6 +140,8 @@ from jax.experimental.array_api._elementwise_functions import (
logical_not as logical_not,
logical_or as logical_or,
logical_xor as logical_xor,
maximum as maximum,
minimum as minimum,
multiply as multiply,
negative as negative,
not_equal as not_equal,
@ -148,6 +151,7 @@ from jax.experimental.array_api._elementwise_functions import (
remainder as remainder,
round as round,
sign as sign,
signbit as signbit,
sin as sin,
sinh as sinh,
sqrt as sqrt,
@ -168,7 +172,9 @@ from jax.experimental.array_api._manipulation_functions import (
concat as concat,
expand_dims as expand_dims,
flip as flip,
moveaxis as moveaxis,
permute_dims as permute_dims,
repeat as repeat,
reshape as reshape,
roll as roll,
squeeze as squeeze,
@ -179,6 +185,7 @@ from jax.experimental.array_api._searching_functions import (
argmax as argmax,
argmin as argmin,
nonzero as nonzero,
searchsorted as searchsorted,
where as where,
)

View File

@ -17,7 +17,6 @@ from jax.experimental.array_api._data_type_functions import (
result_type as _result_type,
isdtype as _isdtype,
)
import numpy as np
def _promote_dtypes(name, *args):
@ -148,6 +147,11 @@ def conj(x, /):
return jax.numpy.conj(x)
def copysign(x1, x2, /):
"""Composes a floating-point value with the magnitude of x1_i and the sign of x2_i for each element of the input array x1."""
return jax.numpy.copysign(x1, x2)
def cos(x, /):
"""Calculates an implementation-dependent approximation to the cosine for each element x_i of the input array x."""
x, = _promote_dtypes("cos", x)
@ -300,6 +304,18 @@ def logical_xor(x1, x2, /):
return jax.numpy.logical_xor(x1, x2)
def maximum(x1, x2, /):
"""Computes the maximum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("maximum", x1, x2)
return jax.numpy.maximum(x1, x2)
def minimum(x1, x2, /):
"""Computes the minimum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("minimum", x1, x2)
return jax.numpy.minimum(x1, x2)
def multiply(x1, x2, /):
"""Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("multiply", x1, x2)
@ -356,6 +372,11 @@ def sign(x, /):
return jax.numpy.sign(x)
def signbit(x, /):
"""Determines whether the sign bit is set for each element x_i of the input array x."""
return jax.numpy.signbit(x)
def sin(x, /):
"""Calculates an implementation-dependent approximation to the sine for each element x_i of the input array x."""
x, = _promote_dtypes("sin", x)

View File

@ -47,11 +47,21 @@ def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array:
return jax.numpy.flip(x, axis=axis)
def moveaxis(x: Array, source: int | tuple[int, ...], destination: int | tuple[int, ...], /) -> Array:
"""Moves array axes (dimensions) to new positions, while leaving other axes in their original positions."""
return jax.numpy.moveaxis(x, source, destination)
def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
"""Permutes the axes (dimensions) of an array x."""
return jax.numpy.permute_dims(x, axes=axes)
def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
"""Repeats each element of an array a specified number of times on a per-element basis."""
return jax.numpy.repeat(x, repeats=repeats, axis=axis)
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused

View File

@ -33,6 +33,15 @@ def nonzero(x, /):
return jax.numpy.nonzero(x)
def searchsorted(x1, x2, /, *, side='left', sorter=None):
"""
Finds the indices into x1 such that, if the corresponding elements in x2
were inserted before the indices, the order of x1, when sorted in ascending
order, would be preserved.
"""
return jax.numpy.searchsorted(x1, x2, side=side, sorter=sorter)
def where(condition, x1, x2, /):
"""Returns elements chosen from x1 or x2 depending on condition."""
dtype = _result_type(x1, x2)

View File

@ -65,6 +65,7 @@ MAIN_NAMESPACE = {
'complex64',
'concat',
'conj',
'copysign',
'cos',
'cosh',
'divide',
@ -115,9 +116,12 @@ MAIN_NAMESPACE = {
'matmul',
'matrix_transpose',
'max',
'maximum',
'mean',
'meshgrid',
'min',
'minimum',
'moveaxis',
'multiply',
'nan',
'negative',
@ -133,11 +137,14 @@ MAIN_NAMESPACE = {
'prod',
'real',
'remainder',
'repeat',
'reshape',
'result_type',
'roll',
'round',
'searchsorted',
'sign',
'signbit',
'sin',
'sinh',
'sort',