1138 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
51db1cfd0e [docs] Rename "JAX in Parallelism" files so the URL matches the title. 2022-12-01 19:53:31 +00:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Eltayeb Ahmed
b5dc0638a2
Fix typo in docs/multi_process.md 2022-11-30 16:38:27 +00:00
jax authors
7f469afe8a Merge pull request #12877 from LenaMartens:check-error-types
PiperOrigin-RevId: 491915619
2022-11-30 07:52:31 -08:00
jax authors
22f67d62dc Merge pull request #13440 from froystig:part-rng-parallel-doc
PiperOrigin-RevId: 491791000
2022-11-29 18:50:04 -08:00
jax authors
dba050a9a1 Merge pull request #13288 from imh:betaln_accuracy
PiperOrigin-RevId: 491778273
2022-11-29 17:26:24 -08:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
Roy Frostig
b6fd3ff9d7 describe partitionable RNG mode in parallelism doc
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-11-29 14:31:06 -08:00
Ian Horn
a35fe206a1 Added more accurate version of the betaln function. 2022-11-29 11:56:07 -08:00
jax authors
87408c769a Merge pull request #13421 from sharadmv:fix-rtd
PiperOrigin-RevId: 491404001
2022-11-28 11:53:17 -08:00
Sharad Vikram
c0c8eed6fa Pin IPython version in docs build to avoid RTD warning 2022-11-28 11:22:41 -08:00
TJ
7456b66d35 changed both "currently available" to "total" for mem allocation doc 2022-11-28 09:30:25 -08:00
TJ
4011d17965 Change documentation to state the correct usage of XLA_PYTHON_CLIENT_MEM_FRACTION 2022-11-27 20:05:38 -08:00
lenamartens
e4757e8410 Rewrite Checkify to support tracking different error types.
In general, behavior should remain the same and this is not a breaking
change.

There are some minor changes to the API:
  - checkify.ErrorCategory has changed type: it's no longer an Enum, but
    the JaxException type. These have not been exposed as part of the
    public API.
  - some attributes on Error have changed and made private
  - The raised error has changed type (JaxRuntimeError), and will have a
    different traceback (pointing to the origin of the error + where the
    error value was raised).
  - `checkify.check` now supports formating error message with variable
    size runtime info!

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-11-25 15:31:54 +00:00
jax authors
f33d5514c9 Merge pull request #13367 from froystig:custom-derivatives-docfix
PiperOrigin-RevId: 490383906
2022-11-22 18:30:42 -08:00
Roy Frostig
fcce6b102c remove cotangent negation in custom VJP example
This was originally intended to show that we can change the VJP by
customizing it, but the algebraic incorrectness is confusing.
2022-11-22 17:55:22 -08:00
Igor Saprykin
be527b62d7 Make it clear that fun is a function rather than a noun.
PiperOrigin-RevId: 490370522
2022-11-22 17:02:21 -08:00
Yash Katariya
1824be772e
Update jax_array_migration.md 2022-11-18 08:11:22 -08:00
Yash Katariya
29d75324a3
Add a date till which jax.Array can be disabled 2022-11-18 08:09:31 -08:00
Peter Hawkins
9f2a6acb61 Revert: Add deprecation warnings to DA, SDA and GDA.
This change is currently overly noisy for users.

PiperOrigin-RevId: 489455729
2022-11-18 06:06:13 -08:00
Yash Katariya
52a2428073 Add deprecation warnings to DA, SDA and GDA.
PiperOrigin-RevId: 489314189
2022-11-17 14:51:29 -08:00
jax authors
bbc3c6aa89 Merge pull request #13266 from skye:tensorboard_docs
PiperOrigin-RevId: 488975887
2022-11-16 10:27:15 -08:00
Skye Wanderman-Milne
66f3e0da9c Update TensorBoard install instructions for profiling 2022-11-16 01:02:45 +00:00
yashkatariya
aca7e4ade2 jax.Array tutorial 2022-11-15 16:49:17 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Yash Katariya
e6c4d4a30e Add docstrings for Sharding classes. Right now I am only documenting Sharding, XLACompatibleSharding, MeshPspecSharding and SingleDeviceSharding.
Also moving jax_array_migration guide to reference documentation.

PiperOrigin-RevId: 488489503
2022-11-14 15:47:46 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Yash Katariya
6897d37562 Add docstrings for jax.Array APIs make_array_from_callback and make_array_from_single_device_arrays.
PiperOrigin-RevId: 487929688
2022-11-11 15:21:10 -08:00
jax authors
71f92a7cd5 Merge pull request #13182 from canyon289:docs_update
PiperOrigin-RevId: 487808462
2022-11-11 06:44:29 -08:00
Yash Katariya
f0c0689a8a
Remove internal information 2022-11-10 19:09:19 -08:00
Yash Katariya
73935a5bd1
Update jax_array_migration.md 2022-11-10 17:23:16 -08:00
Yash Katariya
aa66b939f9
Fix the docs build 2022-11-10 17:08:57 -08:00
Yash Katariya
b49a1bda15 Add jax.Array migration doc to OSS
PiperOrigin-RevId: 487673643
2022-11-10 16:46:30 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Ravin Kumar
cb3762e5dd Consolidate links in JAX documentation
Move notes down
2022-11-09 15:44:52 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
jax authors
85f43dd902 Merge pull request #13061 from nouiz:test_doc
PiperOrigin-RevId: 486816419
2022-11-07 18:23:41 -08:00
8bitmp3
636cac882c
Increase visibility in index Multi-Host Multi-Process Envs Guide 2022-11-04 19:22:39 +00:00
Jake VanderPlas
acdb545941 CI: add absl-py to docs/requirements.txt 2022-11-04 11:28:15 -07:00
Frederic Bastien
fdc4c1d4ca More documentation about GPU tests. 2022-11-02 06:43:56 -07:00
Jake VanderPlas
8bde3a0a70 Point to ndarray.at from docstring of unimplemented jnp.put & jnp.place 2022-10-28 14:13:36 -07:00
jax authors
a08ced86f3 Merge pull request #12991 from jakevdp:fix-faq
PiperOrigin-RevId: 484042682
2022-10-26 12:39:00 -07:00
Jake VanderPlas
e9194b26b0 FAQ: fix JIT numerics discussion 2022-10-26 11:30:17 -07:00
jax authors
db2c8c1bdb Merge pull request #12994 from hawkinsp:docfix
PiperOrigin-RevId: 484015353
2022-10-26 10:55:14 -07:00
Peter Hawkins
71a384d25e Clarify in JAX Basics that JAX array creation is also an operation that requires accelerator dispatch and converting to a regular Python type is a blocking operation. 2022-10-26 13:38:17 -04:00
Adam Paszke
6e43ce363e Remove a TODO from the xmap tutorial
xeinsum is already powerful enough to support the example.
2022-10-26 15:44:06 +00:00
Ikko Ashimine
28def736d1
Fix typo in 9419-jax-versioning.md
overriden -> overridden
2022-10-25 03:26:48 +09:00
jax authors
8f2f9f4563 Merge pull request #12646 from adrn:truncnorm
PiperOrigin-RevId: 483425197
2022-10-24 10:41:51 -07:00
Adrian Price-Whelan
5784d61048 implement truncnorm in jax.scipy.stats
fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
2022-10-22 15:48:20 -04:00