Implement jnp.from* array creation functions

This commit is contained in:
Jake VanderPlas 2022-03-29 10:52:47 -07:00
parent a68b0f3a0a
commit 093b7032a8
5 changed files with 51 additions and 0 deletions

View File

@ -15,6 +15,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
* the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`).
* added array creation routines {func}`jax.numpy.frombuffer`, {func}`jax.numpy.fromfunction`,
and {func}`jax.numpy.fromstring` ({jax-issue}`#10049`).
* Deprecations:
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).

View File

@ -175,6 +175,9 @@ namespace; they are listed below.
fmin
fmod
frexp
frombuffer
fromfunction
fromstring
full
full_like
gcd

View File

@ -1966,6 +1966,27 @@ empty_like = zeros_like
empty = zeros
# General np.from* style functions mostly delegate to numpy.
@_wraps(np.frombuffer)
def frombuffer(buffer, dtype=float, count=-1, offset=0):
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))
@_wraps(np.fromfunction)
def fromfunction(function, shape, *, dtype=float, **kwargs):
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
for i in range(len(shape)):
in_axes = [0 if i == j else None for j in range(len(shape))]
function = jax.vmap(function, in_axes=tuple(in_axes[::-1]))
return function(*(arange(s, dtype=dtype) for s in shape), **kwargs)
@_wraps(np.fromstring)
def fromstring(string, dtype=float, count=-1, *, sep):
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))
@_wraps(np.eye)
def eye(N, M=None, k=0, dtype=None):
lax_internal._check_user_dtype_supported(dtype, "eye")

View File

@ -122,6 +122,9 @@ from jax._src.numpy.lax_numpy import (
floating as floating,
fmax as fmax,
fmin as fmin,
frombuffer as frombuffer,
fromfunction as fromfunction,
fromstring as fromstring,
full as full,
full_like as full_like,
gcd as gcd,

View File

@ -5922,6 +5922,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64)
self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128)
def testFromBuffer(self):
buf = b'\x01\x02\x03'
expected = np.frombuffer(buf, dtype='uint8')
actual = jnp.frombuffer(buf, dtype='uint8')
self.assertArraysEqual(expected, actual)
def testFromFunction(self):
def f(x, y, z):
return x + 2 * y + 3 * z
shape = (3, 4, 5)
expected = np.fromfunction(f, shape=shape)
actual = jnp.fromfunction(f, shape=shape)
self.assertArraysEqual(expected, actual)
def testFromString(self):
s = "1,2,3"
expected = np.fromstring(s, sep=',', dtype=int)
actual = jnp.fromstring(s, sep=',', dtype=int)
self.assertArraysEqual(expected, actual)
# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.
@ -6093,6 +6114,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'identity': ['like'],
'full': ['order', 'like'],
'full_like': ['subok', 'order'],
'fromfunction': ['like'],
'histogram': ['normed'],
'histogram2d': ['normed'],
'histogramdd': ['normed'],