Implement jnp.from* array creation functions

This commit is contained in:
Jake VanderPlas 2022-03-29 10:52:47 -07:00
parent a68b0f3a0a
commit 093b7032a8
5 changed files with 51 additions and 0 deletions

View File

@ -15,6 +15,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta` * added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`). and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
* the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`). * the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`).
* added array creation routines {func}`jax.numpy.frombuffer`, {func}`jax.numpy.fromfunction`,
and {func}`jax.numpy.fromstring` ({jax-issue}`#10049`).
* Deprecations: * Deprecations:
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`). * {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).

View File

@ -175,6 +175,9 @@ namespace; they are listed below.
fmin fmin
fmod fmod
frexp frexp
frombuffer
fromfunction
fromstring
full full
full_like full_like
gcd gcd

View File

@ -1966,6 +1966,27 @@ empty_like = zeros_like
empty = zeros empty = zeros
# General np.from* style functions mostly delegate to numpy.
@_wraps(np.frombuffer)
def frombuffer(buffer, dtype=float, count=-1, offset=0):
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))
@_wraps(np.fromfunction)
def fromfunction(function, shape, *, dtype=float, **kwargs):
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
for i in range(len(shape)):
in_axes = [0 if i == j else None for j in range(len(shape))]
function = jax.vmap(function, in_axes=tuple(in_axes[::-1]))
return function(*(arange(s, dtype=dtype) for s in shape), **kwargs)
@_wraps(np.fromstring)
def fromstring(string, dtype=float, count=-1, *, sep):
return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep))
@_wraps(np.eye) @_wraps(np.eye)
def eye(N, M=None, k=0, dtype=None): def eye(N, M=None, k=0, dtype=None):
lax_internal._check_user_dtype_supported(dtype, "eye") lax_internal._check_user_dtype_supported(dtype, "eye")

View File

@ -122,6 +122,9 @@ from jax._src.numpy.lax_numpy import (
floating as floating, floating as floating,
fmax as fmax, fmax as fmax,
fmin as fmin, fmin as fmin,
frombuffer as frombuffer,
fromfunction as fromfunction,
fromstring as fromstring,
full as full, full as full,
full_like as full_like, full_like as full_like,
gcd as gcd, gcd as gcd,

View File

@ -5922,6 +5922,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64) self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64)
self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128) self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128)
def testFromBuffer(self):
buf = b'\x01\x02\x03'
expected = np.frombuffer(buf, dtype='uint8')
actual = jnp.frombuffer(buf, dtype='uint8')
self.assertArraysEqual(expected, actual)
def testFromFunction(self):
def f(x, y, z):
return x + 2 * y + 3 * z
shape = (3, 4, 5)
expected = np.fromfunction(f, shape=shape)
actual = jnp.fromfunction(f, shape=shape)
self.assertArraysEqual(expected, actual)
def testFromString(self):
s = "1,2,3"
expected = np.fromstring(s, sep=',', dtype=int)
actual = jnp.fromstring(s, sep=',', dtype=int)
self.assertArraysEqual(expected, actual)
# Most grad tests are at the lax level (see lax_test.py), but we add some here # Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest. # as needed for e.g. particular compound ops of interest.
@ -6093,6 +6114,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'identity': ['like'], 'identity': ['like'],
'full': ['order', 'like'], 'full': ['order', 'like'],
'full_like': ['subok', 'order'], 'full_like': ['subok', 'order'],
'fromfunction': ['like'],
'histogram': ['normed'], 'histogram': ['normed'],
'histogram2d': ['normed'], 'histogram2d': ['normed'],
'histogramdd': ['normed'], 'histogramdd': ['normed'],