jnp.ufunc: minor cleanups & test fixes

This commit is contained in:
Jake VanderPlas 2023-08-14 15:19:46 -07:00
parent d6e06f4476
commit 61f50bd3b6
2 changed files with 6 additions and 6 deletions

View File

@ -101,7 +101,7 @@ class ufunc:
if initial is None:
initial = self.identity
if dtype is None:
dtype = jax.eval_shape(self, lax_internal._one(arr), lax_internal._one(arr)).dtype
dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype
if isinstance(axis, tuple):
axis = tuple(canonicalize_axis(a, arr.ndim) for a in axis)
@ -161,7 +161,7 @@ class ufunc:
arr = lax_internal.asarray(arr)
if dtype is None:
dtype = jax.eval_shape(self, lax_internal._one(arr), lax_internal._one(arr)).dtype
dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype
if axis is None or isinstance(axis, tuple):
raise ValueError("accumulate does not allow multiple axes")
@ -187,7 +187,7 @@ class ufunc:
def _at_via_scan(self, a, indices, *args):
check_arraylike(f"{self.__name__}.at", a, *args)
dtype = jax.eval_shape(self, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
a = lax_internal.asarray(a).astype(dtype)
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
indices = _eliminate_deprecated_list_indexing(indices)

View File

@ -131,7 +131,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in broadcast_compatible_shapes
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
)
@ -150,7 +150,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in broadcast_compatible_shapes
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
@ -191,7 +191,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in broadcast_compatible_shapes
for shape in nonscalar_shapes
for axis in [*range(-len(shape), len(shape))]],
idx_shape=[(0,), (3,), (5,)],
dtype=jtu.dtypes.floating,