mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix JAX functions to work if the default gather mode is set to "fill".
These functions really do want "clip".
This commit is contained in:
parent
4f6a6c5dec
commit
f8ba024621
@ -5431,7 +5431,8 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
collapsed_slice_dims=(axis_idx,),
|
||||
start_index_map=(axis_idx,))
|
||||
return lax.gather(a, indices[..., None], dimension_numbers=dnums,
|
||||
slice_sizes=tuple(slice_sizes))
|
||||
slice_sizes=tuple(slice_sizes),
|
||||
mode="clip")
|
||||
|
||||
|
||||
def _normalize_index(index, axis_size):
|
||||
|
@ -1028,7 +1028,7 @@ def _sph_harm(m: jnp.ndarray,
|
||||
cos_colatitude = jnp.cos(phi)
|
||||
|
||||
legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
|
||||
legendre_val = legendre[abs(m), n, jnp.arange(len(n))]
|
||||
legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")
|
||||
|
||||
angle = abs(m) * theta
|
||||
vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
|
||||
|
Loading…
x
Reference in New Issue
Block a user