mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
parent
3c1ee0644b
commit
4c67dd1f48
@ -276,6 +276,7 @@ Not every function in NumPy is implemented; contributions are welcome!
|
||||
not_equal
|
||||
number
|
||||
object_
|
||||
ogrid
|
||||
ones
|
||||
ones_like
|
||||
outer
|
||||
|
@ -3289,6 +3289,40 @@ class _Mgrid:
|
||||
mgrid = _Mgrid()
|
||||
|
||||
|
||||
class _Ogrid:
|
||||
"""Return open multi-dimensional "meshgrid".
|
||||
|
||||
LAX-backend implementation of `ogrid()`."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
start = core.concrete_or_error(None, key.start,
|
||||
"slice start of jnp.ogrid") or 0
|
||||
stop = core.concrete_or_error(None, key.stop, "slice stop of jnp.ogrid")
|
||||
step = core.concrete_or_error(None, key.step,
|
||||
"slice step of jnp.ogrid") or 1
|
||||
if np.iscomplex(step):
|
||||
return linspace(start, stop, int(_abs(step)))
|
||||
return arange(start, stop, step)
|
||||
|
||||
output = []
|
||||
for k in key:
|
||||
start = core.concrete_or_error(None, k.start,
|
||||
"slice start of jnp.ogrid") or 0
|
||||
stop = core.concrete_or_error(None, k.stop, "slice stop of jnp.ogrid")
|
||||
step = core.concrete_or_error(None, k.step,
|
||||
"slice step of jnp.ogrid") or 1
|
||||
if np.iscomplex(step):
|
||||
output.append(linspace(start, stop, int(_abs(step))))
|
||||
else:
|
||||
output.append(arange(start, stop, step))
|
||||
|
||||
return meshgrid(*output, indexing='ij', sparse=True)
|
||||
|
||||
|
||||
ogrid = _Ogrid()
|
||||
|
||||
|
||||
@_wraps(np.i0)
|
||||
def i0(x):
|
||||
x_orig = x
|
||||
|
@ -38,7 +38,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
fmod, frexp, full, full_like, gcd, geomspace, gradient, greater,
|
||||
greater_equal, hamming, hanning, heaviside, histogram, histogram_bin_edges, histogram2d, histogramdd,
|
||||
hsplit, hstack, hypot, i0, identity, iinfo, imag,
|
||||
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
|
||||
indices, inexact, in1d, inf, inner, int16, int32, int64, int8, int_, integer,
|
||||
interp, intersect1d, invert,
|
||||
isclose, iscomplex, iscomplexobj, isfinite, isin, isinf, isnan, isneginf,
|
||||
isposinf, isreal, isrealobj, isscalar, issubdtype, issubsctype, iterable,
|
||||
@ -50,7 +50,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
nanmedian, nanpercentile, nanquantile,
|
||||
nanmax, nanmean, nanmin, nanprod, nanstd, nansum, nanvar, ndarray, ndim,
|
||||
negative, newaxis, nextafter, nonzero, not_equal, number, numpy_version,
|
||||
object_, ones, ones_like, operator_name, outer, packbits, pad, percentile,
|
||||
object_, ogrid, ones, ones_like, operator_name, outer, packbits, pad, percentile,
|
||||
pi, piecewise, polyadd, polyder, polyint, polymul, polysub, polyval, positive, power,
|
||||
prod, product, promote_types, ptp, quantile,
|
||||
rad2deg, radians, ravel, ravel_multi_index, real, reciprocal, remainder, repeat, reshape,
|
||||
|
@ -4670,6 +4670,41 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
|
||||
def testOgrid(self):
|
||||
def assertListOfArraysEqual(xs, ys):
|
||||
self.assertIsInstance(xs, list)
|
||||
self.assertIsInstance(ys, list)
|
||||
self.assertEqual(len(xs), len(ys))
|
||||
for x, y in zip(xs, ys):
|
||||
self.assertArraysEqual(x, y)
|
||||
|
||||
self.assertArraysEqual(np.ogrid[:5], jnp.ogrid[:5])
|
||||
self.assertArraysEqual(np.ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])())
|
||||
self.assertArraysEqual(np.ogrid[1:7:2], jnp.ogrid[1:7:2])
|
||||
# List of arrays
|
||||
assertListOfArraysEqual(np.ogrid[:5,], jnp.ogrid[:5,])
|
||||
assertListOfArraysEqual(np.ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3])
|
||||
assertListOfArraysEqual(np.ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3])
|
||||
assertListOfArraysEqual(np.ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11])
|
||||
# Corner cases
|
||||
self.assertArraysEqual(np.ogrid[:], jnp.ogrid[:])
|
||||
# Complex number steps
|
||||
atol = 1e-6
|
||||
rtol = 1e-6
|
||||
self.assertAllClose(np.ogrid[-1:1:5j],
|
||||
jnp.ogrid[-1:1:5j],
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
# Non-integer steps
|
||||
self.assertAllClose(np.ogrid[0:3.5:0.3],
|
||||
jnp.ogrid[0:3.5:0.3],
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
self.assertAllClose(np.ogrid[1.2:4.8:0.24],
|
||||
jnp.ogrid[1.2:4.8:0.24],
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user