Merge pull request #10935 from jakevdp:x64-linspace

PiperOrigin-RevId: 452388829
This commit is contained in:
jax authors 2022-06-01 14:44:35 -07:00
commit e9542bb61d

View File

@ -2163,6 +2163,7 @@ def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
# but can lead to rounding errors for integer outputs.
real_dtype = finfo(computation_dtype).dtype
step = reshape(lax.iota(real_dtype, div), iota_shape) / div
step = step.astype(computation_dtype)
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
reshape(broadcast_stop, bounds_shape) * step)
@ -2231,6 +2232,7 @@ def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
stop = asarray(stop, dtype=computation_dtype)
# follow the numpy geomspace convention for negative and complex endpoints
signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2
signflip = signflip.astype(computation_dtype)
res = signflip * logspace(log10(signflip * start),
log10(signflip * stop), num,
endpoint=endpoint, base=10.0,