mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
mgrid/ogrid: unify implementation & fill-out docstring
This commit is contained in:
parent
b935d33ccb
commit
ec7b10c4b6
@ -3253,72 +3253,89 @@ def meshgrid(*args, **kwargs):
|
||||
return output
|
||||
|
||||
|
||||
class _Mgrid:
|
||||
"""Return dense multi-dimensional "meshgrid".
|
||||
|
||||
LAX-backend implementation of `mgrid()`."""
|
||||
|
||||
class _IndexGrid:
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
start = core.concrete_or_error(None, key.start,
|
||||
"slice start of jnp.mgrid") or 0
|
||||
stop = core.concrete_or_error(None, key.stop, "slice stop of jnp.mgrid")
|
||||
step = core.concrete_or_error(None, key.step,
|
||||
"slice step of jnp.mgrid") or 1
|
||||
if np.iscomplex(step):
|
||||
return linspace(start, stop, int(_abs(step)))
|
||||
return arange(start, stop, step)
|
||||
|
||||
xi = []
|
||||
# Key is tuple of slices.
|
||||
single_slice = isinstance(key, slice)
|
||||
if single_slice:
|
||||
key = (key,)
|
||||
output = []
|
||||
for k in key:
|
||||
start = core.concrete_or_error(None, k.start,
|
||||
"slice start of jnp.mgrid") or 0
|
||||
stop = core.concrete_or_error(None, k.stop, "slice stop of jnp.mgrid")
|
||||
step = core.concrete_or_error(None, k.step,
|
||||
"slice step of jnp.mgrid") or 1
|
||||
|
||||
if np.iscomplex(step):
|
||||
xi.append(linspace(start, stop, int(_abs(step))))
|
||||
else:
|
||||
xi.append(arange(start, stop, step))
|
||||
|
||||
m_grid = meshgrid(*xi, indexing='ij')
|
||||
return stack(m_grid, 0)
|
||||
|
||||
|
||||
mgrid = _Mgrid()
|
||||
|
||||
|
||||
class _Ogrid:
|
||||
"""Return open multi-dimensional "meshgrid".
|
||||
|
||||
LAX-backend implementation of `ogrid()`."""
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
start = core.concrete_or_error(None, key.start,
|
||||
"slice start of jnp.ogrid") or 0
|
||||
stop = core.concrete_or_error(None, key.stop, "slice stop of jnp.ogrid")
|
||||
step = core.concrete_or_error(None, key.step,
|
||||
"slice step of jnp.ogrid") or 1
|
||||
if np.iscomplex(step):
|
||||
return linspace(start, stop, int(_abs(step)))
|
||||
return arange(start, stop, step)
|
||||
|
||||
output = []
|
||||
for k in key:
|
||||
start = core.concrete_or_error(None, k.start,
|
||||
"slice start of jnp.ogrid") or 0
|
||||
stop = core.concrete_or_error(None, k.stop, "slice stop of jnp.ogrid")
|
||||
step = core.concrete_or_error(None, k.step,
|
||||
"slice step of jnp.ogrid") or 1
|
||||
if np.iscomplex(step):
|
||||
output.append(linspace(start, stop, int(_abs(step))))
|
||||
else:
|
||||
output.append(arange(start, stop, step))
|
||||
if single_slice:
|
||||
return output[0]
|
||||
output = meshgrid(*output, indexing='ij', sparse=self.sparse)
|
||||
return output if self.sparse else stack(output, 0)
|
||||
|
||||
return meshgrid(*output, indexing='ij', sparse=True)
|
||||
|
||||
class _Mgrid(_IndexGrid):
|
||||
"""Return dense multi-dimensional "meshgrid".
|
||||
|
||||
LAX-backend implementation of :obj:`numpy.mgrid`. This is a convenience wrapper for
|
||||
functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=False``.
|
||||
|
||||
See Also:
|
||||
jnp.ogrid: open/sparse version of jnp.mgrid
|
||||
|
||||
Examples:
|
||||
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
|
||||
|
||||
>>> jnp.mgrid[0:4:1]
|
||||
DeviceArray([0, 1, 2, 3], dtype=int32)
|
||||
|
||||
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
|
||||
|
||||
>>> jnp.mgrid[0:1:4j]
|
||||
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
|
||||
|
||||
Multiple slices can be used to create broadcasted grids of indices:
|
||||
|
||||
>>> jnp.mgrid[:2, :3]
|
||||
DeviceArray([[[0, 0, 0],
|
||||
[1, 1, 1]],
|
||||
[[0, 1, 2],
|
||||
[0, 1, 2]]], dtype=int32)
|
||||
"""
|
||||
sparse = False
|
||||
|
||||
mgrid = _Mgrid()
|
||||
|
||||
|
||||
class _Ogrid(_IndexGrid):
|
||||
"""Return open multi-dimensional "meshgrid".
|
||||
|
||||
LAX-backend implementation of :obj:`numpy.ogrid`. This is a convenience wrapper for
|
||||
functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=True``.
|
||||
|
||||
See Also:
|
||||
jnp.mgrid: dense version of jnp.ogrid
|
||||
|
||||
Examples:
|
||||
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
|
||||
|
||||
>>> jnp.ogrid[0:4:1]
|
||||
DeviceArray([0, 1, 2, 3], dtype=int32)
|
||||
|
||||
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
|
||||
|
||||
>>> jnp.ogrid[0:1:4j]
|
||||
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
|
||||
|
||||
Multiple slices can be used to create sparse grids of indices:
|
||||
|
||||
>>> jnp.ogrid[:2, :3]
|
||||
[DeviceArray([[0],
|
||||
[1]], dtype=int32),
|
||||
DeviceArray([[0, 1, 2]], dtype=int32)]
|
||||
"""
|
||||
sparse = True
|
||||
|
||||
|
||||
ogrid = _Ogrid()
|
||||
|
Loading…
x
Reference in New Issue
Block a user