Fix JAX functions to work if the default gather mode is set to "fill".

These functions really do want "clip".
This commit is contained in:
Peter Hawkins 2021-09-30 14:21:05 -04:00
parent 4f6a6c5dec
commit f8ba024621
2 changed files with 3 additions and 2 deletions

View File

@ -5431,7 +5431,8 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
collapsed_slice_dims=(axis_idx,),
start_index_map=(axis_idx,))
return lax.gather(a, indices[..., None], dimension_numbers=dnums,
slice_sizes=tuple(slice_sizes))
slice_sizes=tuple(slice_sizes),
mode="clip")
def _normalize_index(index, axis_size):

View File

@ -1028,7 +1028,7 @@ def _sph_harm(m: jnp.ndarray,
cos_colatitude = jnp.cos(phi)
legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
legendre_val = legendre[abs(m), n, jnp.arange(len(n))]
legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")
angle = abs(m) * theta
vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))