Merge pull request #13982 from Edenhofer:fix_zero_length_meshgrid

PiperOrigin-RevId: 501827615
This commit is contained in:
jax authors 2023-01-13 06:06:52 -08:00
commit 44a044f936
2 changed files with 6 additions and 1 deletions

View File

@ -54,7 +54,11 @@ class _IndexGrid(abc.ABC):
with jax.numpy_dtype_promotion('standard'):
output = _promote_dtypes(*output)
output_arr = jnp.meshgrid(*output, indexing='ij', sparse=self.sparse)
return output_arr if self.sparse else jnp.stack(output_arr, 0)
if self.sparse:
return output_arr
if len(output_arr) == 0:
return jnp.arange(0)
return jnp.stack(output_arr, 0)
class _Mgrid(_IndexGrid):

View File

@ -4353,6 +4353,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# wrap indexer for appropriate dtype defaults.
np_mgrid = _indexer_with_default_outputs(np.mgrid)
assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0)
assertAllEqual(np_mgrid[()], jnp.mgrid[()])
assertAllEqual(np_mgrid[:4], jnp.mgrid[:4])
assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,])
assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])())