diff --git a/CHANGELOG.md b/CHANGELOG.md index 00dc3f946..32bc78f1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta` and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`). * the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`). + * added array creation routines {func}`jax.numpy.frombuffer`, {func}`jax.numpy.fromfunction`, + and {func}`jax.numpy.fromstring` ({jax-issue}`#10049`). * Deprecations: * {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`). diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index e2a0384a7..72c69a680 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -175,6 +175,9 @@ namespace; they are listed below. fmin fmod frexp + frombuffer + fromfunction + fromstring full full_like gcd diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c043dbcbb..b35e9d425 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1966,6 +1966,27 @@ empty_like = zeros_like empty = zeros +# General np.from* style functions mostly delegate to numpy. + +@_wraps(np.frombuffer) +def frombuffer(buffer, dtype=float, count=-1, offset=0): + return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) + + +@_wraps(np.fromfunction) +def fromfunction(function, shape, *, dtype=float, **kwargs): + shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()") + for i in range(len(shape)): + in_axes = [0 if i == j else None for j in range(len(shape))] + function = jax.vmap(function, in_axes=tuple(in_axes[::-1])) + return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) + + +@_wraps(np.fromstring) +def fromstring(string, dtype=float, count=-1, *, sep): + return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) + + @_wraps(np.eye) def eye(N, M=None, k=0, dtype=None): lax_internal._check_user_dtype_supported(dtype, "eye") diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index e460e39cb..cdb388779 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -122,6 +122,9 @@ from jax._src.numpy.lax_numpy import ( floating as floating, fmax as fmax, fmin as fmin, + frombuffer as frombuffer, + fromfunction as fromfunction, + fromstring as fromstring, full as full, full_like as full_like, gcd as gcd, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index cfe2b94e9..b0fd74bad 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5922,6 +5922,27 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64) self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128) + def testFromBuffer(self): + buf = b'\x01\x02\x03' + expected = np.frombuffer(buf, dtype='uint8') + actual = jnp.frombuffer(buf, dtype='uint8') + self.assertArraysEqual(expected, actual) + + def testFromFunction(self): + def f(x, y, z): + return x + 2 * y + 3 * z + shape = (3, 4, 5) + expected = np.fromfunction(f, shape=shape) + actual = jnp.fromfunction(f, shape=shape) + self.assertArraysEqual(expected, actual) + + def testFromString(self): + s = "1,2,3" + expected = np.fromstring(s, sep=',', dtype=int) + actual = jnp.fromstring(s, sep=',', dtype=int) + self.assertArraysEqual(expected, actual) + + # Most grad tests are at the lax level (see lax_test.py), but we add some here # as needed for e.g. particular compound ops of interest. @@ -6093,6 +6114,7 @@ class NumpySignaturesTest(jtu.JaxTestCase): 'identity': ['like'], 'full': ['order', 'like'], 'full_like': ['subok', 'order'], + 'fromfunction': ['like'], 'histogram': ['normed'], 'histogram2d': ['normed'], 'histogramdd': ['normed'],