From 61f50bd3b66192d69487da249346c4116c88b808 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 14 Aug 2023 15:19:46 -0700 Subject: [PATCH] jnp.ufunc: minor cleanups & test fixes --- jax/_src/numpy/ufunc_api.py | 6 +++--- tests/lax_numpy_ufuncs_test.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 698aa63ca..4b74c7675 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -101,7 +101,7 @@ class ufunc: if initial is None: initial = self.identity if dtype is None: - dtype = jax.eval_shape(self, lax_internal._one(arr), lax_internal._one(arr)).dtype + dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype if isinstance(axis, tuple): axis = tuple(canonicalize_axis(a, arr.ndim) for a in axis) @@ -161,7 +161,7 @@ class ufunc: arr = lax_internal.asarray(arr) if dtype is None: - dtype = jax.eval_shape(self, lax_internal._one(arr), lax_internal._one(arr)).dtype + dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype if axis is None or isinstance(axis, tuple): raise ValueError("accumulate does not allow multiple axes") @@ -187,7 +187,7 @@ class ufunc: def _at_via_scan(self, a, indices, *args): check_arraylike(f"{self.__name__}.at", a, *args) - dtype = jax.eval_shape(self, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype + dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype a = lax_internal.asarray(a).astype(dtype) args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) indices = _eliminate_deprecated_list_indexing(indices) diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 73e5cd010..5f8e684d4 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -131,7 +131,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} - for shape in broadcast_compatible_shapes + for shape in nonscalar_shapes for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) @@ -150,7 +150,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} - for shape in broadcast_compatible_shapes + for shape in nonscalar_shapes for axis in range(-len(shape), len(shape))], dtype=jtu.dtypes.floating, ) @@ -191,7 +191,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} - for shape in broadcast_compatible_shapes + for shape in nonscalar_shapes for axis in [*range(-len(shape), len(shape))]], idx_shape=[(0,), (3,), (5,)], dtype=jtu.dtypes.floating,