1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 07:36:08 +00:00

Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax

PiperOrigin-RevId: 676843138
This commit is contained in:
Michael Hudgins 2024-09-20 07:51:48 -07:00 committed by jax authors
parent afaa3bf43c
commit d4d1518c3d
257 changed files with 906 additions and 906 deletions
.github
CHANGELOG.mdCITATION.bibREADME.md
cloud_tpu_colabs
docs
jax

@ -20,11 +20,11 @@ body:
* If you prefer a non-templated issue report, click [here][Raw report]. * If you prefer a non-templated issue report, click [here][Raw report].
[Discussions]: https://github.com/google/jax/discussions [Discussions]: https://github.com/jax-ml/jax/discussions
[issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues
[Raw report]: http://github.com/google/jax/issues/new [Raw report]: http://github.com/jax-ml/jax/issues/new
- type: textarea - type: textarea
attributes: attributes:
label: Description label: Description

@ -1,5 +1,5 @@
blank_issues_enabled: false blank_issues_enabled: false
contact_links: contact_links:
- name: Have questions or need support? - name: Have questions or need support?
url: https://github.com/google/jax/discussions url: https://github.com/jax-ml/jax/discussions
about: Please ask questions on the Discussions tab about: Please ask questions on the Discussions tab

@ -84,7 +84,7 @@ jobs:
failure() failure()
&& steps.status.outcome == 'failure' && steps.status.outcome == 'failure'
&& github.event_name == 'schedule' && github.event_name == 'schedule'
&& github.repository == 'google/jax' && github.repository == 'jax-ml/jax'
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
with: with:
name: output-${{ matrix.python-version }}-log.jsonl name: output-${{ matrix.python-version }}-log.jsonl

@ -279,7 +279,7 @@ See the 0.4.33 release notes for more details.
which manifested as an incorrect output for cumulative reductions (#21403). which manifested as an incorrect output for cumulative reductions (#21403).
* Fixed a bug where XLA:CPU miscompiled certain matmul fusions * Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(https://github.com/openxla/xla/pull/13301). (https://github.com/openxla/xla/pull/13301).
* Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396). * Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396).
* Deprecations * Deprecations
* `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will * `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
@ -401,7 +401,7 @@ See the 0.4.33 release notes for more details.
branch consistent with that of NumPy 2.0. branch consistent with that of NumPy 2.0.
* The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'` * The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'`
and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has
changed](https://github.com/google/jax/issues/19085) so that changed](https://github.com/jax-ml/jax/issues/19085) so that
mapping over keys results in random generation only from the first mapping over keys results in random generation only from the first
key in the batch. key in the batch.
* Docs now use `jax.random.key` for construction of PRNG key arrays * Docs now use `jax.random.key` for construction of PRNG key arrays
@ -433,7 +433,7 @@ See the 0.4.33 release notes for more details.
* JAX export does not support older serialization versions anymore. Version 9 * JAX export does not support older serialization versions anymore. Version 9
has been supported since October 27th, 2023 and has become the default has been supported since October 27th, 2023 and has become the default
since February 1, 2024. since February 1, 2024.
See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). See [a description of the versions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
This change could break clients that set a specific This change could break clients that set a specific
JAX serialization version lower than 9. JAX serialization version lower than 9.
@ -506,7 +506,7 @@ See the 0.4.33 release notes for more details.
* added the ability to specify symbolic constraints on the dimension variables. * added the ability to specify symbolic constraints on the dimension variables.
This makes shape polymorphism more expressive, and gives a way to workaround This makes shape polymorphism more expressive, and gives a way to workaround
limitations in the reasoning about inequalities. limitations in the reasoning about inequalities.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* with the addition of symbolic constraints ({jax-issue}`#19235`) we now * with the addition of symbolic constraints ({jax-issue}`#19235`) we now
consider dimension variables from different scopes to be different, even consider dimension variables from different scopes to be different, even
if they have the same name. Symbolic expressions from different scopes if they have the same name. Symbolic expressions from different scopes
@ -516,7 +516,7 @@ See the 0.4.33 release notes for more details.
The scope of a symbolic expression `e` can be read with `e.scope` and passed The scope of a symbolic expression `e` can be read with `e.scope` and passed
into the above functions to direct them to construct symbolic expressions in into the above functions to direct them to construct symbolic expressions in
a given scope. a given scope.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* simplified and faster equality comparisons, where we consider two symbolic dimensions * simplified and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0 to be equal if the normalized form of their difference reduces to 0
({jax-issue}`#19231`; note that this may result in user-visible behavior ({jax-issue}`#19231`; note that this may result in user-visible behavior
@ -535,7 +535,7 @@ See the 0.4.33 release notes for more details.
strings for polymorphic shapes specifications ({jax-issue}`#19284`). strings for polymorphic shapes specifications ({jax-issue}`#19284`).
* JAX default native serialization version is now 9. This is relevant * JAX default native serialization version is now 9. This is relevant
for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`. for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`.
See [description of version numbers](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). See [description of version numbers](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
* Refactored the API for `jax.experimental.export`. Instead of * Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now `from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will `from jax.experimental import export`. The old way of importing will
@ -781,19 +781,19 @@ See the 0.4.33 release notes for more details.
* When not running under IPython: when an exception is raised, JAX now filters out the * When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace" entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See that previously appeared.) This should produce much friendlier-looking tracebacks. See
[here](https://github.com/google/jax/pull/16949) for an example. [here](https://github.com/jax-ml/jax/pull/16949) for an example.
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback). `JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape * jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). [safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically * Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are applies to mock devices or user created devices. `jax.devices()` are
already hashable. already hashable.
* Breaking changes: * Breaking changes:
* jax2tf now uses native serialization by default. See * jax2tf now uses native serialization by default. See
the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default. for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always * The option `--jax_coordination_service` has been removed. It is now always
`True`. `True`.
@ -922,7 +922,7 @@ See the 0.4.33 release notes for more details.
arguments will always resolve to the "common operands" `cond` arguments will always resolve to the "common operands" `cond`
behavior (as documented) if the second and third arguments are behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See callable, even if other operands are callable as well. See
[#16413](https://github.com/google/jax/issues/16413). [#16413](https://github.com/jax-ml/jax/issues/16413).
* The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`, * The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
which did nothing, have been removed. These options have been true by which did nothing, have been removed. These options have been true by
default for many releases. default for many releases.
@ -933,7 +933,7 @@ See the 0.4.33 release notes for more details.
serialization version ({jax-issue}`#16746`). serialization version ({jax-issue}`#16746`).
* jax2tf in presence of shape polymorphism now generates code that checks * jax2tf in presence of shape polymorphism now generates code that checks
certain shape constraints, if the serialization version is at least 7. certain shape constraints, if the serialization version is at least 7.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
## jaxlib 0.4.14 (July 27, 2023) ## jaxlib 0.4.14 (July 27, 2023)
@ -1095,14 +1095,14 @@ See the 0.4.33 release notes for more details.
{func}`jax.experimental.host_callback` is no longer supported on Cloud TPU {func}`jax.experimental.host_callback` is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the [JAX issue with the new runtime component. Please file an issue on the [JAX issue
tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs tracker](https://github.com/jax-ml/jax/issues) if the new `jax.debug` APIs
are insufficient for your use case. are insufficient for your use case.
The old runtime component will be available for at least the next three The old runtime component will be available for at least the next three
months by setting the environment variable months by setting the environment variable
`JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new `JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new
runtime for any reason, please let us know on the [JAX issue runtime for any reason, please let us know on the [JAX issue
tracker](https://github.com/google/jax/issues). tracker](https://github.com/jax-ml/jax/issues).
* Changes * Changes
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7. * The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
@ -1126,7 +1126,7 @@ See the 0.4.33 release notes for more details.
StableHLO module for the entire JAX function instead of lowering each JAX StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics. the confidence that what you serialize matches the JAX native semantics.
See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). See [documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
As part of this change the config flag `--jax2tf_default_experimental_native_lowering` As part of this change the config flag `--jax2tf_default_experimental_native_lowering`
has been renamed to `--jax2tf_native_serialization`. has been renamed to `--jax2tf_native_serialization`.
* JAX now depends on `ml_dtypes`, which contains definitions of NumPy types * JAX now depends on `ml_dtypes`, which contains definitions of NumPy types
@ -1403,7 +1403,7 @@ Changes:
## jaxlib 0.3.22 (Oct 11, 2022) ## jaxlib 0.3.22 (Oct 11, 2022)
## jax 0.3.21 (Sep 30, 2022) ## jax 0.3.21 (Sep 30, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.20...jax-v0.3.21). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.20...jax-v0.3.21).
* Changes * Changes
* The persistent compilation cache will now warn instead of raising an * The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}`#12582`), so program execution can continue exception on error ({jax-issue}`#12582`), so program execution can continue
@ -1417,18 +1417,18 @@ Changes:
* Fix incorrect `pip` url in `setup.py` comment ({jax-issue}`#12528`). * Fix incorrect `pip` url in `setup.py` comment ({jax-issue}`#12528`).
## jaxlib 0.3.20 (Sep 28, 2022) ## jaxlib 0.3.20 (Sep 28, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20). * [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20).
* Bug fixes * Bug fixes
* Fixes support for limiting the visible CUDA devices via * Fixes support for limiting the visible CUDA devices via
`jax_cuda_visible_devices` in distributed jobs. This functionality is needed for `jax_cuda_visible_devices` in distributed jobs. This functionality is needed for
the JAX/SLURM integration on GPU ({jax-issue}`#12533`). the JAX/SLURM integration on GPU ({jax-issue}`#12533`).
## jax 0.3.19 (Sep 27, 2022) ## jax 0.3.19 (Sep 27, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.18...jax-v0.3.19). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.18...jax-v0.3.19).
* Fixes required jaxlib version. * Fixes required jaxlib version.
## jax 0.3.18 (Sep 26, 2022) ## jax 0.3.18 (Sep 26, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.17...jax-v0.3.18). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.17...jax-v0.3.18).
* Changes * Changes
* Ahead-of-time lowering and compilation functionality (tracked in * Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}`#7733`) is stable and public. See [the {jax-issue}`#7733`) is stable and public. See [the
@ -1446,7 +1446,7 @@ Changes:
would have been provided. would have been provided.
## jax 0.3.17 (Aug 31, 2022) ## jax 0.3.17 (Aug 31, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.16...jax-v0.3.17).
* Bugs * Bugs
* Fix corner case issue in gradient of `lax.pow` with an exponent of zero * Fix corner case issue in gradient of `lax.pow` with an exponent of zero
({jax-issue}`12041`) ({jax-issue}`12041`)
@ -1462,7 +1462,7 @@ Changes:
* `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead. * `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead.
## jax 0.3.16 ## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main).
* Breaking changes * Breaking changes
* Support for NumPy 1.19 has been dropped, per the * Support for NumPy 1.19 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
@ -1486,7 +1486,7 @@ Changes:
deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
## jax 0.3.15 (July 22, 2022) ## jax 0.3.15 (July 22, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...jax-v0.3.15). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15).
* Changes * Changes
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These * `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`). classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
@ -1507,10 +1507,10 @@ Changes:
following a similar deprecation in {func}`scipy.linalg.solve`. following a similar deprecation in {func}`scipy.linalg.solve`.
## jaxlib 0.3.15 (July 22, 2022) ## jaxlib 0.3.15 (July 22, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15). * [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15).
## jax 0.3.14 (June 27, 2022) ## jax 0.3.14 (June 27, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.13...jax-v0.3.14).
* Breaking changes * Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support * {func}`jax.experimental.compilation_cache.initialize_cache` does not support
`max_cache_size_ bytes` anymore and will not get that as an input. `max_cache_size_ bytes` anymore and will not get that as an input.
@ -1563,22 +1563,22 @@ Changes:
coefficients have leading zeros ({jax-issue}`#11215`). coefficients have leading zeros ({jax-issue}`#11215`).
## jaxlib 0.3.14 (June 27, 2022) ## jaxlib 0.3.14 (June 27, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14). * [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14).
* x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 * x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14
was released in 2018, so this should not be a very onerous requirement. was released in 2018, so this should not be a very onerous requirement.
* The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks. * The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
* The Python flatbuffers package is no longer a dependency of jaxlib. * The Python flatbuffers package is no longer a dependency of jaxlib.
## jax 0.3.13 (May 16, 2022) ## jax 0.3.13 (May 16, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.12...jax-v0.3.13). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.12...jax-v0.3.13).
## jax 0.3.12 (May 15, 2022) ## jax 0.3.12 (May 15, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.11...jax-v0.3.12). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.11...jax-v0.3.12).
* Changes * Changes
* Fixes [#10717](https://github.com/google/jax/issues/10717). * Fixes [#10717](https://github.com/jax-ml/jax/issues/10717).
## jax 0.3.11 (May 15, 2022) ## jax 0.3.11 (May 15, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...jax-v0.3.11). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.10...jax-v0.3.11).
* Changes * Changes
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument * {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
that allows users to opt out of eigenvalue sorting on TPU. that allows users to opt out of eigenvalue sorting on TPU.
@ -1592,22 +1592,22 @@ Changes:
scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead. scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead.
## jax 0.3.10 (May 3, 2022) ## jax 0.3.10 (May 3, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.9...jax-v0.3.10). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.9...jax-v0.3.10).
## jaxlib 0.3.10 (May 3, 2022) ## jaxlib 0.3.10 (May 3, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10). * [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10).
* Changes * Changes
* [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a) * [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a)
fixes an issue in the MHLO canonicalizer that caused constant folding to fixes an issue in the MHLO canonicalizer that caused constant folding to
take a long time or crash for certain programs. take a long time or crash for certain programs.
## jax 0.3.9 (May 2, 2022) ## jax 0.3.9 (May 2, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.8...jax-v0.3.9). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.8...jax-v0.3.9).
* Changes * Changes
* Added support for fully asynchronous checkpointing for GlobalDeviceArray. * Added support for fully asynchronous checkpointing for GlobalDeviceArray.
## jax 0.3.8 (April 29 2022) ## jax 0.3.8 (April 29 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...jax-v0.3.8). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.7...jax-v0.3.8).
* Changes * Changes
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver. * {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input. * {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
@ -1666,7 +1666,7 @@ Changes:
## jax 0.3.7 (April 15, 2022) ## jax 0.3.7 (April 15, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7). commits](https://github.com/jax-ml/jax/compare/jax-v0.3.6...jax-v0.3.7).
* Changes: * Changes:
* Fixed a performance problem if the indices passed to * Fixed a performance problem if the indices passed to
{func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`). {func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`).
@ -1684,17 +1684,17 @@ Changes:
## jax 0.3.6 (April 12, 2022) ## jax 0.3.6 (April 12, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.5...jax-v0.3.6). commits](https://github.com/jax-ml/jax/compare/jax-v0.3.5...jax-v0.3.6).
* Changes: * Changes:
* Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU * Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU
pod. Fixes [#10218](https://github.com/google/jax/issues/10218). pod. Fixes [#10218](https://github.com/jax-ml/jax/issues/10218).
* Deprecations: * Deprecations:
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278` * {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
for an alternative API. for an alternative API.
## jax 0.3.5 (April 7, 2022) ## jax 0.3.5 (April 7, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.4...jax-v0.3.5). commits](https://github.com/jax-ml/jax/compare/jax-v0.3.4...jax-v0.3.5).
* Changes: * Changes:
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta` * added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`). and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
@ -1717,17 +1717,17 @@ Changes:
## jax 0.3.4 (March 18, 2022) ## jax 0.3.4 (March 18, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.3...jax-v0.3.4). commits](https://github.com/jax-ml/jax/compare/jax-v0.3.3...jax-v0.3.4).
## jax 0.3.3 (March 17, 2022) ## jax 0.3.3 (March 17, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.2...jax-v0.3.3). commits](https://github.com/jax-ml/jax/compare/jax-v0.3.2...jax-v0.3.3).
## jax 0.3.2 (March 16, 2022) ## jax 0.3.2 (March 16, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.1...jax-v0.3.2). commits](https://github.com/jax-ml/jax/compare/jax-v0.3.1...jax-v0.3.2).
* Changes: * Changes:
* The functions `jax.ops.index_update`, `jax.ops.index_add`, which were * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were
deprecated in 0.2.22, have been removed. Please use deprecated in 0.2.22, have been removed. Please use
@ -1751,7 +1751,7 @@ Changes:
## jax 0.3.1 (Feb 18, 2022) ## jax 0.3.1 (Feb 18, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.0...jax-v0.3.1). commits](https://github.com/jax-ml/jax/compare/jax-v0.3.0...jax-v0.3.1).
* Changes: * Changes:
* `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated. * `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated.
@ -1774,7 +1774,7 @@ Changes:
## jax 0.3.0 (Feb 10, 2022) ## jax 0.3.0 (Feb 10, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.28...jax-v0.3.0). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0).
* Changes * Changes
* jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html)
@ -1788,7 +1788,7 @@ Changes:
## jax 0.2.28 (Feb 1, 2022) ## jax 0.2.28 (Feb 1, 2022)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.27...jax-v0.2.28). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.27...jax-v0.2.28).
* `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no * `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no
`dialect=` is passed. `dialect=` is passed.
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR * The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
@ -1813,7 +1813,7 @@ Changes:
* The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311). * The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
## jax 0.2.27 (Jan 18 2022) ## jax 0.2.27 (Jan 18 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...jax-v0.2.27). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.26...jax-v0.2.27).
* Breaking changes: * Breaking changes:
* Support for NumPy 1.18 has been dropped, per the * Support for NumPy 1.18 has been dropped, per the
@ -1858,7 +1858,7 @@ Changes:
## jax 0.2.26 (Dec 8, 2021) ## jax 0.2.26 (Dec 8, 2021)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.25...jax-v0.2.26). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.25...jax-v0.2.26).
* Bug fixes: * Bug fixes:
* Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with * Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with
@ -1875,7 +1875,7 @@ Changes:
## jax 0.2.25 (Nov 10, 2021) ## jax 0.2.25 (Nov 10, 2021)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.24...jax-v0.2.25).
* New features: * New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend. * (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
@ -1889,7 +1889,7 @@ Changes:
## jax 0.2.24 (Oct 19, 2021) ## jax 0.2.24 (Oct 19, 2021)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.22...jax-v0.2.24).
* New features: * New features:
* `jax.random.choice` and `jax.random.permutation` now support * `jax.random.choice` and `jax.random.permutation` now support
@ -1923,7 +1923,7 @@ Changes:
## jax 0.2.22 (Oct 12, 2021) ## jax 0.2.22 (Oct 12, 2021)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.21...jax-v0.2.22). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.21...jax-v0.2.22).
* Breaking Changes * Breaking Changes
* Static arguments to `jax.pmap` must now be hashable. * Static arguments to `jax.pmap` must now be hashable.
@ -1958,13 +1958,13 @@ Changes:
* Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports * Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+. CUDA 11.1+.
* Bug fixes: * Bug fixes:
* Fixes https://github.com/google/jax/issues/7461, which caused wrong * Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong
outputs on all platforms due to incorrect buffer aliasing inside the XLA outputs on all platforms due to incorrect buffer aliasing inside the XLA
compiler. compiler.
## jax 0.2.21 (Sept 23, 2021) ## jax 0.2.21 (Sept 23, 2021)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes * Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*` * `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in were aliases for functions in `jax.*`; please use the functions in
@ -1992,7 +1992,7 @@ Changes:
## jax 0.2.20 (Sept 2, 2021) ## jax 0.2.20 (Sept 2, 2021)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.19...jax-v0.2.20). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20).
* Breaking Changes * Breaking Changes
* `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`) * `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`)
* `jnp.unique` and other set-like operations now require array-like inputs * `jnp.unique` and other set-like operations now require array-like inputs
@ -2005,7 +2005,7 @@ Changes:
## jax 0.2.19 (Aug 12, 2021) ## jax 0.2.19 (Aug 12, 2021)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.18...jax-v0.2.19). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19).
* Breaking changes: * Breaking changes:
* Support for NumPy 1.17 has been dropped, per the * Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
@ -2042,7 +2042,7 @@ Changes:
called in sequence. called in sequence.
## jax 0.2.18 (July 21 2021) ## jax 0.2.18 (July 21 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...jax-v0.2.18). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18).
* Breaking changes: * Breaking changes:
* Support for Python 3.6 has been dropped, per the * Support for Python 3.6 has been dropped, per the
@ -2065,7 +2065,7 @@ Changes:
* Fix bugs in TFRT CPU backend that results in incorrect results. * Fix bugs in TFRT CPU backend that results in incorrect results.
## jax 0.2.17 (July 9 2021) ## jax 0.2.17 (July 9 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes: * Bug fixes:
* Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68 * Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
to work around #7229, which caused wrong outputs on CPU due to a concurrency to work around #7229, which caused wrong outputs on CPU due to a concurrency
@ -2082,12 +2082,12 @@ Changes:
## jax 0.2.16 (June 23 2021) ## jax 0.2.16 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.15...jax-v0.2.16).
## jax 0.2.15 (June 23 2021) ## jax 0.2.15 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.14...jax-v0.2.15).
* New features: * New features:
* [#7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend * [#7042](https://github.com/jax-ml/jax/pull/7042) Turned on TFRT CPU backend
with significant dispatch performance improvements on CPU. with significant dispatch performance improvements on CPU.
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans * The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`#6956`). ({jax-issue}`#6956`).
@ -2107,7 +2107,7 @@ Changes:
CPU. CPU.
## jax 0.2.14 (June 10 2021) ## jax 0.2.14 (June 10 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.13...jax-v0.2.14).
* New features: * New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`. * The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters * A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
@ -2165,7 +2165,7 @@ Changes:
{func}`jit` transformed functions. {func}`jit` transformed functions.
## jax 0.2.13 (May 3 2021) ## jax 0.2.13 (May 3 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.12...jax-v0.2.13).
* New features: * New features:
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static * When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify keyword arguments. A new `static_argnames` option has been added to specify
@ -2209,7 +2209,7 @@ Changes:
## jaxlib 0.1.65 (April 7 2021) ## jaxlib 0.1.65 (April 7 2021)
## jax 0.2.12 (April 1 2021) ## jax 0.2.12 (April 1 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.11...v0.2.12).
* New features * New features
* New profiling APIs: {func}`jax.profiler.start_trace`, * New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace` {func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
@ -2222,7 +2222,7 @@ Changes:
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation` * `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation` * `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function` * `trace_function` --> {func}`~jax.profiler.annotate_function`
* Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md) * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md)
for more information. for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow * Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`). in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
@ -2236,23 +2236,23 @@ Changes:
## jax 0.2.11 (March 23 2021) ## jax 0.2.11 (March 23 2021)
* [GitHub * [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11). commits](https://github.com/jax-ml/jax/compare/jax-v0.2.10...jax-v0.2.11).
* New features: * New features:
* [#6112](https://github.com/google/jax/pull/6112) added context managers: * [#6112](https://github.com/jax-ml/jax/pull/6112) added context managers:
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`, `jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
`jax.debug_infs`, `jax.log_compiles`. `jax.debug_infs`, `jax.log_compiles`.
* [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete` * [#6085](https://github.com/jax-ml/jax/pull/6085) added `jnp.delete`
* Bug fixes: * Bug fixes:
* [#6136](https://github.com/google/jax/pull/6136) generalized * [#6136](https://github.com/jax-ml/jax/pull/6136) generalized
`jax.flatten_util.ravel_pytree` to handle integer dtypes. `jax.flatten_util.ravel_pytree` to handle integer dtypes.
* [#6129](https://github.com/google/jax/issues/6129) fixed a bug with handling * [#6129](https://github.com/jax-ml/jax/issues/6129) fixed a bug with handling
some constants like `enum.IntEnums` some constants like `enum.IntEnums`
* [#6145](https://github.com/google/jax/pull/6145) fixed batching issues with * [#6145](https://github.com/jax-ml/jax/pull/6145) fixed batching issues with
incomplete beta functions incomplete beta functions
* [#6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during * [#6014](https://github.com/jax-ml/jax/pull/6014) fixed H2D transfers during
tracing tracing
* [#6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when * [#6165](https://github.com/jax-ml/jax/pull/6165) avoids OverflowErrors when
converting some large Python integers to floats converting some large Python integers to floats
* Breaking changes: * Breaking changes:
* The minimum jaxlib version is now 0.1.62. * The minimum jaxlib version is now 0.1.62.
@ -2264,13 +2264,13 @@ Changes:
## jax 0.2.10 (March 5 2021) ## jax 0.2.10 (March 5 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.9...jax-v0.2.10).
* New features: * New features:
* {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods. * {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
* {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods. * {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods.
* Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions * Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
from JAX ({jax-issue}`#5627`) from JAX ({jax-issue}`#5627`)
and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)). and [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
* Extended the batching rule for `lax.pad` to support batching of the padding values. * Extended the batching rule for `lax.pad` to support batching of the padding values.
* Bug fixes: * Bug fixes:
* {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`) * {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`)
@ -2314,7 +2314,7 @@ Changes:
## jax 0.2.9 (January 26 2021) ## jax 0.2.9 (January 26 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.8...jax-v0.2.9).
* New features: * New features:
* Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved * Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved
error checking and error messages. error checking and error messages.
@ -2330,7 +2330,7 @@ Changes:
## jax 0.2.8 (January 12 2021) ## jax 0.2.8 (January 12 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.7...jax-v0.2.8).
* New features: * New features:
* Add {func}`jax.closure_convert` for use with higher-order custom * Add {func}`jax.closure_convert` for use with higher-order custom
derivative functions. ({jax-issue}`#5244`) derivative functions. ({jax-issue}`#5244`)
@ -2362,7 +2362,7 @@ Changes:
## jax 0.2.7 (Dec 4 2020) ## jax 0.2.7 (Dec 4 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.6...jax-v0.2.7).
* New features: * New features:
* Add `jax.device_put_replicated` * Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit` * Add multi-host support to `jax.experimental.sharded_jit`
@ -2382,14 +2382,14 @@ Changes:
## jax 0.2.6 (Nov 18 2020) ## jax 0.2.6 (Nov 18 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.5...jax-v0.2.6).
* New Features: * New Features:
* Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. * Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). See [README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
* Breaking change cleanup * Breaking change cleanup
* Raise an error on non-hashable static arguments for jax.jit and * Raise an error on non-hashable static arguments for jax.jit and
xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42). xla_computation. See [cb48f42](https://github.com/jax-ml/jax/commit/cb48f42).
* Improve consistency of type promotion behavior ({jax-issue}`#4744`): * Improve consistency of type promotion behavior ({jax-issue}`#4744`):
* Adding a complex Python scalar to a JAX floating point number respects the precision of * Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
@ -2441,15 +2441,15 @@ Changes:
## jax 0.2.5 (October 27 2020) ## jax 0.2.5 (October 27 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.4...jax-v0.2.5).
* Improvements: * Improvements:
* Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`. * Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`.
* Expanded the set of JAX primitives converted by jax2tf. * Expanded the set of JAX primitives converted by jax2tf.
See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md). See [primitives_with_limited_support.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).
## jax 0.2.4 (October 19 2020) ## jax 0.2.4 (October 19 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.3...jax-v0.2.4).
* Improvements: * Improvements:
* Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`. * Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`.
* Deprecations * Deprecations
@ -2461,17 +2461,17 @@ Changes:
## jax 0.2.3 (October 14 2020) ## jax 0.2.3 (October 14 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.2...jax-v0.2.3).
* The reason for another release so soon is we need to temporarily roll back a * The reason for another release so soon is we need to temporarily roll back a
new jit fastpath while we look into a performance degradation new jit fastpath while we look into a performance degradation
## jax 0.2.2 (October 13 2020) ## jax 0.2.2 (October 13 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.1...jax-v0.2.2). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.1...jax-v0.2.2).
## jax 0.2.1 (October 6 2020) ## jax 0.2.1 (October 6 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.0...jax-v0.2.1). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.0...jax-v0.2.1).
* Improvements: * Improvements:
* As a benefit of omnistaging, the host_callback functions are executed (in program * As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/ order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/
@ -2479,10 +2479,10 @@ Changes:
## jax (0.2.0) (September 23 2020) ## jax (0.2.0) (September 23 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.77...jax-v0.2.0). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.77...jax-v0.2.0).
* Improvements: * Improvements:
* Omnistaging on by default. See {jax-issue}`#3370` and * Omnistaging on by default. See {jax-issue}`#3370` and
[omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md) [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md)
## jax (0.1.77) (September 15 2020) ## jax (0.1.77) (September 15 2020)
@ -2496,11 +2496,11 @@ Changes:
## jax 0.1.76 (September 8, 2020) ## jax 0.1.76 (September 8, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.75...jax-v0.1.76). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.75...jax-v0.1.76).
## jax 0.1.75 (July 30, 2020) ## jax 0.1.75 (July 30, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.74...jax-v0.1.75). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.74...jax-v0.1.75).
* Bug Fixes: * Bug Fixes:
* make jnp.abs() work for unsigned inputs (#3914) * make jnp.abs() work for unsigned inputs (#3914)
* Improvements: * Improvements:
@ -2508,7 +2508,7 @@ Changes:
## jax 0.1.74 (July 29, 2020) ## jax 0.1.74 (July 29, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.73...jax-v0.1.74).
* New Features: * New Features:
* BFGS (#3101) * BFGS (#3101)
* TPU support for half-precision arithmetic (#3878) * TPU support for half-precision arithmetic (#3878)
@ -2525,7 +2525,7 @@ Changes:
## jax 0.1.73 (July 22, 2020) ## jax 0.1.73 (July 22, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.72...jax-v0.1.73). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.72...jax-v0.1.73).
* The minimum jaxlib version is now 0.1.51. * The minimum jaxlib version is now 0.1.51.
* New Features: * New Features:
* jax.image.resize. (#3703) * jax.image.resize. (#3703)
@ -2563,14 +2563,14 @@ Changes:
## jax 0.1.72 (June 28, 2020) ## jax 0.1.72 (June 28, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.71...jax-v0.1.72). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.71...jax-v0.1.72).
* Bug fixes: * Bug fixes:
* Fix an odeint bug introduced in the previous release, see * Fix an odeint bug introduced in the previous release, see
{jax-issue}`#3587`. {jax-issue}`#3587`.
## jax 0.1.71 (June 25, 2020) ## jax 0.1.71 (June 25, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.70...jax-v0.1.71). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.70...jax-v0.1.71).
* The minimum jaxlib version is now 0.1.48. * The minimum jaxlib version is now 0.1.48.
* Bug fixes: * Bug fixes:
* Allow `jax.experimental.ode.odeint` dynamics functions to close over * Allow `jax.experimental.ode.odeint` dynamics functions to close over
@ -2606,7 +2606,7 @@ Changes:
## jax 0.1.70 (June 8, 2020) ## jax 0.1.70 (June 8, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.69...jax-v0.1.70).
* New features: * New features:
* `lax.switch` introduces indexed conditionals with multiple * `lax.switch` introduces indexed conditionals with multiple
branches, together with a generalization of the `cond` branches, together with a generalization of the `cond`
@ -2615,11 +2615,11 @@ Changes:
## jax 0.1.69 (June 3, 2020) ## jax 0.1.69 (June 3, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.68...jax-v0.1.69). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.68...jax-v0.1.69).
## jax 0.1.68 (May 21, 2020) ## jax 0.1.68 (May 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.67...jax-v0.1.68). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.67...jax-v0.1.68).
* New features: * New features:
* {func}`lax.cond` supports a single-operand form, taken as the argument * {func}`lax.cond` supports a single-operand form, taken as the argument
to both branches to both branches
@ -2630,7 +2630,7 @@ Changes:
## jax 0.1.67 (May 12, 2020) ## jax 0.1.67 (May 12, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.66...jax-v0.1.67). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.66...jax-v0.1.67).
* New features: * New features:
* Support for reduction over subsets of a pmapped axis using `axis_index_groups` * Support for reduction over subsets of a pmapped axis using `axis_index_groups`
{jax-issue}`#2382`. {jax-issue}`#2382`.
@ -2648,7 +2648,7 @@ Changes:
## jax 0.1.66 (May 5, 2020) ## jax 0.1.66 (May 5, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.65...jax-v0.1.66). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.65...jax-v0.1.66).
* New features: * New features:
* Support for `in_axes=None` on {func}`pmap` * Support for `in_axes=None` on {func}`pmap`
{jax-issue}`#2896`. {jax-issue}`#2896`.
@ -2661,7 +2661,7 @@ Changes:
## jax 0.1.65 (April 30, 2020) ## jax 0.1.65 (April 30, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.64...jax-v0.1.65). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.64...jax-v0.1.65).
* New features: * New features:
* Differentiation of determinants of singular matrices * Differentiation of determinants of singular matrices
{jax-issue}`#2809`. {jax-issue}`#2809`.
@ -2679,7 +2679,7 @@ Changes:
## jax 0.1.64 (April 21, 2020) ## jax 0.1.64 (April 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.63...jax-v0.1.64).
* New features: * New features:
* Add syntactic sugar for functional indexed updates * Add syntactic sugar for functional indexed updates
{jax-issue}`#2684`. {jax-issue}`#2684`.
@ -2706,7 +2706,7 @@ Changes:
## jax 0.1.63 (April 12, 2020) ## jax 0.1.63 (April 12, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63).
* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works).
* Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`.
* Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`.
@ -2727,7 +2727,7 @@ Changes:
## jax 0.1.62 (March 21, 2020) ## jax 0.1.62 (March 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.61...jax-v0.1.62).
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer. * JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
* Removed the internal function `lax._safe_mul`, which implemented the * Removed the internal function `lax._safe_mul`, which implemented the
convention `0. * nan == 0.`. This change means some programs when convention `0. * nan == 0.`. This change means some programs when
@ -2745,13 +2745,13 @@ Changes:
## jax 0.1.61 (March 17, 2020) ## jax 0.1.61 (March 17, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.60...jax-v0.1.61).
* Fixes Python 3.5 support. This will be the last JAX or jaxlib release that * Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
supports Python 3.5. supports Python 3.5.
## jax 0.1.60 (March 17, 2020) ## jax 0.1.60 (March 17, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.59...jax-v0.1.60).
* New features: * New features:
* {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows * {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows
the user to specify arguments that should be treated as compile-time the user to specify arguments that should be treated as compile-time
@ -2777,7 +2777,7 @@ Changes:
## jax 0.1.59 (February 11, 2020) ## jax 0.1.59 (February 11, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59). * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.58...jax-v0.1.59).
* Breaking changes * Breaking changes
* The minimum jaxlib version is now 0.1.38. * The minimum jaxlib version is now 0.1.38.
@ -2809,7 +2809,7 @@ Changes:
## jax 0.1.58 (January 28, 2020) ## jax 0.1.58 (January 28, 2020)
* [GitHub commits](https://github.com/google/jax/compare/46014da21...jax-v0.1.58). * [GitHub commits](https://github.com/jax-ml/jax/compare/46014da21...jax-v0.1.58).
* Breaking changes * Breaking changes
* JAX has dropped Python 2 support, because Python 2 reached its end of life on * JAX has dropped Python 2 support, because Python 2 reached its end of life on

@ -1,7 +1,7 @@
@software{jax2018github, @software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax}, url = {http://github.com/jax-ml/jax},
version = {0.3.13}, version = {0.3.13},
year = {2018}, year = {2018},
} }

@ -1,10 +1,10 @@
<div align="center"> <div align="center">
<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img> <img src="https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png" alt="logo"></img>
</div> </div>
# Transformable numerical computing at scale # Transformable numerical computing at scale
![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg) ![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)
![PyPI version](https://img.shields.io/pypi/v/jax) ![PyPI version](https://img.shields.io/pypi/v/jax)
[**Quickstart**](#quickstart-colab-in-the-cloud) [**Quickstart**](#quickstart-colab-in-the-cloud)
@ -50,7 +50,7 @@ parallel programming of multiple accelerators, with more to come.
This is a research project, not an official Google product. Expect bugs and This is a research project, not an official Google product. Expect bugs and
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). [sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
Please help by trying it out, [reporting Please help by trying it out, [reporting
bugs](https://github.com/google/jax/issues), and letting us know what you bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
think! think!
```python ```python
@ -84,16 +84,16 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
Here are some starter notebooks: Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) - [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) - [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs). Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs).
For a deeper dive into JAX: For a deeper dive into JAX:
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) - [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) - [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- See the [full list of - See the [full list of
notebooks](https://github.com/google/jax/tree/main/docs/notebooks). notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks).
## Transformations ## Transformations
@ -300,7 +300,7 @@ print(normalize(jnp.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ] # prints [0. 0.16666667 0.33333334 0.5 ]
``` ```
You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
sophisticated communication patterns. sophisticated communication patterns.
It all composes, so you're free to differentiate through parallel computations: It all composes, so you're free to differentiate through parallel computations:
@ -333,9 +333,9 @@ When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
backward pass of the computation is parallelized just like the forward pass. backward pass of the computation is parallelized just like the forward pass.
See the [SPMD See the [SPMD
Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
and the [SPMD MNIST classifier from scratch and the [SPMD MNIST classifier from scratch
example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
for more. for more.
## Current gotchas ## Current gotchas
@ -349,7 +349,7 @@ Some standouts:
1. [In-place mutating updates of 1. [In-place mutating updates of
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
1. [Random numbers are 1. [Random numbers are
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md). different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
1. If you're looking for [convolution 1. If you're looking for [convolution
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
they're in the `jax.lax` package. they're in the `jax.lax` package.
@ -437,7 +437,7 @@ To cite this repository:
@software{jax2018github, @software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax}, url = {http://github.com/jax-ml/jax},
version = {0.3.13}, version = {0.3.13},
year = {2018}, year = {2018},
} }

@ -451,7 +451,7 @@
"id": "jC-KIMQ1q-lK" "id": "jC-KIMQ1q-lK"
}, },
"source": [ "source": [
"For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
] ]
}, },
{ {

@ -837,7 +837,7 @@
"id": "f-FBsWeo1AXE" "id": "f-FBsWeo1AXE"
}, },
"source": [ "source": [
"<img src=\"https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/nested_pmap.png\" width=\"70%\"/>" "<img src=\"https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/nested_pmap.png\" width=\"70%\"/>"
] ]
}, },
{ {
@ -847,7 +847,7 @@
"id": "jC-KIMQ1q-lK" "id": "jC-KIMQ1q-lK"
}, },
"source": [ "source": [
"For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
] ]
}, },
{ {

@ -15,13 +15,13 @@
"id": "sk-3cPGIBTq8" "id": "sk-3cPGIBTq8"
}, },
"source": [ "source": [
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n",
"\n", "\n",
"This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n", "This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n",
"\n", "\n",
"**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.\n", "**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.\n",
"\n", "\n",
"The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)." "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
] ]
}, },
{ {

@ -13,25 +13,25 @@ VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab: The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab:
### [Pmap Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) ### [Pmap Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
A guide to getting started with `pmap`, a transform for easily distributing SPMD A guide to getting started with `pmap`, a transform for easily distributing SPMD
computations across devices. computations across devices.
### [Lorentz ODE Solver](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb) ### [Lorentz ODE Solver](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb)
Contributed by Alex Alemi (alexalemi@) Contributed by Alex Alemi (alexalemi@)
Solve and plot parallel ODE solutions with `pmap`. Solve and plot parallel ODE solutions with `pmap`.
<img src="https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/lorentz.png" width=65%></image> <img src="https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/lorentz.png" width=65%></image>
### [Wave Equation](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb) ### [Wave Equation](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb)
Contributed by Stephan Hoyer (shoyer@) Contributed by Stephan Hoyer (shoyer@)
Solve the wave equation with `pmap`, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU. Solve the wave equation with `pmap`, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU.
![](https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/wave_movie.gif) ![](https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/wave_movie.gif)
### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb) ### [JAX Demo](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb)
An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/) and the [Compilers for ML workshop at CGO 2020](https://www.c4ml.org/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`. An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/) and the [Compilers for ML workshop at CGO 2020](https://www.c4ml.org/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`.
## Performance notes ## Performance notes
@ -53,7 +53,7 @@ By default\*, matrix multiplication in JAX on TPUs [uses bfloat16](https://cloud
JAX also adds the `bfloat16` dtype, which you can use to explicitly cast arrays to bfloat16, e.g., `jax.numpy.array(x, dtype=jax.numpy.bfloat16)`. JAX also adds the `bfloat16` dtype, which you can use to explicitly cast arrays to bfloat16, e.g., `jax.numpy.array(x, dtype=jax.numpy.bfloat16)`.
\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/google/jax/issues/2161) if it affects you! \* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/jax-ml/jax/issues/2161) if it affects you!
## Running JAX on a Cloud TPU VM ## Running JAX on a Cloud TPU VM
@ -65,8 +65,8 @@ documentation](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPU If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPU
VM), please email <cloud-tpu-support@google.com>, or <trc-support@google.com> if VM), please email <cloud-tpu-support@google.com>, or <trc-support@google.com> if
you are a [TRC](https://sites.research.google/trc/) member. You can also [file a you are a [TRC](https://sites.research.google/trc/) member. You can also [file a
JAX issue](https://github.com/google/jax/issues) or [ask a discussion JAX issue](https://github.com/jax-ml/jax/issues) or [ask a discussion
question](https://github.com/google/jax/discussions) for any issues with these question](https://github.com/jax-ml/jax/discussions) for any issues with these
notebooks or using JAX in general. notebooks or using JAX in general.
If you have any other questions or comments regarding JAX on Cloud TPUs, please If you have any other questions or comments regarding JAX on Cloud TPUs, please

@ -571,7 +571,7 @@ print("Naive full Hessian materialization")
### Jacobian-Matrix and Matrix-Jacobian products ### Jacobian-Matrix and Matrix-Jacobian products
Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:
```{code-cell} ```{code-cell}
# Isolate the function from the weight matrix to the predictions # Isolate the function from the weight matrix to the predictions

@ -27,7 +27,7 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"[![Open in\n", "[![Open in\n",
"Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)" "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)"
] ]
}, },
{ {
@ -1781,7 +1781,7 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"This is precisely the issue that\n", "This is precisely the issue that\n",
"[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n", "[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.\n",
"We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n", "We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n",
"applied, regardless of whether any inputs to `bind` are boxed in corresponding\n", "applied, regardless of whether any inputs to `bind` are boxed in corresponding\n",
"`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n", "`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n",

@ -33,7 +33,7 @@ limitations under the License.
``` ```
[![Open in [![Open in
Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb) Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)
+++ +++
@ -1399,7 +1399,7 @@ print(jaxpr)
``` ```
This is precisely the issue that This is precisely the issue that
[omnistaging](https://github.com/google/jax/pull/3370) fixed. [omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.
We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always
applied, regardless of whether any inputs to `bind` are boxed in corresponding applied, regardless of whether any inputs to `bind` are boxed in corresponding
`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`

@ -27,7 +27,7 @@
# --- # ---
# [![Open in # [![Open in
# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb) # Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)
# # Autodidax: JAX core from scratch # # Autodidax: JAX core from scratch
# #
@ -1396,7 +1396,7 @@ print(jaxpr)
# This is precisely the issue that # This is precisely the issue that
# [omnistaging](https://github.com/google/jax/pull/3370) fixed. # [omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.
# We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always # We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always
# applied, regardless of whether any inputs to `bind` are boxed in corresponding # applied, regardless of whether any inputs to `bind` are boxed in corresponding
# `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` # `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`

@ -52,4 +52,4 @@ questions answered are:
.. _Flax: https://flax.readthedocs.io/ .. _Flax: https://flax.readthedocs.io/
.. _Haiku: https://dm-haiku.readthedocs.io/ .. _Haiku: https://dm-haiku.readthedocs.io/
.. _JAX on StackOverflow: https://stackoverflow.com/questions/tagged/jax .. _JAX on StackOverflow: https://stackoverflow.com/questions/tagged/jax
.. _JAX GitHub discussions: https://github.com/google/jax/discussions .. _JAX GitHub discussions: https://github.com/jax-ml/jax/discussions

@ -168,7 +168,7 @@ html_theme = 'sphinx_book_theme'
# documentation. # documentation.
html_theme_options = { html_theme_options = {
'show_toc_level': 2, 'show_toc_level': 2,
'repository_url': 'https://github.com/google/jax', 'repository_url': 'https://github.com/jax-ml/jax',
'use_repository_button': True, # add a "link to repository" button 'use_repository_button': True, # add a "link to repository" button
'navigation_with_keys': False, 'navigation_with_keys': False,
} }
@ -345,7 +345,7 @@ def linkcode_resolve(domain, info):
return None return None
filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__)) filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__))
lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else "" lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else ""
return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}" return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}"
# Generate redirects from deleted files to new sources # Generate redirects from deleted files to new sources
rediraffe_redirects = { rediraffe_redirects = {

@ -5,22 +5,22 @@
Everyone can contribute to JAX, and we value everyone's contributions. There are several Everyone can contribute to JAX, and we value everyone's contributions. There are several
ways to contribute, including: ways to contribute, including:
- Answering questions on JAX's [discussions page](https://github.com/google/jax/discussions) - Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions)
- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) - Improving or expanding JAX's [documentation](http://jax.readthedocs.io/)
- Contributing to JAX's [code-base](http://github.com/google/jax/) - Contributing to JAX's [code-base](http://github.com/jax-ml/jax/)
- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/google/jax#neural-network-libraries) - Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries)
The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
## Ways to contribute ## Ways to contribute
We welcome pull requests, in particular for those issues marked with We welcome pull requests, in particular for those issues marked with
[contributions welcome](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or [contributions welcome](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or
[good first issue](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). [good first issue](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
For other proposals, we ask that you first open a GitHub For other proposals, we ask that you first open a GitHub
[Issue](https://github.com/google/jax/issues/new/choose) or [Issue](https://github.com/jax-ml/jax/issues/new/choose) or
[Discussion](https://github.com/google/jax/discussions) [Discussion](https://github.com/jax-ml/jax/discussions)
to seek feedback on your planned contribution. to seek feedback on your planned contribution.
## Contributing code using pull requests ## Contributing code using pull requests
@ -33,7 +33,7 @@ Follow these steps to contribute code:
For more information, see the Pull Request Checklist below. For more information, see the Pull Request Checklist below.
2. Fork the JAX repository by clicking the **Fork** button on the 2. Fork the JAX repository by clicking the **Fork** button on the
[repository page](http://www.github.com/google/jax). This creates [repository page](http://www.github.com/jax-ml/jax). This creates
a copy of the JAX repository in your own account. a copy of the JAX repository in your own account.
3. Install Python >= 3.10 locally in order to run tests. 3. Install Python >= 3.10 locally in order to run tests.
@ -52,7 +52,7 @@ Follow these steps to contribute code:
changes. changes.
```bash ```bash
git remote add upstream https://www.github.com/google/jax git remote add upstream https://www.github.com/jax-ml/jax
``` ```
6. Create a branch where you will develop from: 6. Create a branch where you will develop from:

@ -6,7 +6,7 @@
First, obtain the JAX source code: First, obtain the JAX source code:
``` ```
git clone https://github.com/google/jax git clone https://github.com/jax-ml/jax
cd jax cd jax
``` ```
@ -26,7 +26,7 @@ If you're only modifying Python portions of JAX, we recommend installing
pip install jaxlib pip install jaxlib
``` ```
See the [JAX readme](https://github.com/google/jax#installation) for full See the [JAX readme](https://github.com/jax-ml/jax#installation) for full
guidance on pip installation (e.g., for GPU and TPU support). guidance on pip installation (e.g., for GPU and TPU support).
### Building `jaxlib` from source ### Building `jaxlib` from source
@ -621,7 +621,7 @@ pytest --doctest-modules jax/_src/numpy/lax_numpy.py
Keep in mind that there are several files that are marked to be skipped when the Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in doctest command is run on the full package; you can see the details in
[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml) [`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml)
## Type checking ## Type checking
@ -712,7 +712,7 @@ jupytext --sync docs/notebooks/thinking_in_jax.ipynb
``` ```
The jupytext version should match that specified in The jupytext version should match that specified in
[.pre-commit-config.yaml](https://github.com/google/jax/blob/main/.pre-commit-config.yaml). [.pre-commit-config.yaml](https://github.com/jax-ml/jax/blob/main/.pre-commit-config.yaml).
To check that the markdown and ipynb files are properly synced, you may use the To check that the markdown and ipynb files are properly synced, you may use the
[pre-commit](https://pre-commit.com/) framework to perform the same check used [pre-commit](https://pre-commit.com/) framework to perform the same check used
@ -740,12 +740,12 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked
Some of the notebooks are built automatically as part of the pre-submit checks and Some of the notebooks are built automatically as part of the pre-submit checks and
as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build.
The build will fail if cells raise errors. If the errors are intentional, you can either catch them, The build will fail if cells raise errors. If the errors are intentional, you can either catch them,
or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)). or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)).
You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else
re-saves the notebook. re-saves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations. We exclude some notebooks from the build, e.g., because they contain long computations.
See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs/conf.py). See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs/conf.py).
### Documentation building on `readthedocs.io` ### Documentation building on `readthedocs.io`
@ -772,7 +772,7 @@ I saw in the Readthedocs logs:
mkvirtualenv jax-docs # A new virtualenv mkvirtualenv jax-docs # A new virtualenv
mkdir jax-docs # A new directory mkdir jax-docs # A new directory
cd jax-docs cd jax-docs
git clone --no-single-branch --depth 50 https://github.com/google/jax git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax
cd jax cd jax
git checkout --force origin/test-docs git checkout --force origin/test-docs
git clean -d -f -f git clean -d -f -f

@ -153,7 +153,7 @@ JAX runtime system that are:
an inference system that is already deployed when the exporting is done. an inference system that is already deployed when the exporting is done.
(The particular compatibility window lengths are the same that JAX (The particular compatibility window lengths are the same that JAX
[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), [promised for jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model),
and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow). and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow).
The terminology “backward compatibility” is from the perspective of the consumer, The terminology “backward compatibility” is from the perspective of the consumer,
e.g., the inference system.) e.g., the inference system.)
@ -626,7 +626,7 @@ We list here a history of the calling convention version numbers:
June 13th, 2023 (JAX 0.4.13). June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and * Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`. for `shape_assertions` specified in `disabled_checks`.
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
since July 12th, 2023 (cl/547482522), since July 12th, 2023 (cl/547482522),
available in JAX serialization since July 20th, 2023 (JAX 0.4.14), available in JAX serialization since July 20th, 2023 (JAX 0.4.14),
and the default since August 12th, 2023 (JAX 0.4.15). and the default since August 12th, 2023 (JAX 0.4.15).
@ -721,7 +721,7 @@ that live in jaxlib):
2. Day “D”, we add the new custom call target `T_NEW`. 2. Day “D”, we add the new custom call target `T_NEW`.
We should create a new custom call target, and clean up the old We should create a new custom call target, and clean up the old
target roughly after 6 months, rather than updating `T` in place: target roughly after 6 months, rather than updating `T` in place:
* See the example [PR #20997](https://github.com/google/jax/pull/20997) * See the example [PR #20997](https://github.com/jax-ml/jax/pull/20997)
implementing the steps below. implementing the steps below.
* We add the custom call target `T_NEW`. * We add the custom call target `T_NEW`.
* We change the JAX lowering rules that were previous using `T`, * We change the JAX lowering rules that were previous using `T`,

@ -2,4 +2,4 @@
## Interoperation with TensorFlow ## Interoperation with TensorFlow
See the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). See the [JAX2TF documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).

@ -372,7 +372,7 @@ device.
Jitted functions behave like any other primitive operations—they will follow the Jitted functions behave like any other primitive operations—they will follow the
data and will show errors if invoked on data committed on more than one device. data and will show errors if invoked on data committed on more than one device.
(Before `PR #6002 <https://github.com/google/jax/pull/6002>`_ in March 2021 (Before `PR #6002 <https://github.com/jax-ml/jax/pull/6002>`_ in March 2021
there was some laziness in creation of array constants, so that there was some laziness in creation of array constants, so that
``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually ``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually
create the array of zeros on ``jax.devices()[1]``, instead of creating the create the array of zeros on ``jax.devices()[1]``, instead of creating the
@ -385,7 +385,7 @@ and its use is not recommended.)
For a worked-out example, we recommend reading through For a worked-out example, we recommend reading through
``test_computation_follows_data`` in ``test_computation_follows_data`` in
`multi_device_test.py <https://github.com/google/jax/blob/main/tests/multi_device_test.py>`_. `multi_device_test.py <https://github.com/jax-ml/jax/blob/main/tests/multi_device_test.py>`_.
.. _faq-benchmark: .. _faq-benchmark:
@ -691,7 +691,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.::
Additional reading: Additional reading:
* `Issue: gradients through jnp.where when one of branches is nan <https://github.com/google/jax/issues/1052#issuecomment-514083352>`_. * `Issue: gradients through jnp.where when one of branches is nan <https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352>`_.
* `How to avoid NaN gradients when using where <https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf>`_. * `How to avoid NaN gradients when using where <https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf>`_.

@ -406,7 +406,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)." "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
] ]
}, },
{ {
@ -492,7 +492,7 @@
"source": [ "source": [
"At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n", "At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n",
"One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n", "One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n",
"JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n", "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n",
"\n", "\n",
"One other JAX feature that this example doesn't support is higher-order AD.\n", "One other JAX feature that this example doesn't support is higher-order AD.\n",
"It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n", "It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n",

@ -333,7 +333,7 @@ def rms_norm_not_vectorized(x, eps=1e-5):
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
``` ```
If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues). If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
+++ +++
@ -406,7 +406,7 @@ np.testing.assert_allclose(
At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`. At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.
One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode. One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.
JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice. JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.
One other JAX feature that this example doesn't support is higher-order AD. One other JAX feature that this example doesn't support is higher-order AD.
It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here. It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.

@ -176,7 +176,7 @@ installation.
JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package. JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.
Make sure that it is present in your CUDA installation. Make sure that it is present in your CUDA installation.
Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) Please let the JAX team know on [the GitHub issue tracker](https://github.com/jax-ml/jax/issues)
if you run into any errors or problems with the pre-built wheels. if you run into any errors or problems with the pre-built wheels.
(docker-containers-nvidia-gpu)= (docker-containers-nvidia-gpu)=
@ -216,7 +216,7 @@ refer to
**Note:** There are several caveats with the Metal plugin: **Note:** There are several caveats with the Metal plugin:
* The Metal plugin is new and experimental and has a number of * The Metal plugin is new and experimental and has a number of
[known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). [known issues](https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
Please report any issues on the JAX issue tracker. Please report any issues on the JAX issue tracker.
* The Metal plugin currently requires very specific versions of `jax` and * The Metal plugin currently requires very specific versions of `jax` and
`jaxlib`. This restriction will be relaxed over time as the plugin API `jaxlib`. This restriction will be relaxed over time as the plugin API

@ -9,7 +9,7 @@ Let's first make a JAX issue.
But if you can pinpoint the commit that triggered the regression, it will really help us. But if you can pinpoint the commit that triggered the regression, it will really help us.
This document explains how we identified the commit that caused a This document explains how we identified the commit that caused a
[15% performance regression](https://github.com/google/jax/issues/17686). [15% performance regression](https://github.com/jax-ml/jax/issues/17686).
## Steps ## Steps
@ -34,7 +34,7 @@ containers](https://github.com/NVIDIA/JAX-Toolbox).
- test_runner.sh: will start the containers and the test. - test_runner.sh: will start the containers and the test.
- test.sh: will install missing dependencies and run the test - test.sh: will install missing dependencies and run the test
Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686 Here are real example scripts used for the issue: https://github.com/jax-ml/jax/issues/17686
- test_runner.sh: - test_runner.sh:
``` ```
for m in 7 8 9; do for m in 7 8 9; do

@ -14,7 +14,7 @@
## Whats going on? ## Whats going on?
As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade) As of [#11830](https://github.com/jax-ml/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade)
## How can I disable the change, and go back to the old behavior for now? ## How can I disable the change, and go back to the old behavior for now?
@ -29,7 +29,7 @@ If you need to revert to the old implementation, **please reach out** on a GitHu
As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer
available. If you have an issue, please reach out on [the issue available. If you have an issue, please reach out on [the issue
tracker](https://github.com/google/jax/issues) so we can help fix it! tracker](https://github.com/jax-ml/jax/issues) so we can help fix it!
## Why are we doing this? ## Why are we doing this?
@ -82,7 +82,7 @@ The old `jax.checkpoint` implementation was forced to save the value of `a`, whi
### Significantly less Python overhead in some cases ### Significantly less Python overhead in some cases
The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/google/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice! The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/jax-ml/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice!
### Enabling new JAX features by simplifying internals ### Enabling new JAX features by simplifying internals

@ -12,7 +12,7 @@ The current state of type annotations in JAX is a bit patchwork, and efforts to
This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX. This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX.
Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally. Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally.
In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/google/jax/pull/9917), [PR #10322](https://github.com/google/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python. In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/jax-ml/jax/pull/9917), [PR #10322](https://github.com/jax-ml/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python.
This document details JAX's goals and recommendations for type annotations within the package. This document details JAX's goals and recommendations for type annotations within the package.
## Why type annotations? ## Why type annotations?
@ -21,7 +21,7 @@ There are a number of reasons that a Python project might wish to annotate their
### Level 1: Annotations as documentation ### Level 1: Annotations as documentation
When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]: When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]:
```python ```python
Array = Any Array = Any
@ -44,14 +44,14 @@ Many modern IDEs take advantage of type annotations as inputs to [intelligent co
This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development. This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development.
JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]: JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]:
```python ```python
def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray: def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
... ...
``` ```
In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result: In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:
![VSCode Intellisense Screenshot](../_static/vscode-completion.png) ![VSCode Intellisense Screenshot](../_static/vscode-completion.png)
@ -232,7 +232,7 @@ assert jit(f)(x) # x will be a tracer
``` ```
Again, there are a couple mechanisms that could be used for this: Again, there are a couple mechanisms that could be used for this:
- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/google/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)). - override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/jax-ml/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)).
- define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer` - define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer`
- restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer` - restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer`

@ -170,7 +170,7 @@ print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
This module could expose our mechanism for defining new RNG This module could expose our mechanism for defining new RNG
implementations, and functions for working with PRNG key internals implementations, and functions for working with PRNG key internals
(see issue [#9263](https://github.com/google/jax/issues/9263)), (see issue [#9263](https://github.com/jax-ml/jax/issues/9263)),
such as the current `jax._src.prng.random_wrap` and such as the current `jax._src.prng.random_wrap` and
`random_unwrap`. `random_unwrap`.

@ -78,8 +78,8 @@ to JAX which have relatively complex implementations which are difficult to vali
and introduce outsized maintenance burdens; an example is {func}`jax.scipy.special.bessel_jn`: and introduce outsized maintenance burdens; an example is {func}`jax.scipy.special.bessel_jn`:
as of the writing of this JEP, its current implementation is a non-straightforward as of the writing of this JEP, its current implementation is a non-straightforward
iterative approximation that has iterative approximation that has
[convergence issues in some domains](https://github.com/google/jax/issues/12402#issuecomment-1384828637), [convergence issues in some domains](https://github.com/jax-ml/jax/issues/12402#issuecomment-1384828637),
and [proposed fixes](https://github.com/google/jax/pull/17038/files) introduce further and [proposed fixes](https://github.com/jax-ml/jax/pull/17038/files) introduce further
complexity. Had we more carefully weighed the complexity and robustness of the complexity. Had we more carefully weighed the complexity and robustness of the
implementation when accepting the contribution, we may have chosen not to accept this implementation when accepting the contribution, we may have chosen not to accept this
contribution to the package. contribution to the package.

@ -35,9 +35,9 @@ behavior of their code. This customization
Python control flow and workflows for NaN debugging. Python control flow and workflows for NaN debugging.
As **JAX developers** we want to write library functions, like As **JAX developers** we want to write library functions, like
[`logit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83) [`logit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83)
and and
[`expit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91), [`expit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91),
that are defined in terms of other primitives, but for the purposes of that are defined in terms of other primitives, but for the purposes of
differentiation have primitive-like behavior in the sense that we want to define differentiation have primitive-like behavior in the sense that we want to define
custom differentiation rules for them, which may be more numerically stable or custom differentiation rules for them, which may be more numerically stable or
@ -50,9 +50,9 @@ looking to add custom differentiation rules for higher-order functions like
want to be confident were not going to preclude good solutions to that problem. want to be confident were not going to preclude good solutions to that problem.
That is, our primary goals are That is, our primary goals are
1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/google/jax/issues/1249)), and 1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/jax-ml/jax/issues/1249)), and
2. allow Python in custom VJPs, e.g. to debug NaNs 2. allow Python in custom VJPs, e.g. to debug NaNs
([#1275](https://github.com/google/jax/issues/1275)). ([#1275](https://github.com/jax-ml/jax/issues/1275)).
Secondary goals are Secondary goals are
3. clean up and simplify user experience (symbolic zeros, kwargs, etc) 3. clean up and simplify user experience (symbolic zeros, kwargs, etc)
@ -60,18 +60,18 @@ Secondary goals are
`odeint`, `root`, etc. `odeint`, `root`, etc.
Overall, we want to close Overall, we want to close
[#116](https://github.com/google/jax/issues/116), [#116](https://github.com/jax-ml/jax/issues/116),
[#1097](https://github.com/google/jax/issues/1097), [#1097](https://github.com/jax-ml/jax/issues/1097),
[#1249](https://github.com/google/jax/issues/1249), [#1249](https://github.com/jax-ml/jax/issues/1249),
[#1275](https://github.com/google/jax/issues/1275), [#1275](https://github.com/jax-ml/jax/issues/1275),
[#1366](https://github.com/google/jax/issues/1366), [#1366](https://github.com/jax-ml/jax/issues/1366),
[#1723](https://github.com/google/jax/issues/1723), [#1723](https://github.com/jax-ml/jax/issues/1723),
[#1670](https://github.com/google/jax/issues/1670), [#1670](https://github.com/jax-ml/jax/issues/1670),
[#1875](https://github.com/google/jax/issues/1875), [#1875](https://github.com/jax-ml/jax/issues/1875),
[#1938](https://github.com/google/jax/issues/1938), [#1938](https://github.com/jax-ml/jax/issues/1938),
and replace the custom_transforms machinery (from and replace the custom_transforms machinery (from
[#636](https://github.com/google/jax/issues/636), [#636](https://github.com/jax-ml/jax/issues/636),
[#818](https://github.com/google/jax/issues/818), [#818](https://github.com/jax-ml/jax/issues/818),
and others). and others).
## Non-goals ## Non-goals
@ -400,7 +400,7 @@ There are some other bells and whistles to the API:
resolved to positions using the `inspect` module. This is a bit of an experiment resolved to positions using the `inspect` module. This is a bit of an experiment
with Python 3s improved ability to programmatically inspect argument with Python 3s improved ability to programmatically inspect argument
signatures. I believe it is sound but not complete, which is a fine place to be. signatures. I believe it is sound but not complete, which is a fine place to be.
(See also [#2069](https://github.com/google/jax/issues/2069).) (See also [#2069](https://github.com/jax-ml/jax/issues/2069).)
* Arguments can be marked non-differentiable using `nondiff_argnums`, and as with * Arguments can be marked non-differentiable using `nondiff_argnums`, and as with
`jit`s `static_argnums` these arguments dont have to be JAX types. We need to `jit`s `static_argnums` these arguments dont have to be JAX types. We need to
set a convention for how these arguments are passed to the rules. For a primal set a convention for how these arguments are passed to the rules. For a primal
@ -433,5 +433,5 @@ There are some other bells and whistles to the API:
`custom_lin` to the tangent values; `custom_lin` carries with it the users `custom_lin` to the tangent values; `custom_lin` carries with it the users
custom backward-pass function, and as a primitive it only has a transpose custom backward-pass function, and as a primitive it only has a transpose
rule. rule.
* This mechanism is described more in [#636](https://github.com/google/jax/issues/636). * This mechanism is described more in [#636](https://github.com/jax-ml/jax/issues/636).
* To prevent * To prevent

@ -9,7 +9,7 @@ notebook.
## What to update ## What to update
After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments After JAX [PR #4008](https://github.com/jax-ml/jax/pull/4008), the arguments
passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or
containers of `Tracer`s), which basically means to allow for containers of `Tracer`s), which basically means to allow for
arbitrarily-transformable code `nondiff_argnums` shouldn't be used for arbitrarily-transformable code `nondiff_argnums` shouldn't be used for
@ -95,7 +95,7 @@ acted very much like lexical closure. But lexical closure over `Tracer`s wasn't
at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing
`nondiff_argnums` that way was a mistake! `nondiff_argnums` that way was a mistake!
**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure **[PR #4008](https://github.com/jax-ml/jax/pull/4008) fixes all lexical closure
issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp` issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp`
and `custom_vjp` functions and rules can close over `Tracer`s to our hearts' and `custom_vjp` functions and rules can close over `Tracer`s to our hearts'
content. For all non-autodiff transformations, things will Just Work. For content. For all non-autodiff transformations, things will Just Work. For
@ -120,9 +120,9 @@ manageable, until you think through how we have to handle arbitrary pytrees!
Moreover, that complexity isn't necessary: if user code treats array-like Moreover, that complexity isn't necessary: if user code treats array-like
non-differentiable arguments just like regular arguments and residuals, non-differentiable arguments just like regular arguments and residuals,
everything already works. (Before everything already works. (Before
[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about [#4039](https://github.com/jax-ml/jax/pull/4039) JAX might've complained about
involving integer-valued inputs and outputs in autodiff, but after involving integer-valued inputs and outputs in autodiff, but after
[#4039](https://github.com/google/jax/pull/4039) those will just work!) [#4039](https://github.com/jax-ml/jax/pull/4039) those will just work!)
Unlike `custom_vjp`, it was easy to make `custom_jvp` work with Unlike `custom_vjp`, it was easy to make `custom_jvp` work with
`nondiff_argnums` arguments that were `Tracer`s. So these updates only need to `nondiff_argnums` arguments that were `Tracer`s. So these updates only need to

@ -20,7 +20,7 @@ This is more of an upgrade guide than a design doc.
### What's going on? ### What's going on?
A change to JAX's tracing infrastructure called “omnistaging” A change to JAX's tracing infrastructure called “omnistaging”
([google/jax#3370](https://github.com/google/jax/pull/3370)) was switched on in ([jax-ml/jax#3370](https://github.com/jax-ml/jax/pull/3370)) was switched on in
jax==0.2.0. This change improves memory performance, trace execution time, and jax==0.2.0. This change improves memory performance, trace execution time, and
simplifies jax internals, but may cause some existing code to break. Breakage is simplifies jax internals, but may cause some existing code to break. Breakage is
usually a result of buggy code, so long-term its best to fix the bugs, but usually a result of buggy code, so long-term its best to fix the bugs, but
@ -191,7 +191,7 @@ and potentially even fragmenting memory.
(The `broadcast` that corresponds to the construction of the zeros array for (The `broadcast` that corresponds to the construction of the zeros array for
`jnp.zeros_like(x)` is staged out because JAX is lazy about very simple `jnp.zeros_like(x)` is staged out because JAX is lazy about very simple
expressions from [google/jax#1668](https://github.com/google/jax/pull/1668). After expressions from [jax-ml/jax#1668](https://github.com/jax-ml/jax/pull/1668). After
omnistaging, we can remove that lazy sublanguage and simplify JAX internals.) omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)
The reason the creation of `mask` is not staged out is that, before omnistaging, The reason the creation of `mask` is not staged out is that, before omnistaging,

@ -321,7 +321,7 @@ Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same
extended dtype mechanism elsewhere internally. For example, the extended dtype mechanism elsewhere internally. For example, the
`jax._src.core.bint` object, a bounded integer type used for experimental work `jax._src.core.bint` object, a bounded integer type used for experimental work
on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies
the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)). the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/jax-ml/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)).
### PRNG dtypes ### PRNG dtypes
PRNG dtypes are defined as a particular case of extended dtypes. Specifically, PRNG dtypes are defined as a particular case of extended dtypes. Specifically,

@ -8,7 +8,7 @@
"source": [ "source": [
"# Design of Type Promotion Semantics for JAX\n", "# Design of Type Promotion Semantics for JAX\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
"\n", "\n",
"*Jake VanderPlas, December 2021*\n", "*Jake VanderPlas, December 2021*\n",
"\n", "\n",

@ -16,7 +16,7 @@ kernelspec:
# Design of Type Promotion Semantics for JAX # Design of Type Promotion Semantics for JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
*Jake VanderPlas, December 2021* *Jake VanderPlas, December 2021*

@ -58,11 +58,11 @@ These constraints imply the following rules for releases:
* If a new `jaxlib` is released, a `jax` release must be made at the same time. * If a new `jaxlib` is released, a `jax` release must be made at the same time.
These These
[version constraints](https://github.com/google/jax/blob/main/jax/version.py) [version constraints](https://github.com/jax-ml/jax/blob/main/jax/version.py)
are currently checked by `jax` at import time, instead of being expressed as are currently checked by `jax` at import time, instead of being expressed as
Python package version constraints. `jax` checks the `jaxlib` version at Python package version constraints. `jax` checks the `jaxlib` version at
runtime rather than using a `pip` package version constraint because we runtime rather than using a `pip` package version constraint because we
[provide separate `jaxlib` wheels](https://github.com/google/jax#installation) [provide separate `jaxlib` wheels](https://github.com/jax-ml/jax#installation)
for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we
do not know which is the right choice for any given user, we do not want `pip` do not know which is the right choice for any given user, we do not want `pip`
to install a `jaxlib` package for us automatically. to install a `jaxlib` package for us automatically.
@ -119,7 +119,7 @@ no released `jax` version uses that API.
## How is the source to `jaxlib` laid out? ## How is the source to `jaxlib` laid out?
`jaxlib` is split across two main repositories, namely the `jaxlib` is split across two main repositories, namely the
[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib) [`jaxlib/` subdirectory in the main JAX repository](https://github.com/jax-ml/jax/tree/main/jaxlib)
and in the and in the
[XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla). [XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla).
The JAX-specific pieces inside XLA are primarily in the The JAX-specific pieces inside XLA are primarily in the
@ -146,7 +146,7 @@ level.
`jaxlib` is built using Bazel out of the `jax` repository. The pieces of `jaxlib` is built using Bazel out of the `jax` repository. The pieces of
`jaxlib` from the XLA repository are incorporated into the build `jaxlib` from the XLA repository are incorporated into the build
[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE). [as a Bazel submodule](https://github.com/jax-ml/jax/blob/main/WORKSPACE).
To update the version of XLA used during the build, one must update the pinned To update the version of XLA used during the build, one must update the pinned
version in the Bazel `WORKSPACE`. This is done manually on an version in the Bazel `WORKSPACE`. This is done manually on an
as-needed basis, but can be overridden on a build-by-build basis. as-needed basis, but can be overridden on a build-by-build basis.

@ -32,7 +32,7 @@ should be linked to this issue.
Then create a pull request that adds a file named Then create a pull request that adds a file named
`%d-{short-title}.md` - with the number being the issue number. `%d-{short-title}.md` - with the number being the issue number.
.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP .. _JEP label: https://github.com/jax-ml/jax/issues?q=label%3AJEP
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-06-03' } *-->\n", "<!--* freshness: { reviewed: '2024-06-03' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
] ]
}, },
{ {
@ -661,7 +661,7 @@
"source": [ "source": [
"Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n", "Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n",
"\n", "\n",
"Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)." "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
] ]
}, },
{ {
@ -1003,7 +1003,7 @@
"id": "COjzGBpO4tzL" "id": "COjzGBpO4tzL"
}, },
"source": [ "source": [
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n", "\n",
"The random state is described by a special array element that we call a __key__:" "The random state is described by a special array element that we call a __key__:"
] ]
@ -1349,7 +1349,7 @@
"\n", "\n",
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n", "For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
"\n", "\n",
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
"\n", "\n",
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n", "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
"\n", "\n",

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-06-03' } *--> <!--* freshness: { reviewed: '2024-06-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
+++ {"id": "4k5PVzEo2uJO"} +++ {"id": "4k5PVzEo2uJO"}
@ -312,7 +312,7 @@ jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error. Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior). Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
+++ {"id": "LwB07Kx5sgHu"} +++ {"id": "LwB07Kx5sgHu"}
@ -460,7 +460,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
+++ {"id": "COjzGBpO4tzL"} +++ {"id": "COjzGBpO4tzL"}
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by a special array element that we call a __key__: The random state is described by a special array element that we call a __key__:
@ -623,7 +623,7 @@ When we `jit`-compile a function, we usually want to compile a version of the fu
For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.
To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n", "<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
"\n", "\n",
"There are two ways to define differentiation rules in JAX:\n", "There are two ways to define differentiation rules in JAX:\n",
"\n", "\n",

@ -17,7 +17,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *--> <!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
There are two ways to define differentiation rules in JAX: There are two ways to define differentiation rules in JAX:

@ -17,7 +17,7 @@
"id": "pFtQjv4SzHRj" "id": "pFtQjv4SzHRj"
}, },
"source": [ "source": [
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
"\n", "\n",
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer." "This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer."
] ]

@ -19,7 +19,7 @@ kernelspec:
+++ {"id": "pFtQjv4SzHRj"} +++ {"id": "pFtQjv4SzHRj"}
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)
This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer. This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n", "<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
"\n", "\n",
"*necula@google.com*, October 2019.\n", "*necula@google.com*, October 2019.\n",
"\n", "\n",

@ -17,7 +17,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *--> <!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
*necula@google.com*, October 2019. *necula@google.com*, October 2019.

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n", "<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
"\n", "\n",
"**Copyright 2018 The JAX Authors.**\n", "**Copyright 2018 The JAX Authors.**\n",
"\n", "\n",
@ -32,9 +32,9 @@
"id": "B_XlLLpcWjkA" "id": "B_XlLLpcWjkA"
}, },
"source": [ "source": [
"![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n",
"\n", "\n",
"Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
"\n", "\n",
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
] ]

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-05-03' } *--> <!--* freshness: { reviewed: '2024-05-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
**Copyright 2018 The JAX Authors.** **Copyright 2018 The JAX Authors.**
@ -35,9 +35,9 @@ limitations under the License.
+++ {"id": "B_XlLLpcWjkA"} +++ {"id": "B_XlLLpcWjkA"}
![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) ![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)
Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n", "<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
] ]
}, },
{ {
@ -79,7 +79,7 @@
"id": "gA8V51wZdsjh" "id": "gA8V51wZdsjh"
}, },
"source": [ "source": [
"When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/google/jax#how-it-works) section in the README." "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/jax-ml/jax#how-it-works) section in the README."
] ]
}, },
{ {
@ -320,7 +320,7 @@
"source": [ "source": [
"Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n", "Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n",
"\n", "\n",
"Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover." "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
] ]
}, },
{ {
@ -333,7 +333,7 @@
"\n", "\n",
"An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n", "An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n",
"\n", "\n",
"It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)." "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
] ]
}, },
{ {

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *--> <!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
+++ {"id": "r-3vMiKRYXPJ"} +++ {"id": "r-3vMiKRYXPJ"}
@ -57,7 +57,7 @@ fast_f = jit(f)
+++ {"id": "gA8V51wZdsjh"} +++ {"id": "gA8V51wZdsjh"}
When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/google/jax#how-it-works) section in the README. When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/jax-ml/jax#how-it-works) section in the README.
+++ {"id": "2Th1vYLVaFBz"} +++ {"id": "2Th1vYLVaFBz"}
@ -223,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
Notice that `eval_jaxpr` will always return a flat list even if the original function does not. Notice that `eval_jaxpr` will always return a flat list even if the original function does not.
Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover. Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
+++ {"id": "0vb2ZoGrCMM4"} +++ {"id": "0vb2ZoGrCMM4"}
@ -231,7 +231,7 @@ Furthermore, this interpreter does not handle higher-order primitives (like `jit
An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry. An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.
It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234). It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234).
```{code-cell} ipython3 ```{code-cell} ipython3
:id: gSMIT2z1vUpO :id: gSMIT2z1vUpO

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n", "<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
"\n", "\n",
"JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics." "JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics."
] ]
@ -255,7 +255,7 @@
"id": "cJ2NxiN58bfI" "id": "cJ2NxiN58bfI"
}, },
"source": [ "source": [
"You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)." "You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)."
] ]
}, },
{ {
@ -1015,7 +1015,7 @@
"source": [ "source": [
"### Jacobian-Matrix and Matrix-Jacobian products\n", "### Jacobian-Matrix and Matrix-Jacobian products\n",
"\n", "\n",
"Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products." "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products."
] ]
}, },
{ {

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *--> <!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics. JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.
@ -151,7 +151,7 @@ print(grad(loss2)({'W': W, 'b': b}))
+++ {"id": "cJ2NxiN58bfI"} +++ {"id": "cJ2NxiN58bfI"}
You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.). You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.).
+++ {"id": "PaCHzAtGruBz"} +++ {"id": "PaCHzAtGruBz"}
@ -592,7 +592,7 @@ print("Naive full Hessian materialization")
### Jacobian-Matrix and Matrix-Jacobian products ### Jacobian-Matrix and Matrix-Jacobian products
Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products. Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
```{code-cell} ipython3 ```{code-cell} ipython3
:id: asAWvxVaCmsx :id: asAWvxVaCmsx

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n", "<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
"\n", "\n",
"JAX provides a number of interfaces to compute convolutions across data, including:\n", "JAX provides a number of interfaces to compute convolutions across data, including:\n",
"\n", "\n",

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *--> <!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)
JAX provides a number of interfaces to compute convolutions across data, including: JAX provides a number of interfaces to compute convolutions across data, including:

@ -40,11 +40,11 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n", "<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
"\n", "\n",
"_Forked from_ `neural_network_and_data_loading.ipynb`\n", "_Forked from_ `neural_network_and_data_loading.ipynb`\n",
"\n", "\n",
"![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n",
"\n", "\n",
"Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n",
"\n", "\n",

@ -38,11 +38,11 @@ limitations under the License.
<!--* freshness: { reviewed: '2024-05-03' } *--> <!--* freshness: { reviewed: '2024-05-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
_Forked from_ `neural_network_and_data_loading.ipynb` _Forked from_ `neural_network_and_data_loading.ipynb`
![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) ![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)
Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n", "<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
"\n", "\n",
"JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively." "JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively."
] ]

@ -17,7 +17,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *--> <!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively. JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.

@ -10,7 +10,7 @@
"\n", "\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n", "<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n", "\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
"\n", "\n",
"This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n",
"\n", "\n",

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *--> <!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs. This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.

@ -21,7 +21,7 @@ software emulation, and can slow down the computation.
If you see unexpected outputs, please compare them against a kernel run with If you see unexpected outputs, please compare them against a kernel run with
``interpret=True`` passed in to ``pallas_call``. If the results diverge, ``interpret=True`` passed in to ``pallas_call``. If the results diverge,
please file a `bug report <https://github.com/google/jax/issues/new/choose>`_. please file a `bug report <https://github.com/jax-ml/jax/issues/new/choose>`_.
What is a TPU? What is a TPU?
-------------- --------------

@ -29,7 +29,7 @@ f(x)
### Setting cache directory ### Setting cache directory
The compilation cache is enabled when the The compilation cache is enabled when the
[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206) [cache location](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as is set. This should be done prior to the first compilation. Set the location as
follows: follows:
@ -54,7 +54,7 @@ os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
``` ```
(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18) (3) Using [`set_cache_dir()`](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
```python ```python
from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.compilation_cache import compilation_cache as cc

@ -26,14 +26,14 @@ def jax_issue_role(name, rawtext, text, lineno, inliner, options=None,
:jax-issue:`1234` :jax-issue:`1234`
This will output a hyperlink of the form This will output a hyperlink of the form
`#1234 <http://github.com/google/jax/issues/1234>`_. These links work even `#1234 <http://github.com/jax-ml/jax/issues/1234>`_. These links work even
for PR numbers. for PR numbers.
""" """
text = text.lstrip('#') text = text.lstrip('#')
if not text.isdigit(): if not text.isdigit():
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.") raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.")
options = {} if options is None else options options = {} if options is None else options
url = f"https://github.com/google/jax/issues/{text}" url = f"https://github.com/jax-ml/jax/issues/{text}"
node = nodes.reference(rawtext, '#' + text, refuri=url, **options) node = nodes.reference(rawtext, '#' + text, refuri=url, **options)
return [node], [] return [node], []

@ -234,4 +234,4 @@ Handling parameters manually seems fine if you're dealing with two parameters, b
2) Are we supposed to pipe all these things around manually? 2) Are we supposed to pipe all these things around manually?
The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples. The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/jax-ml/jax#neural-network-libraries) for some examples.

@ -29,7 +29,7 @@ except Exception as exc:
# Defensively swallow any exceptions to avoid making jax unimportable # Defensively swallow any exceptions to avoid making jax unimportable
from warnings import warn as _warn from warnings import warn as _warn
_warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report " _warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report "
f"an issue at https://github.com/google/jax/issues") f"an issue at https://github.com/jax-ml/jax/issues")
del _warn del _warn
del _cloud_tpu_init del _cloud_tpu_init
@ -38,7 +38,7 @@ import jax.core as _core
del _core del _core
# Note: import <name> as <name> is required for names to be exported. # Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570 # See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.basearray import Array as Array from jax._src.basearray import Array as Array
from jax import tree as tree from jax import tree as tree

@ -546,7 +546,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
# To avoid precision mismatches in fwd and bwd passes due to XLA excess # To avoid precision mismatches in fwd and bwd passes due to XLA excess
# precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls # precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls
# on producers of any residuals. See https://github.com/google/jax/pull/22244. # on producers of any residuals. See https://github.com/jax-ml/jax/pull/22244.
jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res) jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res)
# compute known outputs and residuals (hoisted out of remat primitive) # compute known outputs and residuals (hoisted out of remat primitive)

@ -956,7 +956,7 @@ def vmap(fun: F,
# list: if in_axes is not a leaf, it must be a tuple of trees. However, # list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated # in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here # essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367 # rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_axes = tuple(in_axes) in_axes = tuple(in_axes)
if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}): if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}):
@ -2505,7 +2505,7 @@ class ShapeDtypeStruct:
def __hash__(self): def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing # TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/google/jax/issues/8182 # https://github.com/jax-ml/jax/issues/8182
return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type))
def _sds_aval_mapping(x): def _sds_aval_mapping(x):

@ -196,7 +196,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
device_assignment = axis_context.device_assignment device_assignment = axis_context.device_assignment
if device_assignment is None: if device_assignment is None:
raise AssertionError( raise AssertionError(
"Please file a bug at https://github.com/google/jax/issues") "Please file a bug at https://github.com/jax-ml/jax/issues")
try: try:
device_index = device_assignment.index(device) device_index = device_assignment.index(device)
except IndexError as e: except IndexError as e:

@ -1170,7 +1170,7 @@ softmax_custom_jvp = bool_state(
upgrade=True, upgrade=True,
help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should '
'improve memory usage and stability. Set True to use new ' 'improve memory usage and stability. Set True to use new '
'behavior. See https://github.com/google/jax/pull/15677'), 'behavior. See https://github.com/jax-ml/jax/pull/15677'),
update_global_hook=lambda val: _update_global_jit_state( update_global_hook=lambda val: _update_global_jit_state(
softmax_custom_jvp=val), softmax_custom_jvp=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state( update_thread_local_hook=lambda val: update_thread_local_jit_state(

@ -935,7 +935,7 @@ aval_method = namedtuple("aval_method", ["fun"])
class EvalTrace(Trace): class EvalTrace(Trace):
# See comments in https://github.com/google/jax/pull/3370 # See comments in https://github.com/jax-ml/jax/pull/3370
def pure(self, x): return x def pure(self, x): return x
lift = sublift = pure lift = sublift = pure
@ -998,7 +998,7 @@ class MainTrace:
return self.trace_type(self, cur_sublevel(), **self.payload) return self.trace_type(self, cur_sublevel(), **self.payload)
class TraceStack: class TraceStack:
# See comments in https://github.com/google/jax/pull/3370 # See comments in https://github.com/jax-ml/jax/pull/3370
stack: list[MainTrace] stack: list[MainTrace]
dynamic: MainTrace dynamic: MainTrace
@ -1167,7 +1167,7 @@ def _why_alive(ignore_ids: set[int], x: Any) -> str:
# parent->child jump. We do that by setting `parent` here to be a # parent->child jump. We do that by setting `parent` here to be a
# grandparent (or great-grandparent) of `child`, and then handling that case # grandparent (or great-grandparent) of `child`, and then handling that case
# in _why_alive_container_info. See example: # in _why_alive_container_info. See example:
# https://github.com/google/jax/pull/13022#discussion_r1008456599 # https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599
# To prevent this collapsing behavior, just comment out this code block. # To prevent this collapsing behavior, just comment out this code block.
if (isinstance(parent, dict) and if (isinstance(parent, dict) and
getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]): getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
@ -1213,7 +1213,7 @@ def _why_alive_container_info(container, obj_id) -> str:
@contextmanager @contextmanager
def new_main(trace_type: type[Trace], dynamic: bool = False, def new_main(trace_type: type[Trace], dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]: **payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370 # See comments in https://github.com/jax-ml/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack stack = thread_local_state.trace_state.trace_stack
level = stack.next_level() level = stack.next_level()
main = MainTrace(level, trace_type, **payload) main = MainTrace(level, trace_type, **payload)
@ -1254,7 +1254,7 @@ def dynamic_level() -> int:
@contextmanager @contextmanager
def new_base_main(trace_type: type[Trace], def new_base_main(trace_type: type[Trace],
**payload) -> Generator[MainTrace, None, None]: **payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370 # See comments in https://github.com/jax-ml/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type, **payload) main = MainTrace(0, trace_type, **payload)
prev_dynamic, stack.dynamic = stack.dynamic, main prev_dynamic, stack.dynamic = stack.dynamic, main
@ -1319,7 +1319,7 @@ def ensure_compile_time_eval():
else: else:
return jnp.cos(x) return jnp.cos(x)
Here's a real-world example from https://github.com/google/jax/issues/3974:: Here's a real-world example from https://github.com/jax-ml/jax/issues/3974::
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -1680,7 +1680,7 @@ class UnshapedArray(AbstractValue):
@property @property
def shape(self): def shape(self):
msg = ("UnshapedArray has no shape. Please open an issue at " msg = ("UnshapedArray has no shape. Please open an issue at "
"https://github.com/google/jax/issues because it's unexpected for " "https://github.com/jax-ml/jax/issues because it's unexpected for "
"UnshapedArray instances to ever be produced.") "UnshapedArray instances to ever be produced.")
raise TypeError(msg) raise TypeError(msg)

@ -191,7 +191,7 @@ def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
# TODO(frostig): assert these also equal: # TODO(frostig): assert these also equal:
# treedef_tuple((in_tree, in_tree)) # treedef_tuple((in_tree, in_tree))
# once https://github.com/google/jax/issues/9066 is fixed # once https://github.com/jax-ml/jax/issues/9066 is fixed
assert tree_ps_ts == tree_ps_ts2 assert tree_ps_ts == tree_ps_ts2
del tree_ps_ts2 del tree_ps_ts2

@ -1144,7 +1144,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
def _maybe_perturbed(x: Any) -> bool: def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value # False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise. # with a nontrivial tangent attached), up to heuristics, and True otherwise.
# See https://github.com/google/jax/issues/6415 for motivation. # See https://github.com/jax-ml/jax/issues/6415 for motivation.
x = core.full_lower(x) x = core.full_lower(x)
if not isinstance(x, core.Tracer): if not isinstance(x, core.Tracer):
# If x is not a Tracer, it can't be perturbed. # If x is not a Tracer, it can't be perturbed.

@ -492,7 +492,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
devices = axis_context.device_assignment devices = axis_context.device_assignment
if devices is None: if devices is None:
raise AssertionError( raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues') 'Please file a bug at https://github.com/jax-ml/jax/issues')
if axis_context.mesh_shape is not None: if axis_context.mesh_shape is not None:
ma, ms = list(zip(*axis_context.mesh_shape)) ma, ms = list(zip(*axis_context.mesh_shape))
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma) mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)

@ -389,7 +389,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
devices = axis_context.device_assignment devices = axis_context.device_assignment
if devices is None: if devices is None:
raise AssertionError( raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues') 'Please file a bug at https://github.com/jax-ml/jax/issues')
elif isinstance(axis_context, sharding_impls.SPMDAxisContext): elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = axis_context.mesh._flat_devices_tuple devices = axis_context.mesh._flat_devices_tuple
else: else:

@ -793,7 +793,7 @@ def check_user_dtype_supported(dtype, fun_name=None):
"and will be truncated to dtype {}. To enable more dtypes, set the " "and will be truncated to dtype {}. To enable more dtypes, set the "
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
"environment variable. " "environment variable. "
"See https://github.com/google/jax#current-gotchas for more.") "See https://github.com/jax-ml/jax#current-gotchas for more.")
fun_name = f"requested in {fun_name}" if fun_name else "" fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = canonicalize_dtype(np_dtype).name truncated_dtype = canonicalize_dtype(np_dtype).name
warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3) warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3)

@ -61,7 +61,7 @@ def _ravel_list(lst):
if all(dt == to_dtype for dt in from_dtypes): if all(dt == to_dtype for dt in from_dtypes):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
# See https://github.com/google/jax/issues/7809. # See https://github.com/jax-ml/jax/issues/7809.
del from_dtypes, to_dtype del from_dtypes, to_dtype
raveled = jnp.concatenate([jnp.ravel(e) for e in lst]) raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes)

@ -30,7 +30,7 @@ Instead of writing this information as conditions inside one
particular test, we write them as `Limitation` objects that can be reused in particular test, we write them as `Limitation` objects that can be reused in
multiple tests and can also be used to generate documentation, e.g., multiple tests and can also be used to generate documentation, e.g.,
the report of [unsupported and partially-implemented JAX the report of [unsupported and partially-implemented JAX
primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md)
The limitations are used to filter out from tests the harnesses that are known The limitations are used to filter out from tests the harnesses that are known
to fail. A Limitation is specific to a harness. to fail. A Limitation is specific to a harness.
@ -515,7 +515,7 @@ def _make_convert_element_type_harness(name,
for old_dtype in jtu.dtypes.all: for old_dtype in jtu.dtypes.all:
# TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating # TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating
# point numbers and new_dtype is an unsigned integer. See issue # point numbers and new_dtype is an unsigned integer. See issue
# https://github.com/google/jax/issues/5082 for details. # https://github.com/jax-ml/jax/issues/5082 for details.
for new_dtype in (jtu.dtypes.all for new_dtype in (jtu.dtypes.all
if not (dtypes.issubdtype(old_dtype, np.floating) or if not (dtypes.issubdtype(old_dtype, np.floating) or
dtypes.issubdtype(old_dtype, np.complexfloating)) dtypes.issubdtype(old_dtype, np.complexfloating))
@ -2336,7 +2336,7 @@ _make_select_and_scatter_add_harness("select_prim", select_prim=lax.le_p)
# Validate padding # Validate padding
for padding in [ for padding in [
# TODO(bchetioui): commented out the test based on # TODO(bchetioui): commented out the test based on
# https://github.com/google/jax/issues/4690 # https://github.com/jax-ml/jax/issues/4690
# ((1, 2), (2, 3), (3, 4)) # non-zero padding # ((1, 2), (2, 3), (3, 4)) # non-zero padding
((1, 1), (1, 1), (1, 1)) # non-zero padding ((1, 1), (1, 1), (1, 1)) # non-zero padding
]: ]:

@ -1262,7 +1262,7 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool], name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool],
inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult: inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult:
msg = (f'custom-policy remat rule not implemented for {name}, ' msg = (f'custom-policy remat rule not implemented for {name}, '
'open a feature request at https://github.com/google/jax/issues!') 'open a feature request at https://github.com/jax-ml/jax/issues!')
raise NotImplementedError(msg) raise NotImplementedError(msg)
@ -2688,7 +2688,7 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
# TODO(mattjj): the following are deprecated; update callers to _nounits version # TODO(mattjj): the following are deprecated; update callers to _nounits version
# See https://github.com/google/jax/pull/9498 # See https://github.com/jax-ml/jax/pull/9498
@lu.transformation @lu.transformation
def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool], def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool],
pvals: Sequence[PartialVal]): pvals: Sequence[PartialVal]):

@ -500,7 +500,7 @@ class MapTrace(core.Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
if symbolic_zeros: if symbolic_zeros:
msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. " msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. "
"Please open an issue at https://github.com/google/jax/issues !") "Please open an issue at https://github.com/jax-ml/jax/issues !")
raise NotImplementedError(msg) raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros # always base main, can drop jvp del prim, jvp, symbolic_zeros # always base main, can drop jvp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
@ -513,7 +513,7 @@ class MapTrace(core.Trace):
out_trees, symbolic_zeros): out_trees, symbolic_zeros):
if symbolic_zeros: if symbolic_zeros:
msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. " msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. "
"Please open an issue at https://github.com/google/jax/issues !") "Please open an issue at https://github.com/jax-ml/jax/issues !")
raise NotImplementedError(msg) raise NotImplementedError(msg)
del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
@ -1869,7 +1869,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
"does not preserve sharded data representations and instead collects " "does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. " "input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. " "Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.") "See https://github.com/jax-ml/jax/issues/2926.")
if nreps > xb.device_count(backend): if nreps > xb.device_count(backend):
raise ValueError( raise ValueError(

@ -389,7 +389,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
branch_outs = [] branch_outs = []
for i, jaxpr in enumerate(branches_batched): for i, jaxpr in enumerate(branches_batched):
# Perform a select on the inputs for safety of reverse-mode autodiff; see # Perform a select on the inputs for safety of reverse-mode autodiff; see
# https://github.com/google/jax/issues/1052 # https://github.com/jax-ml/jax/issues/1052
predicate = lax.eq(index, lax._const(index, i)) predicate = lax.eq(index, lax._const(index, i))
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops] ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))

