mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13982 from Edenhofer:fix_zero_length_meshgrid
PiperOrigin-RevId: 501827615
This commit is contained in:
commit
44a044f936
@ -54,7 +54,11 @@ class _IndexGrid(abc.ABC):
|
||||
with jax.numpy_dtype_promotion('standard'):
|
||||
output = _promote_dtypes(*output)
|
||||
output_arr = jnp.meshgrid(*output, indexing='ij', sparse=self.sparse)
|
||||
return output_arr if self.sparse else jnp.stack(output_arr, 0)
|
||||
if self.sparse:
|
||||
return output_arr
|
||||
if len(output_arr) == 0:
|
||||
return jnp.arange(0)
|
||||
return jnp.stack(output_arr, 0)
|
||||
|
||||
|
||||
class _Mgrid(_IndexGrid):
|
||||
|
@ -4353,6 +4353,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# wrap indexer for appropriate dtype defaults.
|
||||
np_mgrid = _indexer_with_default_outputs(np.mgrid)
|
||||
assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0)
|
||||
assertAllEqual(np_mgrid[()], jnp.mgrid[()])
|
||||
assertAllEqual(np_mgrid[:4], jnp.mgrid[:4])
|
||||
assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,])
|
||||
assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])())
|
||||
|
Loading…
x
Reference in New Issue
Block a user