mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #10935 from jakevdp:x64-linspace
PiperOrigin-RevId: 452388829
This commit is contained in:
commit
e9542bb61d
@ -2163,6 +2163,7 @@ def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
|
||||
# but can lead to rounding errors for integer outputs.
|
||||
real_dtype = finfo(computation_dtype).dtype
|
||||
step = reshape(lax.iota(real_dtype, div), iota_shape) / div
|
||||
step = step.astype(computation_dtype)
|
||||
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
|
||||
reshape(broadcast_stop, bounds_shape) * step)
|
||||
|
||||
@ -2231,6 +2232,7 @@ def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
|
||||
stop = asarray(stop, dtype=computation_dtype)
|
||||
# follow the numpy geomspace convention for negative and complex endpoints
|
||||
signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2
|
||||
signflip = signflip.astype(computation_dtype)
|
||||
res = signflip * logspace(log10(signflip * start),
|
||||
log10(signflip * stop), num,
|
||||
endpoint=endpoint, base=10.0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user