@ -715,7 +715,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
raise NotImplementedError raise NotImplementedError
if not all(init_lin): if not all(init_lin):
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963 pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963
consts, _, xs = split_list(args, [num_consts, num_carry]) consts, _, xs = split_list(args, [num_consts, num_carry])
ires, _ = split_list(consts, [num_ires]) ires, _ = split_list(consts, [num_ires])
@ -1169,7 +1169,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
if discharged_consts: if discharged_consts:
raise NotImplementedError("Discharged jaxpr has consts. If you see this, " raise NotImplementedError("Discharged jaxpr has consts. If you see this, "
"please open an issue at " "please open an issue at "
"https://github.com/google/jax/issues") "https://github.com/jax-ml/jax/issues")
def wrapped(*wrapped_args): def wrapped(*wrapped_args):
val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args, val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args,
[n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs]) [n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs])
@ -1838,7 +1838,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
if body_jaxpr_consts: if body_jaxpr_consts:
raise NotImplementedError("Body jaxpr has consts. If you see this error, " raise NotImplementedError("Body jaxpr has consts. If you see this error, "
"please open an issue at " "please open an issue at "
"https://github.com/google/jax/issues") "https://github.com/jax-ml/jax/issues")
# body_jaxpr has the signature (*body_consts, *carry) -> carry. # body_jaxpr has the signature (*body_consts, *carry) -> carry.
# Some of these body_consts are actually `Ref`s so when we discharge # Some of these body_consts are actually `Ref`s so when we discharge
# them, they also turn into outputs, effectively turning those consts into # them, they also turn into outputs, effectively turning those consts into

