Merge pull request #22960 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 661319343
Before Width: | Height: | Size: 58 KiB |
Before Width: | Height: | Size: 101 KiB |
Before Width: | Height: | Size: 63 KiB |
Before Width: | Height: | Size: 64 KiB |
Before Width: | Height: | Size: 63 KiB |
Before Width: | Height: | Size: 46 KiB |
Before Width: | Height: | Size: 64 KiB |
Before Width: | Height: | Size: 18 KiB |
Before Width: | Height: | Size: 94 KiB |
@ -132,7 +132,6 @@ exclude_patterns = [
|
||||
'notebooks/*.md',
|
||||
'pallas/quickstart.md',
|
||||
'pallas/tpu/pipelining.md',
|
||||
'pallas/tpu/distributed.md',
|
||||
'pallas/tpu/matmul.md',
|
||||
'jep/9407-type-promotion.md',
|
||||
'autodidax.md',
|
||||
@ -222,7 +221,6 @@ nb_execution_excludepatterns = [
|
||||
# Requires accelerators
|
||||
'pallas/quickstart.*',
|
||||
'pallas/tpu/pipelining.*',
|
||||
'pallas/tpu/distributed.*',
|
||||
'pallas/tpu/matmul.*',
|
||||
'sharded-computation.*',
|
||||
'distributed_data_loading.*'
|
||||
|
@ -9,4 +9,3 @@ TPU specific documentation.
|
||||
details
|
||||
pipelining
|
||||
matmul
|
||||
distributed
|
||||
|
@ -986,7 +986,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
|
||||
"""Return sample frequencies for the discrete Fourier transform.
|
||||
|
||||
JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
|
||||
for use with the outputs of :func:`~jax.numpy.fft` and :func:`~jax.numpy.ifft`.
|
||||
for use with the outputs of :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`.
|
||||
|
||||
Args:
|
||||
n: length of the FFT window
|
||||
@ -1000,8 +1000,8 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
|
||||
Array of sample frequencies, length ``n``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.rfft`
|
||||
and :func:`~jax.numpy.irfft`.
|
||||
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with
|
||||
:func:`~jax.numpy.fft.rfft` and :func:`~jax.numpy.fft.irfft`.
|
||||
"""
|
||||
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
|
||||
if isinstance(n, (list, tuple)):
|
||||
@ -1037,7 +1037,8 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
|
||||
"""Return sample frequencies for the discrete Fourier transform.
|
||||
|
||||
JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
|
||||
for use with the outputs of :func:`~jax.numpy.rfft` and :func:`~jax.numpy.irfft`.
|
||||
for use with the outputs of :func:`~jax.numpy.fft.rfft` and
|
||||
:func:`~jax.numpy.fft.irfft`.
|
||||
|
||||
Args:
|
||||
n: length of the FFT window
|
||||
@ -1051,8 +1052,8 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
|
||||
Array of sample frequencies, length ``n // 2 + 1``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.fft`
|
||||
and :func:`~jax.numpy.ifft`.
|
||||
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with
|
||||
:func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`.
|
||||
"""
|
||||
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
|
||||
if isinstance(n, (list, tuple)):
|
||||
|