223 Commits

Author SHA1 Message Date
Matthew Johnson
8b219d5f7b [shard-map] add user tutorial 2024-01-18 15:30:13 -08:00
Jake VanderPlas
7d8c358fce Fix wording in TFDS example 2023-11-03 11:31:48 -07:00
Jake VanderPlas
389eb97a7c CI: update pre-commit hooks to latest version 2023-10-30 09:12:24 -07:00
parikshit adhikari
3c283db2fc fixed typo 'primtive' as 'primitive' in How_JAX_primitives_work.ipynb 2023-10-13 21:05:43 +05:45
parikshit adhikari
ba1af0114d fix: typo inside docs/notebooks/How_JAX_primitives_work.md 2023-10-13 08:54:51 +05:45
8bitmp3
7e2e4a08e4 Add jit/shard_map banner to xmap docs
Add jit/shard_map banner to xmap docs
2023-10-12 22:06:15 +00:00
Peter Hawkins
d8a0227e86 Simplify the torch data loader collate function using tree_map.
Fixes https://github.com/google/jax/issues/1004
2023-10-04 14:59:06 -04:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Jake VanderPlas
130a53f2a2 DOC: re-enable execution of thinking_in_jax.ipynb 2023-08-24 09:23:26 -07:00
8bitmp3
03ebdb7454 Add Kaggle TPU VM link, update Distributed Arrays and Auto Parallelization guide 2023-08-21 21:17:02 +00:00
8bitmp3
e0bd8a164d Add Kaggle TPU VM link, update Distributed Arrays and Auto Parallelization guide 2023-08-21 21:09:58 +00:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Peter Hawkins
47651c6a59 Remove uses of XLA translation rules.
Remove translation_rule argument to standard_primitive.

PiperOrigin-RevId: 557220350
2023-08-15 12:53:36 -07:00
Jake VanderPlas
e4701b2451 DOC: callback doc is no longer work-in-progress 2023-08-07 11:16:56 -07:00
Jake VanderPlas
7bb8312f82 CI: update jupytext to v0.14.7 2023-07-24 11:51:45 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
34d9f5a9ae Add a CI presubmit that renders the documentation. 2023-06-20 09:29:25 -04:00
Jake VanderPlas
47ae5bddd7 Mark jax.abstract_arrays as deprecated 2023-06-07 23:36:40 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Matthew Johnson
690071f1de fix custom_jvp docs typo 2023-04-15 14:51:33 -07:00
jax authors
5456152f0b Fix typo: "__W__ - spatial height" -> "__W__ - spatial width"
PiperOrigin-RevId: 523365127
2023-04-11 04:37:26 -07:00
Jake VanderPlas
4473ebc9fc Add documentation for callback functions 2023-04-07 07:05:46 -07:00
Jake VanderPlas
c7c9cb652e Sharp bits: refer to ndarray.at in out-of-bound indexing discussion 2023-03-17 13:29:05 -07:00
Peter Hawkins
71f120beed Add "Open in Kaggle" buttons to Jupyter notebooks. 2023-03-01 13:15:42 -05:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
f0d816f899 Merge pull request #14673 from nouiz:gpu_doc
PiperOrigin-RevId: 512669380
2023-02-27 10:49:52 -08:00
Frederic Bastien
ec817974aa Add a new link instead of a TODO. 2023-02-24 13:54:16 -08:00
Matthew Johnson
c22da81d5d fixes from reviewers 2023-02-23 15:06:55 -08:00
Matthew Johnson
141996ec11 add remat tutorial docs 2023-02-23 14:37:52 -08:00
Peter Hawkins
0af9fff5ca Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 510027595
2023-02-15 21:03:03 -08:00
Jake VanderPlas
c4ec2996af Sharp bits: mention alternatives to lax.cond 2023-02-02 13:19:26 -08:00
Jake VanderPlas
fd71794633 DOC: organize content in advanced guide 2023-01-26 08:20:53 -08:00
Jake VanderPlas
20b55a119e CI: update jupytext version 2023-01-23 14:42:03 -08:00
Jake VanderPlas
9e355a6606 Sharp Bits: add section on Dynamic shapes 2023-01-19 11:37:03 -08:00
Jake VanderPlas
b2141229e3 Sharp bits: raise exceptions directly 2023-01-17 14:13:38 -08:00
8bitmp3
2bf6d3dace Update jax.xmap notebook title (Named axes and easy-to-revise parallelism with xmap) 2023-01-07 00:05:46 +00:00
yashkatariya
008a6918e0 Remove experimental endpoints and update to point to 0.4.1 2022-12-13 11:53:56 -08:00
Matthew Johnson
1185c895ca in jax.Array notebook, polish beginning and tweak title and some wording 2022-12-10 22:16:54 -08:00
Jake VanderPlas
df02d7035e DOC: add example of pure_callback with custom_jvp 2022-12-09 12:43:04 -08:00
Jake VanderPlas
7b59ce2f89 DOC: pre-execute the quickstart notebook on GPU 2022-12-06 13:24:02 -08:00
yashkatariya
70d50814b1 Add cross-linking for the migration guide and the parallelism with JAX
tutorial

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-12-01 14:42:59 -08:00
Skye Wanderman-Milne
51db1cfd0e [docs] Rename "JAX in Parallelism" files so the URL matches the title. 2022-12-01 19:53:31 +00:00
Roy Frostig
b6fd3ff9d7 describe partitionable RNG mode in parallelism doc
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-11-29 14:31:06 -08:00
Roy Frostig
fcce6b102c remove cotangent negation in custom VJP example
This was originally intended to show that we can change the VJP by
customizing it, but the algebraic incorrectness is confusing.
2022-11-22 17:55:22 -08:00
yashkatariya
aca7e4ade2 jax.Array tutorial 2022-11-15 16:49:17 -08:00
Adam Paszke
6e43ce363e Remove a TODO from the xmap tutorial
xeinsum is already powerful enough to support the example.
2022-10-26 15:44:06 +00:00
Serge Durand
27360b9988 Fix book links 2022-10-03 13:59:10 +02:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
be65694ac6 docs: avoid deprecated matplotlib axis creation 2022-09-19 12:55:18 -07:00
Jake VanderPlas
eeb9b5f1f6 pre-commit hook: update flake8, mypy, & jupytext 2022-08-15 15:32:45 -07:00