@ -157,7 +157,7 @@ def _irfft_transpose(t, fft_lengths):
out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
# Use JAX's convention for complex gradients # Use JAX's convention for complex gradients
# https://github.com/google/jax/issues/6223#issuecomment-807740707 # https://github.com/jax-ml/jax/issues/6223#issuecomment-807740707
return lax.conj(out) return lax.conj(out)
def _fft_transpose_rule(t, operand, fft_type, fft_lengths): def _fft_transpose_rule(t, operand, fft_type, fft_lengths):

@ -1077,7 +1077,7 @@ def _reduction_jaxpr(computation, aval):
if any(isinstance(c, core.Tracer) for c in consts): if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError( raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue " "Reduction computations can't close over Tracers. Please open an issue "
"at https://github.com/google/jax.") "at https://github.com/jax-ml/jax.")
return jaxpr, tuple(consts) return jaxpr, tuple(consts)
@cache() @cache()
@ -1090,7 +1090,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
if any(isinstance(c, core.Tracer) for c in consts): if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError( raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue " "Reduction computations can't close over Tracers. Please open an issue "
"at https://github.com/google/jax.") "at https://github.com/jax-ml/jax.")
return core.ClosedJaxpr(jaxpr, consts), out_tree() return core.ClosedJaxpr(jaxpr, consts), out_tree()
def _get_monoid_reducer(monoid_op: Callable, def _get_monoid_reducer(monoid_op: Callable,
@ -4911,7 +4911,7 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
return tree_util.tree_unflatten(p.out_tree(), out_flat) return tree_util.tree_unflatten(p.out_tree(), out_flat)
# TODO(https://github.com/google/jax/issues/13552): Look into making this a # TODO(https://github.com/jax-ml/jax/issues/13552): Look into making this a
# method on jax.Array so that we can bypass the XLA compilation here. # method on jax.Array so that we can bypass the XLA compilation here.
def _copy_impl(prim, *args, **kwargs): def _copy_impl(prim, *args, **kwargs):
a, = args a, = args

@ -781,7 +781,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
raise NotImplementedError( raise NotImplementedError(
'The derivatives of eigenvectors are not implemented, only ' 'The derivatives of eigenvectors are not implemented, only '
'eigenvalues. See ' 'eigenvalues. See '
'https://github.com/google/jax/issues/2748 for discussion.') 'https://github.com/jax-ml/jax/issues/2748 for discussion.')
# Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in
# https://arxiv.org/abs/1701.00392 # https://arxiv.org/abs/1701.00392
a, = primals a, = primals

@ -27,7 +27,7 @@ try:
except ModuleNotFoundError as err: except ModuleNotFoundError as err:
raise ModuleNotFoundError( raise ModuleNotFoundError(
'jax requires jaxlib to be installed. See ' 'jax requires jaxlib to be installed. See '
'https://github.com/google/jax#installation for installation instructions.' 'https://github.com/jax-ml/jax#installation for installation instructions.'
) from err ) from err
import jax.version import jax.version
@ -92,7 +92,7 @@ pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib pmap_lib = xla_client._xla.pmap_lib
# XLA garbage collection: see https://github.com/google/jax/issues/14882 # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882
def _xla_gc_callback(*args): def _xla_gc_callback(*args):
xla_client._xla.collect_garbage() xla_client._xla.collect_garbage()
gc.callbacks.append(_xla_gc_callback) gc.callbacks.append(_xla_gc_callback)

@ -333,7 +333,7 @@ class AbstractMesh:
should use this as an input to the sharding passed to with_sharding_constraint should use this as an input to the sharding passed to with_sharding_constraint
and mesh passed to shard_map to avoid tracing and lowering cache misses when and mesh passed to shard_map to avoid tracing and lowering cache misses when
your mesh shape and names stay the same but the devices change. your mesh shape and names stay the same but the devices change.
See the description of https://github.com/google/jax/pull/23022 for more See the description of https://github.com/jax-ml/jax/pull/23022 for more
details. details.
""" """

@ -3953,7 +3953,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
arrays_out = [asarray(arr, dtype=dtype) for arr in arrays] arrays_out = [asarray(arr, dtype=dtype) for arr in arrays]
# lax.concatenate can be slow to compile for wide concatenations, so form a # lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode. # tree of concatenations as a workaround especially for op-by-op mode.
# (https://github.com/google/jax/issues/653). # (https://github.com/jax-ml/jax/issues/653).
k = 16 k = 16
while len(arrays_out) > 1: while len(arrays_out) > 1:
arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
@ -4645,7 +4645,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
if all(not isinstance(leaf, Array) for leaf in leaves): if all(not isinstance(leaf, Array) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists # TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in # containing large integers; see discussion in
# https://github.com/google/jax/pull/6047. More correct would be to call # https://github.com/jax-ml/jax/pull/6047. More correct would be to call
# coerce_to_array on each leaf, but this may have performance implications. # coerce_to_array on each leaf, but this may have performance implications.
out = np.asarray(object, dtype=dtype) out = np.asarray(object, dtype=dtype)
elif isinstance(object, Array): elif isinstance(object, Array):
@ -10150,11 +10150,11 @@ def _eliminate_deprecated_list_indexing(idx):
if any(_should_unpack_list_index(i) for i in idx): if any(_should_unpack_list_index(i) for i in idx):
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[tuple(seq)]` instead of `arr[seq]`. " "use `arr[tuple(seq)]` instead of `arr[seq]`. "
"See https://github.com/google/jax/issues/4564 for more information.") "See https://github.com/jax-ml/jax/issues/4564 for more information.")
else: else:
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[array(seq)]` instead of `arr[seq]`. " "use `arr[array(seq)]` instead of `arr[seq]`. "
"See https://github.com/google/jax/issues/4564 for more information.") "See https://github.com/jax-ml/jax/issues/4564 for more information.")
raise TypeError(msg) raise TypeError(msg)
else: else:
idx = (idx,) idx = (idx,)

@ -968,7 +968,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
msg = ("jax.numpy.var does not yet support real dtype parameters when " msg = ("jax.numpy.var does not yet support real dtype parameters when "
"computing the variance of an array of complex values. The " "computing the variance of an array of complex values. The "
"semantics of numpy.var seem unclear in this case. Please comment " "semantics of numpy.var seem unclear in this case. Please comment "
"on https://github.com/google/jax/issues/2283 if this behavior is " "on https://github.com/jax-ml/jax/issues/2283 if this behavior is "
"important to you.") "important to you.")
raise ValueError(msg) raise ValueError(msg)
computation_dtype = dtype computation_dtype = dtype

@ -134,7 +134,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
Traceback (most recent call last): Traceback (most recent call last):
... ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static ``size`` In order to ensure statically-known output shapes, you can pass a static ``size``
argument: argument:
@ -217,7 +217,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike,
Traceback (most recent call last): Traceback (most recent call last):
... ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static ``size`` In order to ensure statically-known output shapes, you can pass a static ``size``
argument: argument:

@ -2176,7 +2176,7 @@ def sinc(x: ArrayLike, /) -> Array:
def _sinc_maclaurin(k, x): def _sinc_maclaurin(k, x):
# compute the kth derivative of x -> sin(x)/x evaluated at zero (since we # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we
# compute the monomial term in the jvp rule) # compute the monomial term in the jvp rule)
# TODO(mattjj): see https://github.com/google/jax/issues/10750 # TODO(mattjj): see https://github.com/jax-ml/jax/issues/10750
if k % 2: if k % 2:
return x * 0 return x * 0
else: else:

@ -35,7 +35,7 @@ FRAME_PATTERN = re.compile(
) )
MLIR_ERR_PREFIX = ( MLIR_ERR_PREFIX = (
'Pallas encountered an internal verification error.' 'Pallas encountered an internal verification error.'
'Please file a bug at https://github.com/google/jax/issues. ' 'Please file a bug at https://github.com/jax-ml/jax/issues. '
'Error details: ' 'Error details: '
) )

