Merge pull request #22960 from rajasekharporeddy:testbranch2

PiperOrigin-RevId: 661319343
This commit is contained in:
jax authors 2024-08-09 10:37:44 -07:00
commit c207ad4c04
14 changed files with 7 additions and 3288 deletions

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 58 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 101 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 63 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 64 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 63 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 46 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 64 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 18 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 94 KiB

View File

@ -132,7 +132,6 @@ exclude_patterns = [
'notebooks/*.md',
'pallas/quickstart.md',
'pallas/tpu/pipelining.md',
'pallas/tpu/distributed.md',
'pallas/tpu/matmul.md',
'jep/9407-type-promotion.md',
'autodidax.md',
@ -222,7 +221,6 @@ nb_execution_excludepatterns = [
# Requires accelerators
'pallas/quickstart.*',
'pallas/tpu/pipelining.*',
'pallas/tpu/distributed.*',
'pallas/tpu/matmul.*',
'sharded-computation.*',
'distributed_data_loading.*'

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -9,4 +9,3 @@ TPU specific documentation.
details
pipelining
matmul
distributed

View File

@ -986,7 +986,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
"""Return sample frequencies for the discrete Fourier transform.
JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
for use with the outputs of :func:`~jax.numpy.fft` and :func:`~jax.numpy.ifft`.
for use with the outputs of :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`.
Args:
n: length of the FFT window
@ -1000,8 +1000,8 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
Array of sample frequencies, length ``n``.
See also:
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.rfft`
and :func:`~jax.numpy.irfft`.
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with
:func:`~jax.numpy.fft.rfft` and :func:`~jax.numpy.fft.irfft`.
"""
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
@ -1037,7 +1037,8 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
"""Return sample frequencies for the discrete Fourier transform.
JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
for use with the outputs of :func:`~jax.numpy.rfft` and :func:`~jax.numpy.irfft`.
for use with the outputs of :func:`~jax.numpy.fft.rfft` and
:func:`~jax.numpy.fft.irfft`.
Args:
n: length of the FFT window
@ -1051,8 +1052,8 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
Array of sample frequencies, length ``n // 2 + 1``.
See also:
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.fft`
and :func:`~jax.numpy.ifft`.
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with
:func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`.
"""
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):