mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jnp.ufunc: minor cleanups & test fixes
This commit is contained in:
parent
d6e06f4476
commit
61f50bd3b6
@ -101,7 +101,7 @@ class ufunc:
|
||||
if initial is None:
|
||||
initial = self.identity
|
||||
if dtype is None:
|
||||
dtype = jax.eval_shape(self, lax_internal._one(arr), lax_internal._one(arr)).dtype
|
||||
dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype
|
||||
|
||||
if isinstance(axis, tuple):
|
||||
axis = tuple(canonicalize_axis(a, arr.ndim) for a in axis)
|
||||
@ -161,7 +161,7 @@ class ufunc:
|
||||
arr = lax_internal.asarray(arr)
|
||||
|
||||
if dtype is None:
|
||||
dtype = jax.eval_shape(self, lax_internal._one(arr), lax_internal._one(arr)).dtype
|
||||
dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype
|
||||
|
||||
if axis is None or isinstance(axis, tuple):
|
||||
raise ValueError("accumulate does not allow multiple axes")
|
||||
@ -187,7 +187,7 @@ class ufunc:
|
||||
|
||||
def _at_via_scan(self, a, indices, *args):
|
||||
check_arraylike(f"{self.__name__}.at", a, *args)
|
||||
dtype = jax.eval_shape(self, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
|
||||
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
|
||||
a = lax_internal.asarray(a).astype(dtype)
|
||||
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
|
||||
indices = _eliminate_deprecated_list_indexing(indices)
|
||||
|
@ -131,7 +131,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in broadcast_compatible_shapes
|
||||
for shape in nonscalar_shapes
|
||||
for axis in [None, *range(-len(shape), len(shape))]],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
@ -150,7 +150,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in broadcast_compatible_shapes
|
||||
for shape in nonscalar_shapes
|
||||
for axis in range(-len(shape), len(shape))],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
@ -191,7 +191,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
SCALAR_FUNCS,
|
||||
[{'shape': shape, 'axis': axis}
|
||||
for shape in broadcast_compatible_shapes
|
||||
for shape in nonscalar_shapes
|
||||
for axis in [*range(-len(shape), len(shape))]],
|
||||
idx_shape=[(0,), (3,), (5,)],
|
||||
dtype=jtu.dtypes.floating,
|
||||
|
Loading…
x
Reference in New Issue
Block a user