mgrid/ogrid: unify implementation & fill-out docstring

This commit is contained in:
Jake VanderPlas 2021-04-06 12:58:24 -07:00
parent b935d33ccb
commit ec7b10c4b6

View File

@ -3253,72 +3253,89 @@ def meshgrid(*args, **kwargs):
return output
class _Mgrid:
"""Return dense multi-dimensional "meshgrid".
LAX-backend implementation of `mgrid()`."""
class _IndexGrid:
def __getitem__(self, key):
if isinstance(key, slice):
start = core.concrete_or_error(None, key.start,
"slice start of jnp.mgrid") or 0
stop = core.concrete_or_error(None, key.stop, "slice stop of jnp.mgrid")
step = core.concrete_or_error(None, key.step,
"slice step of jnp.mgrid") or 1
if np.iscomplex(step):
return linspace(start, stop, int(_abs(step)))
return arange(start, stop, step)
xi = []
# Key is tuple of slices.
single_slice = isinstance(key, slice)
if single_slice:
key = (key,)
output = []
for k in key:
start = core.concrete_or_error(None, k.start,
"slice start of jnp.mgrid") or 0
stop = core.concrete_or_error(None, k.stop, "slice stop of jnp.mgrid")
step = core.concrete_or_error(None, k.step,
"slice step of jnp.mgrid") or 1
if np.iscomplex(step):
xi.append(linspace(start, stop, int(_abs(step))))
else:
xi.append(arange(start, stop, step))
m_grid = meshgrid(*xi, indexing='ij')
return stack(m_grid, 0)
mgrid = _Mgrid()
class _Ogrid:
"""Return open multi-dimensional "meshgrid".
LAX-backend implementation of `ogrid()`."""
def __getitem__(self, key):
if isinstance(key, slice):
start = core.concrete_or_error(None, key.start,
"slice start of jnp.ogrid") or 0
stop = core.concrete_or_error(None, key.stop, "slice stop of jnp.ogrid")
step = core.concrete_or_error(None, key.step,
"slice step of jnp.ogrid") or 1
if np.iscomplex(step):
return linspace(start, stop, int(_abs(step)))
return arange(start, stop, step)
output = []
for k in key:
start = core.concrete_or_error(None, k.start,
"slice start of jnp.ogrid") or 0
stop = core.concrete_or_error(None, k.stop, "slice stop of jnp.ogrid")
step = core.concrete_or_error(None, k.step,
"slice step of jnp.ogrid") or 1
if np.iscomplex(step):
output.append(linspace(start, stop, int(_abs(step))))
else:
output.append(arange(start, stop, step))
if single_slice:
return output[0]
output = meshgrid(*output, indexing='ij', sparse=self.sparse)
return output if self.sparse else stack(output, 0)
return meshgrid(*output, indexing='ij', sparse=True)
class _Mgrid(_IndexGrid):
"""Return dense multi-dimensional "meshgrid".
LAX-backend implementation of :obj:`numpy.mgrid`. This is a convenience wrapper for
functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=False``.
See Also:
jnp.ogrid: open/sparse version of jnp.mgrid
Examples:
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
>>> jnp.mgrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
>>> jnp.mgrid[0:1:4j]
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create broadcasted grids of indices:
>>> jnp.mgrid[:2, :3]
DeviceArray([[[0, 0, 0],
[1, 1, 1]],
[[0, 1, 2],
[0, 1, 2]]], dtype=int32)
"""
sparse = False
mgrid = _Mgrid()
class _Ogrid(_IndexGrid):
"""Return open multi-dimensional "meshgrid".
LAX-backend implementation of :obj:`numpy.ogrid`. This is a convenience wrapper for
functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=True``.
See Also:
jnp.mgrid: dense version of jnp.ogrid
Examples:
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
>>> jnp.ogrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
>>> jnp.ogrid[0:1:4j]
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create sparse grids of indices:
>>> jnp.ogrid[:2, :3]
[DeviceArray([[0],
[1]], dtype=int32),
DeviceArray([[0, 1, 2]], dtype=int32)]
"""
sparse = True
ogrid = _Ogrid()