mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Implement jnp.from* array creation functions
This commit is contained in:
parent
a68b0f3a0a
commit
093b7032a8
@ -15,6 +15,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
|
||||
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
|
||||
* the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`).
|
||||
* added array creation routines {func}`jax.numpy.frombuffer`, {func}`jax.numpy.fromfunction`,
|
||||
and {func}`jax.numpy.fromstring` ({jax-issue}`#10049`).
|
||||
* Deprecations:
|
||||
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).
|
||||
|
||||
|
@ -175,6 +175,9 @@ namespace; they are listed below.
|
||||
fmin
|
||||
fmod
|
||||
frexp
|
||||
frombuffer
|
||||
fromfunction
|
||||
fromstring
|
||||
full
|
||||
full_like
|
||||
gcd
|
||||
|
@ -1966,6 +1966,27 @@ empty_like = zeros_like
|
||||
empty = zeros
|
||||
|
||||
|
||||
# General np.from* style functions mostly delegate to numpy.
|
||||
|
||||
@_wraps(np.frombuffer)
|
||||
def frombuffer(buffer, dtype=float, count=-1, offset=0):
|
||||
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))
|
||||
|
||||
|
||||
@_wraps(np.fromfunction)
|
||||
def fromfunction(function, shape, *, dtype=float, **kwargs):
|
||||
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
|
||||
for i in range(len(shape)):
|
||||
in_axes = [0 if i == j else None for j in range(len(shape))]
|
||||
function = jax.vmap(function, in_axes=tuple(in_axes[::-1]))
|
||||
return function(*(arange(s, dtype=dtype) for s in shape), **kwargs)
|
||||
|
||||
|
||||
@_wraps(np.fromstring)
|
||||
def fromstring(string, dtype=float, count=-1, *, sep):
|
||||
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))
|
||||
|
||||
|
||||
@_wraps(np.eye)
|
||||
def eye(N, M=None, k=0, dtype=None):
|
||||
lax_internal._check_user_dtype_supported(dtype, "eye")
|
||||
|
@ -122,6 +122,9 @@ from jax._src.numpy.lax_numpy import (
|
||||
floating as floating,
|
||||
fmax as fmax,
|
||||
fmin as fmin,
|
||||
frombuffer as frombuffer,
|
||||
fromfunction as fromfunction,
|
||||
fromstring as fromstring,
|
||||
full as full,
|
||||
full_like as full_like,
|
||||
gcd as gcd,
|
||||
|
@ -5922,6 +5922,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64)
|
||||
self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128)
|
||||
|
||||
def testFromBuffer(self):
|
||||
buf = b'\x01\x02\x03'
|
||||
expected = np.frombuffer(buf, dtype='uint8')
|
||||
actual = jnp.frombuffer(buf, dtype='uint8')
|
||||
self.assertArraysEqual(expected, actual)
|
||||
|
||||
def testFromFunction(self):
|
||||
def f(x, y, z):
|
||||
return x + 2 * y + 3 * z
|
||||
shape = (3, 4, 5)
|
||||
expected = np.fromfunction(f, shape=shape)
|
||||
actual = jnp.fromfunction(f, shape=shape)
|
||||
self.assertArraysEqual(expected, actual)
|
||||
|
||||
def testFromString(self):
|
||||
s = "1,2,3"
|
||||
expected = np.fromstring(s, sep=',', dtype=int)
|
||||
actual = jnp.fromstring(s, sep=',', dtype=int)
|
||||
self.assertArraysEqual(expected, actual)
|
||||
|
||||
|
||||
# Most grad tests are at the lax level (see lax_test.py), but we add some here
|
||||
# as needed for e.g. particular compound ops of interest.
|
||||
|
||||
@ -6093,6 +6114,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'identity': ['like'],
|
||||
'full': ['order', 'like'],
|
||||
'full_like': ['subok', 'order'],
|
||||
'fromfunction': ['like'],
|
||||
'histogram': ['normed'],
|
||||
'histogram2d': ['normed'],
|
||||
'histogramdd': ['normed'],
|
||||
|
Loading…
x
Reference in New Issue
Block a user