@ -833,7 +833,7 @@ def jaxpr_subcomp(
raise NotImplementedError( raise NotImplementedError(
"Unimplemented primitive in Pallas TPU lowering: " "Unimplemented primitive in Pallas TPU lowering: "
f"{eqn.primitive.name}. " f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues.") "Please file an issue on https://github.com/jax-ml/jax/issues.")
if eqn.primitive.multiple_results: if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, ans) map(write_env, eqn.outvars, ans)
else: else:

@ -549,7 +549,7 @@ def lower_jaxpr_to_mosaic_gpu(
raise NotImplementedError( raise NotImplementedError(
"Unimplemented primitive in Pallas Mosaic GPU lowering: " "Unimplemented primitive in Pallas Mosaic GPU lowering: "
f"{eqn.primitive.name}. " f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues." "Please file an issue on https://github.com/jax-ml/jax/issues."
) )
rule = mosaic_lowering_rules[eqn.primitive] rule = mosaic_lowering_rules[eqn.primitive]
rule_ctx = LoweringRuleContext( rule_ctx = LoweringRuleContext(

@ -381,7 +381,7 @@ def lower_jaxpr_to_triton_ir(
raise NotImplementedError( raise NotImplementedError(
"Unimplemented primitive in Pallas GPU lowering: " "Unimplemented primitive in Pallas GPU lowering: "
f"{eqn.primitive.name}. " f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues.") "Please file an issue on https://github.com/jax-ml/jax/issues.")
rule = triton_lowering_rules[eqn.primitive] rule = triton_lowering_rules[eqn.primitive]
avals_in = [v.aval for v in eqn.invars] avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars] avals_out = [v.aval for v in eqn.outvars]

@ -459,7 +459,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
# list: if in_axes is not a leaf, it must be a tuple of trees. However, # list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated # in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here # essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367 # rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_shardings = tuple(in_shardings) in_shardings = tuple(in_shardings)
in_layouts, in_shardings = _split_layout_and_sharding(in_shardings) in_layouts, in_shardings = _split_layout_and_sharding(in_shardings)
@ -1276,7 +1276,7 @@ def explain_tracing_cache_miss(
return done() return done()
# we think this is unreachable... # we think this is unreachable...
p("explanation unavailable! please open an issue at https://github.com/google/jax") p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax")
return done() return done()
@partial(lu.cache, explain=explain_tracing_cache_miss) @partial(lu.cache, explain=explain_tracing_cache_miss)
@ -1701,7 +1701,7 @@ def _pjit_call_impl_python(
"`jit` decorator, at the cost of losing optimizations. " "`jit` decorator, at the cost of losing optimizations. "
"\n\n" "\n\n"
"If you see this error, consider opening a bug report at " "If you see this error, consider opening a bug report at "
"https://github.com/google/jax.") "https://github.com/jax-ml/jax.")
raise FloatingPointError(msg) raise FloatingPointError(msg)

@ -508,7 +508,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
span = lax.convert_element_type(maxval - minval, unsigned_dtype) span = lax.convert_element_type(maxval - minval, unsigned_dtype)
# Ensure that span=1 when maxval <= minval, so minval is always returned; # Ensure that span=1 when maxval <= minval, so minval is always returned;
# https://github.com/google/jax/issues/222 # https://github.com/jax-ml/jax/issues/222
span = lax.select(maxval <= minval, lax.full_like(span, 1), span) span = lax.select(maxval <= minval, lax.full_like(span, 1), span)
# When maxval is out of range, the span has to be one larger. # When maxval is out of range, the span has to be one larger.
@ -2540,7 +2540,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array:
_btrs(key, count_btrs, q_btrs, shape, dtype, max_iters), _btrs(key, count_btrs, q_btrs, shape, dtype, max_iters),
) )
# ensure nan q always leads to nan output and nan or neg count leads to nan # ensure nan q always leads to nan output and nan or neg count leads to nan
# as discussed in https://github.com/google/jax/pull/16134#pullrequestreview-1446642709 # as discussed in https://github.com/jax-ml/jax/pull/16134#pullrequestreview-1446642709
invalid = (q_l_0 | q_is_nan | count_nan_or_neg) invalid = (q_l_0 | q_is_nan | count_nan_or_neg)
samples = lax.select( samples = lax.select(
invalid, invalid,

@ -176,7 +176,7 @@ def map_coordinates(
Note: Note:
Interpolation near boundaries differs from the scipy function, because JAX Interpolation near boundaries differs from the scipy function, because JAX
fixed an outstanding bug; see https://github.com/google/jax/issues/11097. fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097.
This function interprets the ``mode`` argument as documented by SciPy, but This function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy. not as implemented by SciPy.
""" """

@ -44,7 +44,7 @@ def shard_alike(x, y):
raise ValueError( raise ValueError(
'The leaves shapes of `x` and `y` should match. Got `x` leaf shape:' 'The leaves shapes of `x` and `y` should match. Got `x` leaf shape:'
f' {x_aval.shape} and `y` leaf shape: {y_aval.shape}. File an issue at' f' {x_aval.shape} and `y` leaf shape: {y_aval.shape}. File an issue at'
' https://github.com/google/jax/issues if you want this feature.') ' https://github.com/jax-ml/jax/issues if you want this feature.')
outs = [shard_alike_p.bind(x_, y_) for x_, y_ in safe_zip(x_flat, y_flat)] outs = [shard_alike_p.bind(x_, y_) for x_, y_ in safe_zip(x_flat, y_flat)]
x_out_flat, y_out_flat = zip(*outs) x_out_flat, y_out_flat = zip(*outs)

@ -1208,7 +1208,7 @@ class JaxTestCase(parameterized.TestCase):
y = np.asarray(y) y = np.asarray(y)
if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): if (not allow_object_dtype) and (x.dtype == object or y.dtype == object):
# See https://github.com/google/jax/issues/17867 # See https://github.com/jax-ml/jax/issues/17867
raise TypeError( raise TypeError(
"assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
"If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. " "If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. "

@ -21,7 +21,7 @@ exported at `jax.typing`. Until then, the contents here should be considered uns
and may change without notice. and may change without notice.
To see the proposal that led to the development of these tools, see To see the proposal that led to the development of these tools, see
https://github.com/google/jax/pull/11859/. https://github.com/jax-ml/jax/pull/11859/.
""" """
from __future__ import annotations from __future__ import annotations

@ -1232,7 +1232,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs):
if library_path is None: if library_path is None:
raise RuntimeError( raise RuntimeError(
"JAX TPU support not installed; cannot generate TPU topology. See" "JAX TPU support not installed; cannot generate TPU topology. See"
" https://github.com/google/jax#installation") " https://github.com/jax-ml/jax#installation")
c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path) c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path)
xla_client.profiler.register_plugin_profiler(c_api) xla_client.profiler.register_plugin_profiler(c_api)
assert xla_client.pjrt_plugin_loaded("tpu") assert xla_client.pjrt_plugin_loaded("tpu")

Some files were not shown because too many files have changed in this diff Show More