1113 Commits

Author SHA1 Message Date
Yash Katariya
e6c4d4a30e Add docstrings for Sharding classes. Right now I am only documenting Sharding, XLACompatibleSharding, MeshPspecSharding and SingleDeviceSharding.
Also moving jax_array_migration guide to reference documentation.

PiperOrigin-RevId: 488489503
2022-11-14 15:47:46 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Yash Katariya
6897d37562 Add docstrings for jax.Array APIs make_array_from_callback and make_array_from_single_device_arrays.
PiperOrigin-RevId: 487929688
2022-11-11 15:21:10 -08:00
jax authors
71f92a7cd5 Merge pull request #13182 from canyon289:docs_update
PiperOrigin-RevId: 487808462
2022-11-11 06:44:29 -08:00
Yash Katariya
f0c0689a8a
Remove internal information 2022-11-10 19:09:19 -08:00
Yash Katariya
73935a5bd1
Update jax_array_migration.md 2022-11-10 17:23:16 -08:00
Yash Katariya
aa66b939f9
Fix the docs build 2022-11-10 17:08:57 -08:00
Yash Katariya
b49a1bda15 Add jax.Array migration doc to OSS
PiperOrigin-RevId: 487673643
2022-11-10 16:46:30 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Ravin Kumar
cb3762e5dd Consolidate links in JAX documentation
Move notes down
2022-11-09 15:44:52 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
jax authors
85f43dd902 Merge pull request #13061 from nouiz:test_doc
PiperOrigin-RevId: 486816419
2022-11-07 18:23:41 -08:00
8bitmp3
636cac882c
Increase visibility in index Multi-Host Multi-Process Envs Guide 2022-11-04 19:22:39 +00:00
Jake VanderPlas
acdb545941 CI: add absl-py to docs/requirements.txt 2022-11-04 11:28:15 -07:00
Frederic Bastien
fdc4c1d4ca More documentation about GPU tests. 2022-11-02 06:43:56 -07:00
Jake VanderPlas
8bde3a0a70 Point to ndarray.at from docstring of unimplemented jnp.put & jnp.place 2022-10-28 14:13:36 -07:00
jax authors
a08ced86f3 Merge pull request #12991 from jakevdp:fix-faq
PiperOrigin-RevId: 484042682
2022-10-26 12:39:00 -07:00
Jake VanderPlas
e9194b26b0 FAQ: fix JIT numerics discussion 2022-10-26 11:30:17 -07:00
jax authors
db2c8c1bdb Merge pull request #12994 from hawkinsp:docfix
PiperOrigin-RevId: 484015353
2022-10-26 10:55:14 -07:00
Peter Hawkins
71a384d25e Clarify in JAX Basics that JAX array creation is also an operation that requires accelerator dispatch and converting to a regular Python type is a blocking operation. 2022-10-26 13:38:17 -04:00
Adam Paszke
6e43ce363e Remove a TODO from the xmap tutorial
xeinsum is already powerful enough to support the example.
2022-10-26 15:44:06 +00:00
Ikko Ashimine
28def736d1
Fix typo in 9419-jax-versioning.md
overriden -> overridden
2022-10-25 03:26:48 +09:00
jax authors
8f2f9f4563 Merge pull request #12646 from adrn:truncnorm
PiperOrigin-RevId: 483425197
2022-10-24 10:41:51 -07:00
Adrian Price-Whelan
5784d61048 implement truncnorm in jax.scipy.stats
fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
2022-10-22 15:48:20 -04:00
Matthew Johnson
6a3d2a0dde update docs to point to jax.nn.standardize
Fixes #12909
2022-10-20 21:25:37 -07:00
Jake VanderPlas
4aceb81570 Add docs & changelog for jax.scipy.stats.mode 2022-10-20 15:55:57 -07:00
Serge Durand
27360b9988 Fix book links 2022-10-03 13:59:10 +02:00
jbushago
2038988783
Fix typo in faq.rst.
Fixed a small typo in the FAQ: "inthe" -> "in the".
2022-09-30 14:14:05 -04:00
jax authors
d2fcfb6b83 Merge pull request #12407 from hirwa-nshuti:docs-fix
PiperOrigin-RevId: 476467728
2022-09-23 14:51:11 -07:00
Felix Hirwa Nshuti
820efab6fa removed repeated nan_to_num in docs 2022-09-23 06:23:09 +00:00
Sharad Vikram
99d4d8b89a Update debugging docs to have sharding visualization 2022-09-22 19:42:36 -07:00
jax authors
bc08381da3 Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1 Add generic interface for auto initialization of distributed JAX service
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Jake VanderPlas
fce1099997 Update JEP-12049 implementation discussion 2022-09-20 09:44:29 -07:00
Jake VanderPlas
be65694ac6 docs: avoid deprecated matplotlib axis creation 2022-09-19 12:55:18 -07:00
Jake VanderPlas
5829c6ae9d Change case of typing.Dtype -> typing.DType
This follows the convention used in numpy.typing.DType.
2022-09-14 15:03:55 -07:00
Jake VanderPlas
358363e17f JEP 12049: Type Annotation Roadmap 2022-09-13 09:14:48 -07:00
Jake VanderPlas
0fb462efd7 Add jax.print_environment_info() 2022-09-12 15:39:33 -07:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
Filippo Vicentini
52236adeed
Show link to GitHub repo in navbar 2022-09-05 13:41:49 +02:00
Roy Frostig
bb68fbeefa write in-process AOT walkthrough doc 2022-09-02 13:02:25 -07:00
Roy Frostig
43db06491c write and generate package API documentation for jax.stages 2022-09-01 19:26:53 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
jax authors
cdd16e3bbd Fix typo in jax.debug.print() documentation.
PiperOrigin-RevId: 471096544
2022-08-30 14:51:12 -07:00
jax authors
bff98655a1 Merge pull request #12112 from sharadmv:docs-typo
PiperOrigin-RevId: 470762998
2022-08-29 10:58:40 -07:00
Sharad Vikram
27b313b287 Fix typo in debugging docs 2022-08-29 10:32:08 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
Peter Hawkins
5527966b27 [JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.

PiperOrigin-RevId: 469984029
2022-08-25 07:28:27 -07:00