Implement jnp.ogrid

Related to #5850
This commit is contained in:
minoring 2021-04-05 10:35:45 +09:00
parent 3c1ee0644b
commit 4c67dd1f48
4 changed files with 72 additions and 2 deletions

View File

@ -276,6 +276,7 @@ Not every function in NumPy is implemented; contributions are welcome!
not_equal
number
object_
ogrid
ones
ones_like
outer

View File

@ -3289,6 +3289,40 @@ class _Mgrid:
mgrid = _Mgrid()
class _Ogrid:
"""Return open multi-dimensional "meshgrid".
LAX-backend implementation of `ogrid()`."""
def __getitem__(self, key):
if isinstance(key, slice):
start = core.concrete_or_error(None, key.start,
"slice start of jnp.ogrid") or 0
stop = core.concrete_or_error(None, key.stop, "slice stop of jnp.ogrid")
step = core.concrete_or_error(None, key.step,
"slice step of jnp.ogrid") or 1
if np.iscomplex(step):
return linspace(start, stop, int(_abs(step)))
return arange(start, stop, step)
output = []
for k in key:
start = core.concrete_or_error(None, k.start,
"slice start of jnp.ogrid") or 0
stop = core.concrete_or_error(None, k.stop, "slice stop of jnp.ogrid")
step = core.concrete_or_error(None, k.step,
"slice step of jnp.ogrid") or 1
if np.iscomplex(step):
output.append(linspace(start, stop, int(_abs(step))))
else:
output.append(arange(start, stop, step))
return meshgrid(*output, indexing='ij', sparse=True)
ogrid = _Ogrid()
@_wraps(np.i0)
def i0(x):
x_orig = x

View File

@ -38,7 +38,7 @@ from jax._src.numpy.lax_numpy import (
fmod, frexp, full, full_like, gcd, geomspace, gradient, greater,
greater_equal, hamming, hanning, heaviside, histogram, histogram_bin_edges, histogram2d, histogramdd,
hsplit, hstack, hypot, i0, identity, iinfo, imag,
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
interp, intersect1d, invert,
isclose, iscomplex, iscomplexobj, isfinite, isin, isinf, isnan, isneginf,
isposinf, isreal, isrealobj, isscalar, issubdtype, issubsctype, iterable,
@ -50,7 +50,7 @@ from jax._src.numpy.lax_numpy import (
nanmedian, nanpercentile, nanquantile,
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
negative, newaxis, nextafter, nonzero, not_equal, number, numpy_version,
object_, ones, ones_like, operator_name, outer, packbits, pad, percentile,
object_, ogrid, ones, ones_like, operator_name, outer, packbits, pad, percentile,
pi, piecewise, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
prod, product, promote_types, ptp, quantile,
rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,

View File

@ -4670,6 +4670,41 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
atol=atol,
rtol=rtol)
def testOgrid(self):
def assertListOfArraysEqual(xs, ys):
self.assertIsInstance(xs, list)
self.assertIsInstance(ys, list)
self.assertEqual(len(xs), len(ys))
for x, y in zip(xs, ys):
self.assertArraysEqual(x, y)
self.assertArraysEqual(np.ogrid[:5], jnp.ogrid[:5])
self.assertArraysEqual(np.ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])())
self.assertArraysEqual(np.ogrid[1:7:2], jnp.ogrid[1:7:2])
# List of arrays
assertListOfArraysEqual(np.ogrid[:5,], jnp.ogrid[:5,])
assertListOfArraysEqual(np.ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
assertListOfArraysEqual(np.ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
assertListOfArraysEqual(np.ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
# Corner cases
self.assertArraysEqual(np.ogrid[:], jnp.ogrid[:])
# Complex number steps
atol = 1e-6
rtol = 1e-6
self.assertAllClose(np.ogrid[-1:1:5j],
jnp.ogrid[-1:1:5j],
atol=atol,
rtol=rtol)
# Non-integer steps
self.assertAllClose(np.ogrid[0:3.5:0.3],
jnp.ogrid[0:3.5:0.3],
atol=atol,
rtol=rtol)
self.assertAllClose(np.ogrid[1.2:4.8:0.24],
jnp.ogrid[1.2:4.8:0.24],
atol=atol,
rtol=rtol)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"