mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Implement jnp.from* array creation functions
This commit is contained in:
parent
a68b0f3a0a
commit
093b7032a8
@ -15,6 +15,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
|
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
|
||||||
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
|
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
|
||||||
* the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`).
|
* the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`).
|
||||||
|
* added array creation routines {func}`jax.numpy.frombuffer`, {func}`jax.numpy.fromfunction`,
|
||||||
|
and {func}`jax.numpy.fromstring` ({jax-issue}`#10049`).
|
||||||
* Deprecations:
|
* Deprecations:
|
||||||
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
|
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
|
||||||
|
|
||||||
|
@ -175,6 +175,9 @@ namespace; they are listed below.
|
|||||||
fmin
|
fmin
|
||||||
fmod
|
fmod
|
||||||
frexp
|
frexp
|
||||||
|
frombuffer
|
||||||
|
fromfunction
|
||||||
|
fromstring
|
||||||
full
|
full
|
||||||
full_like
|
full_like
|
||||||
gcd
|
gcd
|
||||||
|
@ -1966,6 +1966,27 @@ empty_like = zeros_like
|
|||||||
empty = zeros
|
empty = zeros
|
||||||
|
|
||||||
|
|
||||||
|
# General np.from* style functions mostly delegate to numpy.
|
||||||
|
|
||||||
|
@_wraps(np.frombuffer)
|
||||||
|
def frombuffer(buffer, dtype=float, count=-1, offset=0):
|
||||||
|
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))
|
||||||
|
|
||||||
|
|
||||||
|
@_wraps(np.fromfunction)
|
||||||
|
def fromfunction(function, shape, *, dtype=float, **kwargs):
|
||||||
|
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
|
||||||
|
for i in range(len(shape)):
|
||||||
|
in_axes = [0 if i == j else None for j in range(len(shape))]
|
||||||
|
function = jax.vmap(function, in_axes=tuple(in_axes[::-1]))
|
||||||
|
return function(*(arange(s, dtype=dtype) for s in shape), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@_wraps(np.fromstring)
|
||||||
|
def fromstring(string, dtype=float, count=-1, *, sep):
|
||||||
|
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))
|
||||||
|
|
||||||
|
|
||||||
@_wraps(np.eye)
|
@_wraps(np.eye)
|
||||||
def eye(N, M=None, k=0, dtype=None):
|
def eye(N, M=None, k=0, dtype=None):
|
||||||
lax_internal._check_user_dtype_supported(dtype, "eye")
|
lax_internal._check_user_dtype_supported(dtype, "eye")
|
||||||
|
@ -122,6 +122,9 @@ from jax._src.numpy.lax_numpy import (
|
|||||||
floating as floating,
|
floating as floating,
|
||||||
fmax as fmax,
|
fmax as fmax,
|
||||||
fmin as fmin,
|
fmin as fmin,
|
||||||
|
frombuffer as frombuffer,
|
||||||
|
fromfunction as fromfunction,
|
||||||
|
fromstring as fromstring,
|
||||||
full as full,
|
full as full,
|
||||||
full_like as full_like,
|
full_like as full_like,
|
||||||
gcd as gcd,
|
gcd as gcd,
|
||||||
|
@ -5922,6 +5922,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64)
|
self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64)
|
||||||
self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128)
|
self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128)
|
||||||
|
|
||||||
|
def testFromBuffer(self):
|
||||||
|
buf = b'\x01\x02\x03'
|
||||||
|
expected = np.frombuffer(buf, dtype='uint8')
|
||||||
|
actual = jnp.frombuffer(buf, dtype='uint8')
|
||||||
|
self.assertArraysEqual(expected, actual)
|
||||||
|
|
||||||
|
def testFromFunction(self):
|
||||||
|
def f(x, y, z):
|
||||||
|
return x + 2 * y + 3 * z
|
||||||
|
shape = (3, 4, 5)
|
||||||
|
expected = np.fromfunction(f, shape=shape)
|
||||||
|
actual = jnp.fromfunction(f, shape=shape)
|
||||||
|
self.assertArraysEqual(expected, actual)
|
||||||
|
|
||||||
|
def testFromString(self):
|
||||||
|
s = "1,2,3"
|
||||||
|
expected = np.fromstring(s, sep=',', dtype=int)
|
||||||
|
actual = jnp.fromstring(s, sep=',', dtype=int)
|
||||||
|
self.assertArraysEqual(expected, actual)
|
||||||
|
|
||||||
|
|
||||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||||
# as needed for e.g. particular compound ops of interest.
|
# as needed for e.g. particular compound ops of interest.
|
||||||
|
|
||||||
@ -6093,6 +6114,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
|||||||
'identity': ['like'],
|
'identity': ['like'],
|
||||||
'full': ['order', 'like'],
|
'full': ['order', 'like'],
|
||||||
'full_like': ['subok', 'order'],
|
'full_like': ['subok', 'order'],
|
||||||
|
'fromfunction': ['like'],
|
||||||
'histogram': ['normed'],
|
'histogram': ['normed'],
|
||||||
'histogram2d': ['normed'],
|
'histogram2d': ['normed'],
|
||||||
'histogramdd': ['normed'],
|
'histogramdd': ['normed'],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user