Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax

PiperOrigin-RevId: 676843138
This commit is contained in:
Michael Hudgins 2024-09-20 07:51:48 -07:00 committed by jax authors
parent afaa3bf43c
commit d4d1518c3d
257 changed files with 906 additions and 906 deletions

View File

@ -20,11 +20,11 @@ body:
* If you prefer a non-templated issue report, click [here][Raw report].
[Discussions]: https://github.com/google/jax/discussions
[Discussions]: https://github.com/jax-ml/jax/discussions
[issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues
[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues
[Raw report]: http://github.com/google/jax/issues/new
[Raw report]: http://github.com/jax-ml/jax/issues/new
- type: textarea
attributes:
label: Description

View File

@ -1,5 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: Have questions or need support?
url: https://github.com/google/jax/discussions
url: https://github.com/jax-ml/jax/discussions
about: Please ask questions on the Discussions tab

View File

@ -84,7 +84,7 @@ jobs:
failure()
&& steps.status.outcome == 'failure'
&& github.event_name == 'schedule'
&& github.repository == 'google/jax'
&& github.repository == 'jax-ml/jax'
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
with:
name: output-${{ matrix.python-version }}-log.jsonl

View File

@ -279,7 +279,7 @@ See the 0.4.33 release notes for more details.
which manifested as an incorrect output for cumulative reductions (#21403).
* Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(https://github.com/openxla/xla/pull/13301).
* Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396).
* Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396).
* Deprecations
* `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
@ -401,7 +401,7 @@ See the 0.4.33 release notes for more details.
branch consistent with that of NumPy 2.0.
* The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'`
and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has
changed](https://github.com/google/jax/issues/19085) so that
changed](https://github.com/jax-ml/jax/issues/19085) so that
mapping over keys results in random generation only from the first
key in the batch.
* Docs now use `jax.random.key` for construction of PRNG key arrays
@ -433,7 +433,7 @@ See the 0.4.33 release notes for more details.
* JAX export does not support older serialization versions anymore. Version 9
has been supported since October 27th, 2023 and has become the default
since February 1, 2024.
See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
See [a description of the versions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
This change could break clients that set a specific
JAX serialization version lower than 9.
@ -506,7 +506,7 @@ See the 0.4.33 release notes for more details.
* added the ability to specify symbolic constraints on the dimension variables.
This makes shape polymorphism more expressive, and gives a way to workaround
limitations in the reasoning about inequalities.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* with the addition of symbolic constraints ({jax-issue}`#19235`) we now
consider dimension variables from different scopes to be different, even
if they have the same name. Symbolic expressions from different scopes
@ -516,7 +516,7 @@ See the 0.4.33 release notes for more details.
The scope of a symbolic expression `e` can be read with `e.scope` and passed
into the above functions to direct them to construct symbolic expressions in
a given scope.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* simplified and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0
({jax-issue}`#19231`; note that this may result in user-visible behavior
@ -535,7 +535,7 @@ See the 0.4.33 release notes for more details.
strings for polymorphic shapes specifications ({jax-issue}`#19284`).
* JAX default native serialization version is now 9. This is relevant
for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`.
See [description of version numbers](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
See [description of version numbers](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
@ -781,19 +781,19 @@ See the 0.4.33 release notes for more details.
* When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
[here](https://github.com/google/jax/pull/16949) for an example.
[here](https://github.com/jax-ml/jax/pull/16949) for an example.
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
[safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are
already hashable.
* Breaking changes:
* jax2tf now uses native serialization by default. See
the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
@ -922,7 +922,7 @@ See the 0.4.33 release notes for more details.
arguments will always resolve to the "common operands" `cond`
behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See
[#16413](https://github.com/google/jax/issues/16413).
[#16413](https://github.com/jax-ml/jax/issues/16413).
* The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
which did nothing, have been removed. These options have been true by
default for many releases.
@ -933,7 +933,7 @@ See the 0.4.33 release notes for more details.
serialization version ({jax-issue}`#16746`).
* jax2tf in presence of shape polymorphism now generates code that checks
certain shape constraints, if the serialization version is at least 7.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
## jaxlib 0.4.14 (July 27, 2023)
@ -1095,14 +1095,14 @@ See the 0.4.33 release notes for more details.
{func}`jax.experimental.host_callback` is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the [JAX issue
tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs
tracker](https://github.com/jax-ml/jax/issues) if the new `jax.debug` APIs
are insufficient for your use case.
The old runtime component will be available for at least the next three
months by setting the environment variable
`JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new
runtime for any reason, please let us know on the [JAX issue
tracker](https://github.com/google/jax/issues).
tracker](https://github.com/jax-ml/jax/issues).
* Changes
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
@ -1126,7 +1126,7 @@ See the 0.4.33 release notes for more details.
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
See [documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
As part of this change the config flag `--jax2tf_default_experimental_native_lowering`
has been renamed to `--jax2tf_native_serialization`.
* JAX now depends on `ml_dtypes`, which contains definitions of NumPy types
@ -1403,7 +1403,7 @@ Changes:
## jaxlib 0.3.22 (Oct 11, 2022)
## jax 0.3.21 (Sep 30, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.20...jax-v0.3.21).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.20...jax-v0.3.21).
* Changes
* The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}`#12582`), so program execution can continue
@ -1417,18 +1417,18 @@ Changes:
* Fix incorrect `pip` url in `setup.py` comment ({jax-issue}`#12528`).
## jaxlib 0.3.20 (Sep 28, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20).
* Bug fixes
* Fixes support for limiting the visible CUDA devices via
`jax_cuda_visible_devices` in distributed jobs. This functionality is needed for
the JAX/SLURM integration on GPU ({jax-issue}`#12533`).
## jax 0.3.19 (Sep 27, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.18...jax-v0.3.19).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.18...jax-v0.3.19).
* Fixes required jaxlib version.
## jax 0.3.18 (Sep 26, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.17...jax-v0.3.18).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.17...jax-v0.3.18).
* Changes
* Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}`#7733`) is stable and public. See [the
@ -1446,7 +1446,7 @@ Changes:
would have been provided.
## jax 0.3.17 (Aug 31, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.16...jax-v0.3.17).
* Bugs
* Fix corner case issue in gradient of `lax.pow` with an exponent of zero
({jax-issue}`12041`)
@ -1462,7 +1462,7 @@ Changes:
* `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead.
## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main).
* Breaking changes
* Support for NumPy 1.19 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
@ -1486,7 +1486,7 @@ Changes:
deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
## jax 0.3.15 (July 22, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...jax-v0.3.15).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15).
* Changes
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
@ -1507,10 +1507,10 @@ Changes:
following a similar deprecation in {func}`scipy.linalg.solve`.
## jaxlib 0.3.15 (July 22, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15).
## jax 0.3.14 (June 27, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.13...jax-v0.3.14).
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
`max_cache_size_ bytes` anymore and will not get that as an input.
@ -1563,22 +1563,22 @@ Changes:
coefficients have leading zeros ({jax-issue}`#11215`).
## jaxlib 0.3.14 (June 27, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14).
* x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14
was released in 2018, so this should not be a very onerous requirement.
* The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
* The Python flatbuffers package is no longer a dependency of jaxlib.
## jax 0.3.13 (May 16, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.12...jax-v0.3.13).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.12...jax-v0.3.13).
## jax 0.3.12 (May 15, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.11...jax-v0.3.12).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.11...jax-v0.3.12).
* Changes
* Fixes [#10717](https://github.com/google/jax/issues/10717).
* Fixes [#10717](https://github.com/jax-ml/jax/issues/10717).
## jax 0.3.11 (May 15, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...jax-v0.3.11).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.10...jax-v0.3.11).
* Changes
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
that allows users to opt out of eigenvalue sorting on TPU.
@ -1592,22 +1592,22 @@ Changes:
scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead.
## jax 0.3.10 (May 3, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.9...jax-v0.3.10).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.9...jax-v0.3.10).
## jaxlib 0.3.10 (May 3, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10).
* Changes
* [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a)
fixes an issue in the MHLO canonicalizer that caused constant folding to
take a long time or crash for certain programs.
## jax 0.3.9 (May 2, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.8...jax-v0.3.9).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.8...jax-v0.3.9).
* Changes
* Added support for fully asynchronous checkpointing for GlobalDeviceArray.
## jax 0.3.8 (April 29 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...jax-v0.3.8).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.7...jax-v0.3.8).
* Changes
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
@ -1666,7 +1666,7 @@ Changes:
## jax 0.3.7 (April 15, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7).
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.6...jax-v0.3.7).
* Changes:
* Fixed a performance problem if the indices passed to
{func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`).
@ -1684,17 +1684,17 @@ Changes:
## jax 0.3.6 (April 12, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.5...jax-v0.3.6).
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.5...jax-v0.3.6).
* Changes:
* Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU
pod. Fixes [#10218](https://github.com/google/jax/issues/10218).
pod. Fixes [#10218](https://github.com/jax-ml/jax/issues/10218).
* Deprecations:
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
for an alternative API.
## jax 0.3.5 (April 7, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.4...jax-v0.3.5).
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.4...jax-v0.3.5).
* Changes:
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
@ -1717,17 +1717,17 @@ Changes:
## jax 0.3.4 (March 18, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.3...jax-v0.3.4).
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.3...jax-v0.3.4).
## jax 0.3.3 (March 17, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.2...jax-v0.3.3).
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.2...jax-v0.3.3).
## jax 0.3.2 (March 16, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.1...jax-v0.3.2).
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.1...jax-v0.3.2).
* Changes:
* The functions `jax.ops.index_update`, `jax.ops.index_add`, which were
deprecated in 0.2.22, have been removed. Please use
@ -1751,7 +1751,7 @@ Changes:
## jax 0.3.1 (Feb 18, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.0...jax-v0.3.1).
commits](https://github.com/jax-ml/jax/compare/jax-v0.3.0...jax-v0.3.1).
* Changes:
* `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated.
@ -1774,7 +1774,7 @@ Changes:
## jax 0.3.0 (Feb 10, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.28...jax-v0.3.0).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0).
* Changes
* jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html)
@ -1788,7 +1788,7 @@ Changes:
## jax 0.2.28 (Feb 1, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.27...jax-v0.2.28).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.27...jax-v0.2.28).
* `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no
`dialect=` is passed.
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
@ -1813,7 +1813,7 @@ Changes:
* The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
## jax 0.2.27 (Jan 18 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...jax-v0.2.27).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.26...jax-v0.2.27).
* Breaking changes:
* Support for NumPy 1.18 has been dropped, per the
@ -1858,7 +1858,7 @@ Changes:
## jax 0.2.26 (Dec 8, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.25...jax-v0.2.26).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.25...jax-v0.2.26).
* Bug fixes:
* Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with
@ -1875,7 +1875,7 @@ Changes:
## jax 0.2.25 (Nov 10, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.24...jax-v0.2.25).
* New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
@ -1889,7 +1889,7 @@ Changes:
## jax 0.2.24 (Oct 19, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.22...jax-v0.2.24).
* New features:
* `jax.random.choice` and `jax.random.permutation` now support
@ -1923,7 +1923,7 @@ Changes:
## jax 0.2.22 (Oct 12, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.21...jax-v0.2.22).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.21...jax-v0.2.22).
* Breaking Changes
* Static arguments to `jax.pmap` must now be hashable.
@ -1958,13 +1958,13 @@ Changes:
* Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+.
* Bug fixes:
* Fixes https://github.com/google/jax/issues/7461, which caused wrong
* Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong
outputs on all platforms due to incorrect buffer aliasing inside the XLA
compiler.
## jax 0.2.21 (Sept 23, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
@ -1992,7 +1992,7 @@ Changes:
## jax 0.2.20 (Sept 2, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.19...jax-v0.2.20).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20).
* Breaking Changes
* `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`)
* `jnp.unique` and other set-like operations now require array-like inputs
@ -2005,7 +2005,7 @@ Changes:
## jax 0.2.19 (Aug 12, 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.18...jax-v0.2.19).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19).
* Breaking changes:
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
@ -2042,7 +2042,7 @@ Changes:
called in sequence.
## jax 0.2.18 (July 21 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...jax-v0.2.18).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18).
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
@ -2065,7 +2065,7 @@ Changes:
* Fix bugs in TFRT CPU backend that results in incorrect results.
## jax 0.2.17 (July 9 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes:
* Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
to work around #7229, which caused wrong outputs on CPU due to a concurrency
@ -2082,12 +2082,12 @@ Changes:
## jax 0.2.16 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.15...jax-v0.2.16).
## jax 0.2.15 (June 23 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.14...jax-v0.2.15).
* New features:
* [#7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend
* [#7042](https://github.com/jax-ml/jax/pull/7042) Turned on TFRT CPU backend
with significant dispatch performance improvements on CPU.
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`#6956`).
@ -2107,7 +2107,7 @@ Changes:
CPU.
## jax 0.2.14 (June 10 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.13...jax-v0.2.14).
* New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
@ -2165,7 +2165,7 @@ Changes:
{func}`jit` transformed functions.
## jax 0.2.13 (May 3 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.12...jax-v0.2.13).
* New features:
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify
@ -2209,7 +2209,7 @@ Changes:
## jaxlib 0.1.65 (April 7 2021)
## jax 0.2.12 (April 1 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.11...v0.2.12).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
@ -2222,7 +2222,7 @@ Changes:
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function`
* Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md)
* Omnistaging can no longer be disabled. See [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md)
for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
@ -2236,23 +2236,23 @@ Changes:
## jax 0.2.11 (March 23 2021)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11).
commits](https://github.com/jax-ml/jax/compare/jax-v0.2.10...jax-v0.2.11).
* New features:
* [#6112](https://github.com/google/jax/pull/6112) added context managers:
* [#6112](https://github.com/jax-ml/jax/pull/6112) added context managers:
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
`jax.debug_infs`, `jax.log_compiles`.
* [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete`
* [#6085](https://github.com/jax-ml/jax/pull/6085) added `jnp.delete`
* Bug fixes:
* [#6136](https://github.com/google/jax/pull/6136) generalized
* [#6136](https://github.com/jax-ml/jax/pull/6136) generalized
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
* [#6129](https://github.com/google/jax/issues/6129) fixed a bug with handling
* [#6129](https://github.com/jax-ml/jax/issues/6129) fixed a bug with handling
some constants like `enum.IntEnums`
* [#6145](https://github.com/google/jax/pull/6145) fixed batching issues with
* [#6145](https://github.com/jax-ml/jax/pull/6145) fixed batching issues with
incomplete beta functions
* [#6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during
* [#6014](https://github.com/jax-ml/jax/pull/6014) fixed H2D transfers during
tracing
* [#6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when
* [#6165](https://github.com/jax-ml/jax/pull/6165) avoids OverflowErrors when
converting some large Python integers to floats
* Breaking changes:
* The minimum jaxlib version is now 0.1.62.
@ -2264,13 +2264,13 @@ Changes:
## jax 0.2.10 (March 5 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.9...jax-v0.2.10).
* New features:
* {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
* {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods.
* Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
from JAX ({jax-issue}`#5627`)
and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
and [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
* Extended the batching rule for `lax.pad` to support batching of the padding values.
* Bug fixes:
* {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`)
@ -2314,7 +2314,7 @@ Changes:
## jax 0.2.9 (January 26 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.8...jax-v0.2.9).
* New features:
* Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved
error checking and error messages.
@ -2330,7 +2330,7 @@ Changes:
## jax 0.2.8 (January 12 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.7...jax-v0.2.8).
* New features:
* Add {func}`jax.closure_convert` for use with higher-order custom
derivative functions. ({jax-issue}`#5244`)
@ -2362,7 +2362,7 @@ Changes:
## jax 0.2.7 (Dec 4 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.6...jax-v0.2.7).
* New features:
* Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit`
@ -2382,14 +2382,14 @@ Changes:
## jax 0.2.6 (Nov 18 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.5...jax-v0.2.6).
* New Features:
* Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
See [README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
* Breaking change cleanup
* Raise an error on non-hashable static arguments for jax.jit and
xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42).
xla_computation. See [cb48f42](https://github.com/jax-ml/jax/commit/cb48f42).
* Improve consistency of type promotion behavior ({jax-issue}`#4744`):
* Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
@ -2441,15 +2441,15 @@ Changes:
## jax 0.2.5 (October 27 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.4...jax-v0.2.5).
* Improvements:
* Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`.
* Expanded the set of JAX primitives converted by jax2tf.
See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).
See [primitives_with_limited_support.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).
## jax 0.2.4 (October 19 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.3...jax-v0.2.4).
* Improvements:
* Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`.
* Deprecations
@ -2461,17 +2461,17 @@ Changes:
## jax 0.2.3 (October 14 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.2...jax-v0.2.3).
* The reason for another release so soon is we need to temporarily roll back a
new jit fastpath while we look into a performance degradation
## jax 0.2.2 (October 13 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.1...jax-v0.2.2).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.1...jax-v0.2.2).
## jax 0.2.1 (October 6 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.0...jax-v0.2.1).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.0...jax-v0.2.1).
* Improvements:
* As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/
@ -2479,10 +2479,10 @@ Changes:
## jax (0.2.0) (September 23 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.77...jax-v0.2.0).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.77...jax-v0.2.0).
* Improvements:
* Omnistaging on by default. See {jax-issue}`#3370` and
[omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md)
[omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md)
## jax (0.1.77) (September 15 2020)
@ -2496,11 +2496,11 @@ Changes:
## jax 0.1.76 (September 8, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.75...jax-v0.1.76).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.75...jax-v0.1.76).
## jax 0.1.75 (July 30, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.74...jax-v0.1.75).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.74...jax-v0.1.75).
* Bug Fixes:
* make jnp.abs() work for unsigned inputs (#3914)
* Improvements:
@ -2508,7 +2508,7 @@ Changes:
## jax 0.1.74 (July 29, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.73...jax-v0.1.74).
* New Features:
* BFGS (#3101)
* TPU support for half-precision arithmetic (#3878)
@ -2525,7 +2525,7 @@ Changes:
## jax 0.1.73 (July 22, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.72...jax-v0.1.73).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.72...jax-v0.1.73).
* The minimum jaxlib version is now 0.1.51.
* New Features:
* jax.image.resize. (#3703)
@ -2563,14 +2563,14 @@ Changes:
## jax 0.1.72 (June 28, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.71...jax-v0.1.72).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.71...jax-v0.1.72).
* Bug fixes:
* Fix an odeint bug introduced in the previous release, see
{jax-issue}`#3587`.
## jax 0.1.71 (June 25, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.70...jax-v0.1.71).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.70...jax-v0.1.71).
* The minimum jaxlib version is now 0.1.48.
* Bug fixes:
* Allow `jax.experimental.ode.odeint` dynamics functions to close over
@ -2606,7 +2606,7 @@ Changes:
## jax 0.1.70 (June 8, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.69...jax-v0.1.70).
* New features:
* `lax.switch` introduces indexed conditionals with multiple
branches, together with a generalization of the `cond`
@ -2615,11 +2615,11 @@ Changes:
## jax 0.1.69 (June 3, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.68...jax-v0.1.69).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.68...jax-v0.1.69).
## jax 0.1.68 (May 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.67...jax-v0.1.68).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.67...jax-v0.1.68).
* New features:
* {func}`lax.cond` supports a single-operand form, taken as the argument
to both branches
@ -2630,7 +2630,7 @@ Changes:
## jax 0.1.67 (May 12, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.66...jax-v0.1.67).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.66...jax-v0.1.67).
* New features:
* Support for reduction over subsets of a pmapped axis using `axis_index_groups`
{jax-issue}`#2382`.
@ -2648,7 +2648,7 @@ Changes:
## jax 0.1.66 (May 5, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.65...jax-v0.1.66).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.65...jax-v0.1.66).
* New features:
* Support for `in_axes=None` on {func}`pmap`
{jax-issue}`#2896`.
@ -2661,7 +2661,7 @@ Changes:
## jax 0.1.65 (April 30, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.64...jax-v0.1.65).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.64...jax-v0.1.65).
* New features:
* Differentiation of determinants of singular matrices
{jax-issue}`#2809`.
@ -2679,7 +2679,7 @@ Changes:
## jax 0.1.64 (April 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.63...jax-v0.1.64).
* New features:
* Add syntactic sugar for functional indexed updates
{jax-issue}`#2684`.
@ -2706,7 +2706,7 @@ Changes:
## jax 0.1.63 (April 12, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63).
* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works).
* Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`.
* Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`.
@ -2727,7 +2727,7 @@ Changes:
## jax 0.1.62 (March 21, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.61...jax-v0.1.62).
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
* Removed the internal function `lax._safe_mul`, which implemented the
convention `0. * nan == 0.`. This change means some programs when
@ -2745,13 +2745,13 @@ Changes:
## jax 0.1.61 (March 17, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.60...jax-v0.1.61).
* Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
supports Python 3.5.
## jax 0.1.60 (March 17, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.59...jax-v0.1.60).
* New features:
* {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows
the user to specify arguments that should be treated as compile-time
@ -2777,7 +2777,7 @@ Changes:
## jax 0.1.59 (February 11, 2020)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59).
* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.58...jax-v0.1.59).
* Breaking changes
* The minimum jaxlib version is now 0.1.38.
@ -2809,7 +2809,7 @@ Changes:
## jax 0.1.58 (January 28, 2020)
* [GitHub commits](https://github.com/google/jax/compare/46014da21...jax-v0.1.58).
* [GitHub commits](https://github.com/jax-ml/jax/compare/46014da21...jax-v0.1.58).
* Breaking changes
* JAX has dropped Python 2 support, because Python 2 reached its end of life on

View File

@ -1,7 +1,7 @@
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}

View File

@ -1,10 +1,10 @@
<div align="center">
<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img>
<img src="https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png" alt="logo"></img>
</div>
# Transformable numerical computing at scale
![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg)
![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)
![PyPI version](https://img.shields.io/pypi/v/jax)
[**Quickstart**](#quickstart-colab-in-the-cloud)
@ -50,7 +50,7 @@ parallel programming of multiple accelerators, with more to come.
This is a research project, not an official Google product. Expect bugs and
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
Please help by trying it out, [reporting
bugs](https://github.com/google/jax/issues), and letting us know what you
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
think!
```python
@ -84,16 +84,16 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs).
Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs).
For a deeper dive into JAX:
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- See the [full list of
notebooks](https://github.com/google/jax/tree/main/docs/notebooks).
notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks).
## Transformations
@ -300,7 +300,7 @@ print(normalize(jnp.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
```
You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
sophisticated communication patterns.
It all composes, so you're free to differentiate through parallel computations:
@ -333,9 +333,9 @@ When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
backward pass of the computation is parallelized just like the forward pass.
See the [SPMD
Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
and the [SPMD MNIST classifier from scratch
example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
for more.
## Current gotchas
@ -349,7 +349,7 @@ Some standouts:
1. [In-place mutating updates of
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
1. [Random numbers are
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md).
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
1. If you're looking for [convolution
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
they're in the `jax.lax` package.
@ -437,7 +437,7 @@ To cite this repository:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}

View File

@ -451,7 +451,7 @@
"id": "jC-KIMQ1q-lK"
},
"source": [
"For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
"For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
]
},
{

View File

@ -837,7 +837,7 @@
"id": "f-FBsWeo1AXE"
},
"source": [
"<img src=\"https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/nested_pmap.png\" width=\"70%\"/>"
"<img src=\"https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/nested_pmap.png\" width=\"70%\"/>"
]
},
{
@ -847,7 +847,7 @@
"id": "jC-KIMQ1q-lK"
},
"source": [
"For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
"For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
]
},
{

View File

@ -15,13 +15,13 @@
"id": "sk-3cPGIBTq8"
},
"source": [
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n",
"\n",
"This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n",
"\n",
"**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.\n",
"\n",
"The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
"The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
]
},
{

View File

@ -13,25 +13,25 @@ VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab:
### [Pmap Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
### [Pmap Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
A guide to getting started with `pmap`, a transform for easily distributing SPMD
computations across devices.
### [Lorentz ODE Solver](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb)
### [Lorentz ODE Solver](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb)
Contributed by Alex Alemi (alexalemi@)
Solve and plot parallel ODE solutions with `pmap`.
<img src="https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/lorentz.png" width=65%></image>
<img src="https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/lorentz.png" width=65%></image>
### [Wave Equation](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb)
### [Wave Equation](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb)
Contributed by Stephan Hoyer (shoyer@)
Solve the wave equation with `pmap`, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU.
![](https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/wave_movie.gif)
![](https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/wave_movie.gif)
### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb)
### [JAX Demo](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb)
An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/) and the [Compilers for ML workshop at CGO 2020](https://www.c4ml.org/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`.
## Performance notes
@ -53,7 +53,7 @@ By default\*, matrix multiplication in JAX on TPUs [uses bfloat16](https://cloud
JAX also adds the `bfloat16` dtype, which you can use to explicitly cast arrays to bfloat16, e.g., `jax.numpy.array(x, dtype=jax.numpy.bfloat16)`.
\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/google/jax/issues/2161) if it affects you!
\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/jax-ml/jax/issues/2161) if it affects you!
## Running JAX on a Cloud TPU VM
@ -65,8 +65,8 @@ documentation](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPU
VM), please email <cloud-tpu-support@google.com>, or <trc-support@google.com> if
you are a [TRC](https://sites.research.google/trc/) member. You can also [file a
JAX issue](https://github.com/google/jax/issues) or [ask a discussion
question](https://github.com/google/jax/discussions) for any issues with these
JAX issue](https://github.com/jax-ml/jax/issues) or [ask a discussion
question](https://github.com/jax-ml/jax/discussions) for any issues with these
notebooks or using JAX in general.
If you have any other questions or comments regarding JAX on Cloud TPUs, please

View File

@ -571,7 +571,7 @@ print("Naive full Hessian materialization")
### Jacobian-Matrix and Matrix-Jacobian products
Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:
Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:
```{code-cell}
# Isolate the function from the weight matrix to the predictions

View File

@ -27,7 +27,7 @@
"metadata": {},
"source": [
"[![Open in\n",
"Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)"
"Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)"
]
},
{
@ -1781,7 +1781,7 @@
"metadata": {},
"source": [
"This is precisely the issue that\n",
"[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n",
"[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.\n",
"We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n",
"applied, regardless of whether any inputs to `bind` are boxed in corresponding\n",
"`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n",

View File

@ -33,7 +33,7 @@ limitations under the License.
```
[![Open in
Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)
Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)
+++
@ -1399,7 +1399,7 @@ print(jaxpr)
```
This is precisely the issue that
[omnistaging](https://github.com/google/jax/pull/3370) fixed.
[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.
We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always
applied, regardless of whether any inputs to `bind` are boxed in corresponding
`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`

View File

@ -27,7 +27,7 @@
# ---
# [![Open in
# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)
# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)
# # Autodidax: JAX core from scratch
#
@ -1396,7 +1396,7 @@ print(jaxpr)
# This is precisely the issue that
# [omnistaging](https://github.com/google/jax/pull/3370) fixed.
# [omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.
# We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always
# applied, regardless of whether any inputs to `bind` are boxed in corresponding
# `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`

View File

@ -52,4 +52,4 @@ questions answered are:
.. _Flax: https://flax.readthedocs.io/
.. _Haiku: https://dm-haiku.readthedocs.io/
.. _JAX on StackOverflow: https://stackoverflow.com/questions/tagged/jax
.. _JAX GitHub discussions: https://github.com/google/jax/discussions
.. _JAX GitHub discussions: https://github.com/jax-ml/jax/discussions

View File

@ -168,7 +168,7 @@ html_theme = 'sphinx_book_theme'
# documentation.
html_theme_options = {
'show_toc_level': 2,
'repository_url': 'https://github.com/google/jax',
'repository_url': 'https://github.com/jax-ml/jax',
'use_repository_button': True, # add a "link to repository" button
'navigation_with_keys': False,
}
@ -345,7 +345,7 @@ def linkcode_resolve(domain, info):
return None
filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__))
lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else ""
return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}"
return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}"
# Generate redirects from deleted files to new sources
rediraffe_redirects = {

View File

@ -5,22 +5,22 @@
Everyone can contribute to JAX, and we value everyone's contributions. There are several
ways to contribute, including:
- Answering questions on JAX's [discussions page](https://github.com/google/jax/discussions)
- Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions)
- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/)
- Contributing to JAX's [code-base](http://github.com/google/jax/)
- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/google/jax#neural-network-libraries)
- Contributing to JAX's [code-base](http://github.com/jax-ml/jax/)
- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries)
The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
## Ways to contribute
We welcome pull requests, in particular for those issues marked with
[contributions welcome](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or
[good first issue](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
[contributions welcome](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or
[good first issue](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
For other proposals, we ask that you first open a GitHub
[Issue](https://github.com/google/jax/issues/new/choose) or
[Discussion](https://github.com/google/jax/discussions)
[Issue](https://github.com/jax-ml/jax/issues/new/choose) or
[Discussion](https://github.com/jax-ml/jax/discussions)
to seek feedback on your planned contribution.
## Contributing code using pull requests
@ -33,7 +33,7 @@ Follow these steps to contribute code:
For more information, see the Pull Request Checklist below.
2. Fork the JAX repository by clicking the **Fork** button on the
[repository page](http://www.github.com/google/jax). This creates
[repository page](http://www.github.com/jax-ml/jax). This creates
a copy of the JAX repository in your own account.
3. Install Python >= 3.10 locally in order to run tests.
@ -52,7 +52,7 @@ Follow these steps to contribute code:
changes.
```bash
git remote add upstream https://www.github.com/google/jax
git remote add upstream https://www.github.com/jax-ml/jax
```
6. Create a branch where you will develop from:

View File

@ -6,7 +6,7 @@
First, obtain the JAX source code:
```
git clone https://github.com/google/jax
git clone https://github.com/jax-ml/jax
cd jax
```
@ -26,7 +26,7 @@ If you're only modifying Python portions of JAX, we recommend installing
pip install jaxlib
```
See the [JAX readme](https://github.com/google/jax#installation) for full
See the [JAX readme](https://github.com/jax-ml/jax#installation) for full
guidance on pip installation (e.g., for GPU and TPU support).
### Building `jaxlib` from source
@ -621,7 +621,7 @@ pytest --doctest-modules jax/_src/numpy/lax_numpy.py
Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in
[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml)
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml)
## Type checking
@ -712,7 +712,7 @@ jupytext --sync docs/notebooks/thinking_in_jax.ipynb
```
The jupytext version should match that specified in
[.pre-commit-config.yaml](https://github.com/google/jax/blob/main/.pre-commit-config.yaml).
[.pre-commit-config.yaml](https://github.com/jax-ml/jax/blob/main/.pre-commit-config.yaml).
To check that the markdown and ipynb files are properly synced, you may use the
[pre-commit](https://pre-commit.com/) framework to perform the same check used
@ -740,12 +740,12 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked
Some of the notebooks are built automatically as part of the pre-submit checks and
as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build.
The build will fail if cells raise errors. If the errors are intentional, you can either catch them,
or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)).
or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)).
You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else
re-saves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations.
See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs/conf.py).
See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs/conf.py).
### Documentation building on `readthedocs.io`
@ -772,7 +772,7 @@ I saw in the Readthedocs logs:
mkvirtualenv jax-docs # A new virtualenv
mkdir jax-docs # A new directory
cd jax-docs
git clone --no-single-branch --depth 50 https://github.com/google/jax
git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax
cd jax
git checkout --force origin/test-docs
git clean -d -f -f

View File

@ -153,7 +153,7 @@ JAX runtime system that are:
an inference system that is already deployed when the exporting is done.
(The particular compatibility window lengths are the same that JAX
[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model),
[promised for jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model),
and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow).
The terminology “backward compatibility” is from the perspective of the consumer,
e.g., the inference system.)
@ -626,7 +626,7 @@ We list here a history of the calling convention version numbers:
June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`.
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
since July 12th, 2023 (cl/547482522),
available in JAX serialization since July 20th, 2023 (JAX 0.4.14),
and the default since August 12th, 2023 (JAX 0.4.15).
@ -721,7 +721,7 @@ that live in jaxlib):
2. Day “D”, we add the new custom call target `T_NEW`.
We should create a new custom call target, and clean up the old
target roughly after 6 months, rather than updating `T` in place:
* See the example [PR #20997](https://github.com/google/jax/pull/20997)
* See the example [PR #20997](https://github.com/jax-ml/jax/pull/20997)
implementing the steps below.
* We add the custom call target `T_NEW`.
* We change the JAX lowering rules that were previous using `T`,

View File

@ -2,4 +2,4 @@
## Interoperation with TensorFlow
See the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
See the [JAX2TF documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).

View File

@ -372,7 +372,7 @@ device.
Jitted functions behave like any other primitive operations—they will follow the
data and will show errors if invoked on data committed on more than one device.
(Before `PR #6002 <https://github.com/google/jax/pull/6002>`_ in March 2021
(Before `PR #6002 <https://github.com/jax-ml/jax/pull/6002>`_ in March 2021
there was some laziness in creation of array constants, so that
``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually
create the array of zeros on ``jax.devices()[1]``, instead of creating the
@ -385,7 +385,7 @@ and its use is not recommended.)
For a worked-out example, we recommend reading through
``test_computation_follows_data`` in
`multi_device_test.py <https://github.com/google/jax/blob/main/tests/multi_device_test.py>`_.
`multi_device_test.py <https://github.com/jax-ml/jax/blob/main/tests/multi_device_test.py>`_.
.. _faq-benchmark:
@ -691,7 +691,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.::
Additional reading:
* `Issue: gradients through jnp.where when one of branches is nan <https://github.com/google/jax/issues/1052#issuecomment-514083352>`_.
* `Issue: gradients through jnp.where when one of branches is nan <https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352>`_.
* `How to avoid NaN gradients when using where <https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf>`_.

View File

@ -406,7 +406,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)."
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
]
},
{
@ -492,7 +492,7 @@
"source": [
"At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n",
"One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n",
"JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n",
"JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n",
"\n",
"One other JAX feature that this example doesn't support is higher-order AD.\n",
"It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n",

View File

@ -333,7 +333,7 @@ def rms_norm_not_vectorized(x, eps=1e-5):
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
```
If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues).
If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
+++
@ -406,7 +406,7 @@ np.testing.assert_allclose(
At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.
One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.
JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.
JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.
One other JAX feature that this example doesn't support is higher-order AD.
It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.

View File

@ -176,7 +176,7 @@ installation.
JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.
Make sure that it is present in your CUDA installation.
Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues)
Please let the JAX team know on [the GitHub issue tracker](https://github.com/jax-ml/jax/issues)
if you run into any errors or problems with the pre-built wheels.
(docker-containers-nvidia-gpu)=
@ -216,7 +216,7 @@ refer to
**Note:** There are several caveats with the Metal plugin:
* The Metal plugin is new and experimental and has a number of
[known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
[known issues](https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
Please report any issues on the JAX issue tracker.
* The Metal plugin currently requires very specific versions of `jax` and
`jaxlib`. This restriction will be relaxed over time as the plugin API

View File

@ -9,7 +9,7 @@ Let's first make a JAX issue.
But if you can pinpoint the commit that triggered the regression, it will really help us.
This document explains how we identified the commit that caused a
[15% performance regression](https://github.com/google/jax/issues/17686).
[15% performance regression](https://github.com/jax-ml/jax/issues/17686).
## Steps
@ -34,7 +34,7 @@ containers](https://github.com/NVIDIA/JAX-Toolbox).
- test_runner.sh: will start the containers and the test.
- test.sh: will install missing dependencies and run the test
Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686
Here are real example scripts used for the issue: https://github.com/jax-ml/jax/issues/17686
- test_runner.sh:
```
for m in 7 8 9; do

View File

@ -14,7 +14,7 @@
## Whats going on?
As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade)
As of [#11830](https://github.com/jax-ml/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade)
## How can I disable the change, and go back to the old behavior for now?
@ -29,7 +29,7 @@ If you need to revert to the old implementation, **please reach out** on a GitHu
As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer
available. If you have an issue, please reach out on [the issue
tracker](https://github.com/google/jax/issues) so we can help fix it!
tracker](https://github.com/jax-ml/jax/issues) so we can help fix it!
## Why are we doing this?
@ -82,7 +82,7 @@ The old `jax.checkpoint` implementation was forced to save the value of `a`, whi
### Significantly less Python overhead in some cases
The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/google/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice!
The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/jax-ml/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice!
### Enabling new JAX features by simplifying internals

View File

@ -12,7 +12,7 @@ The current state of type annotations in JAX is a bit patchwork, and efforts to
This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX.
Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally.
In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/google/jax/pull/9917), [PR #10322](https://github.com/google/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python.
In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/jax-ml/jax/pull/9917), [PR #10322](https://github.com/jax-ml/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python.
This document details JAX's goals and recommendations for type annotations within the package.
## Why type annotations?
@ -21,7 +21,7 @@ There are a number of reasons that a Python project might wish to annotate their
### Level 1: Annotations as documentation
When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]:
When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]:
```python
Array = Any
@ -44,14 +44,14 @@ Many modern IDEs take advantage of type annotations as inputs to [intelligent co
This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development.
JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]:
JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]:
```python
def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
...
```
In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:
In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:
![VSCode Intellisense Screenshot](../_static/vscode-completion.png)
@ -232,7 +232,7 @@ assert jit(f)(x) # x will be a tracer
```
Again, there are a couple mechanisms that could be used for this:
- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/google/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)).
- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/jax-ml/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)).
- define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer`
- restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer`

View File

@ -170,7 +170,7 @@ print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
This module could expose our mechanism for defining new RNG
implementations, and functions for working with PRNG key internals
(see issue [#9263](https://github.com/google/jax/issues/9263)),
(see issue [#9263](https://github.com/jax-ml/jax/issues/9263)),
such as the current `jax._src.prng.random_wrap` and
`random_unwrap`.

View File

@ -78,8 +78,8 @@ to JAX which have relatively complex implementations which are difficult to vali
and introduce outsized maintenance burdens; an example is {func}`jax.scipy.special.bessel_jn`:
as of the writing of this JEP, its current implementation is a non-straightforward
iterative approximation that has
[convergence issues in some domains](https://github.com/google/jax/issues/12402#issuecomment-1384828637),
and [proposed fixes](https://github.com/google/jax/pull/17038/files) introduce further
[convergence issues in some domains](https://github.com/jax-ml/jax/issues/12402#issuecomment-1384828637),
and [proposed fixes](https://github.com/jax-ml/jax/pull/17038/files) introduce further
complexity. Had we more carefully weighed the complexity and robustness of the
implementation when accepting the contribution, we may have chosen not to accept this
contribution to the package.

View File

@ -35,9 +35,9 @@ behavior of their code. This customization
Python control flow and workflows for NaN debugging.
As **JAX developers** we want to write library functions, like
[`logit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83)
[`logit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83)
and
[`expit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91),
[`expit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91),
that are defined in terms of other primitives, but for the purposes of
differentiation have primitive-like behavior in the sense that we want to define
custom differentiation rules for them, which may be more numerically stable or
@ -50,9 +50,9 @@ looking to add custom differentiation rules for higher-order functions like
want to be confident were not going to preclude good solutions to that problem.
That is, our primary goals are
1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/google/jax/issues/1249)), and
1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/jax-ml/jax/issues/1249)), and
2. allow Python in custom VJPs, e.g. to debug NaNs
([#1275](https://github.com/google/jax/issues/1275)).
([#1275](https://github.com/jax-ml/jax/issues/1275)).
Secondary goals are
3. clean up and simplify user experience (symbolic zeros, kwargs, etc)
@ -60,18 +60,18 @@ Secondary goals are
`odeint`, `root`, etc.
Overall, we want to close
[#116](https://github.com/google/jax/issues/116),
[#1097](https://github.com/google/jax/issues/1097),
[#1249](https://github.com/google/jax/issues/1249),
[#1275](https://github.com/google/jax/issues/1275),
[#1366](https://github.com/google/jax/issues/1366),
[#1723](https://github.com/google/jax/issues/1723),
[#1670](https://github.com/google/jax/issues/1670),
[#1875](https://github.com/google/jax/issues/1875),
[#1938](https://github.com/google/jax/issues/1938),
[#116](https://github.com/jax-ml/jax/issues/116),
[#1097](https://github.com/jax-ml/jax/issues/1097),
[#1249](https://github.com/jax-ml/jax/issues/1249),
[#1275](https://github.com/jax-ml/jax/issues/1275),
[#1366](https://github.com/jax-ml/jax/issues/1366),
[#1723](https://github.com/jax-ml/jax/issues/1723),
[#1670](https://github.com/jax-ml/jax/issues/1670),
[#1875](https://github.com/jax-ml/jax/issues/1875),
[#1938](https://github.com/jax-ml/jax/issues/1938),
and replace the custom_transforms machinery (from
[#636](https://github.com/google/jax/issues/636),
[#818](https://github.com/google/jax/issues/818),
[#636](https://github.com/jax-ml/jax/issues/636),
[#818](https://github.com/jax-ml/jax/issues/818),
and others).
## Non-goals
@ -400,7 +400,7 @@ There are some other bells and whistles to the API:
resolved to positions using the `inspect` module. This is a bit of an experiment
with Python 3s improved ability to programmatically inspect argument
signatures. I believe it is sound but not complete, which is a fine place to be.
(See also [#2069](https://github.com/google/jax/issues/2069).)
(See also [#2069](https://github.com/jax-ml/jax/issues/2069).)
* Arguments can be marked non-differentiable using `nondiff_argnums`, and as with
`jit`s `static_argnums` these arguments dont have to be JAX types. We need to
set a convention for how these arguments are passed to the rules. For a primal
@ -433,5 +433,5 @@ There are some other bells and whistles to the API:
`custom_lin` to the tangent values; `custom_lin` carries with it the users
custom backward-pass function, and as a primitive it only has a transpose
rule.
* This mechanism is described more in [#636](https://github.com/google/jax/issues/636).
* This mechanism is described more in [#636](https://github.com/jax-ml/jax/issues/636).
* To prevent

View File

@ -9,7 +9,7 @@ notebook.
## What to update
After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments
After JAX [PR #4008](https://github.com/jax-ml/jax/pull/4008), the arguments
passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or
containers of `Tracer`s), which basically means to allow for
arbitrarily-transformable code `nondiff_argnums` shouldn't be used for
@ -95,7 +95,7 @@ acted very much like lexical closure. But lexical closure over `Tracer`s wasn't
at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing
`nondiff_argnums` that way was a mistake!
**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure
**[PR #4008](https://github.com/jax-ml/jax/pull/4008) fixes all lexical closure
issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp`
and `custom_vjp` functions and rules can close over `Tracer`s to our hearts'
content. For all non-autodiff transformations, things will Just Work. For
@ -120,9 +120,9 @@ manageable, until you think through how we have to handle arbitrary pytrees!
Moreover, that complexity isn't necessary: if user code treats array-like
non-differentiable arguments just like regular arguments and residuals,
everything already works. (Before
[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about
[#4039](https://github.com/jax-ml/jax/pull/4039) JAX might've complained about
involving integer-valued inputs and outputs in autodiff, but after
[#4039](https://github.com/google/jax/pull/4039) those will just work!)
[#4039](https://github.com/jax-ml/jax/pull/4039) those will just work!)
Unlike `custom_vjp`, it was easy to make `custom_jvp` work with
`nondiff_argnums` arguments that were `Tracer`s. So these updates only need to

View File

@ -20,7 +20,7 @@ This is more of an upgrade guide than a design doc.
### What's going on?
A change to JAX's tracing infrastructure called “omnistaging”
([google/jax#3370](https://github.com/google/jax/pull/3370)) was switched on in
([jax-ml/jax#3370](https://github.com/jax-ml/jax/pull/3370)) was switched on in
jax==0.2.0. This change improves memory performance, trace execution time, and
simplifies jax internals, but may cause some existing code to break. Breakage is
usually a result of buggy code, so long-term its best to fix the bugs, but
@ -191,7 +191,7 @@ and potentially even fragmenting memory.
(The `broadcast` that corresponds to the construction of the zeros array for
`jnp.zeros_like(x)` is staged out because JAX is lazy about very simple
expressions from [google/jax#1668](https://github.com/google/jax/pull/1668). After
expressions from [jax-ml/jax#1668](https://github.com/jax-ml/jax/pull/1668). After
omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)
The reason the creation of `mask` is not staged out is that, before omnistaging,

View File

@ -321,7 +321,7 @@ Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same
extended dtype mechanism elsewhere internally. For example, the
`jax._src.core.bint` object, a bounded integer type used for experimental work
on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies
the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)).
the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/jax-ml/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)).
### PRNG dtypes
PRNG dtypes are defined as a particular case of extended dtypes. Specifically,

View File

@ -8,7 +8,7 @@
"source": [
"# Design of Type Promotion Semantics for JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
"\n",
"*Jake VanderPlas, December 2021*\n",
"\n",

View File

@ -16,7 +16,7 @@ kernelspec:
# Design of Type Promotion Semantics for JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
*Jake VanderPlas, December 2021*

View File

@ -58,11 +58,11 @@ These constraints imply the following rules for releases:
* If a new `jaxlib` is released, a `jax` release must be made at the same time.
These
[version constraints](https://github.com/google/jax/blob/main/jax/version.py)
[version constraints](https://github.com/jax-ml/jax/blob/main/jax/version.py)
are currently checked by `jax` at import time, instead of being expressed as
Python package version constraints. `jax` checks the `jaxlib` version at
runtime rather than using a `pip` package version constraint because we
[provide separate `jaxlib` wheels](https://github.com/google/jax#installation)
[provide separate `jaxlib` wheels](https://github.com/jax-ml/jax#installation)
for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we
do not know which is the right choice for any given user, we do not want `pip`
to install a `jaxlib` package for us automatically.
@ -119,7 +119,7 @@ no released `jax` version uses that API.
## How is the source to `jaxlib` laid out?
`jaxlib` is split across two main repositories, namely the
[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib)
[`jaxlib/` subdirectory in the main JAX repository](https://github.com/jax-ml/jax/tree/main/jaxlib)
and in the
[XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla).
The JAX-specific pieces inside XLA are primarily in the
@ -146,7 +146,7 @@ level.
`jaxlib` is built using Bazel out of the `jax` repository. The pieces of
`jaxlib` from the XLA repository are incorporated into the build
[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE).
[as a Bazel submodule](https://github.com/jax-ml/jax/blob/main/WORKSPACE).
To update the version of XLA used during the build, one must update the pinned
version in the Bazel `WORKSPACE`. This is done manually on an
as-needed basis, but can be overridden on a build-by-build basis.

View File

@ -32,7 +32,7 @@ should be linked to this issue.
Then create a pull request that adds a file named
`%d-{short-title}.md` - with the number being the issue number.
.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP
.. _JEP label: https://github.com/jax-ml/jax/issues?q=label%3AJEP
.. toctree::
:maxdepth: 1

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-06-03' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
]
},
{
@ -661,7 +661,7 @@
"source": [
"Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n",
"\n",
"Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
"Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
]
},
{
@ -1003,7 +1003,7 @@
"id": "COjzGBpO4tzL"
},
"source": [
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by a special array element that we call a __key__:"
]
@ -1349,7 +1349,7 @@
"\n",
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
"\n",
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
"To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
"\n",
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
"\n",

View File

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-06-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
+++ {"id": "4k5PVzEo2uJO"}
@ -312,7 +312,7 @@ jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
+++ {"id": "LwB07Kx5sgHu"}
@ -460,7 +460,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
+++ {"id": "COjzGBpO4tzL"}
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by a special array element that we call a __key__:
@ -623,7 +623,7 @@ When we `jit`-compile a function, we usually want to compile a version of the fu
For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.
To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
"\n",
"There are two ways to define differentiation rules in JAX:\n",
"\n",

View File

@ -17,7 +17,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
There are two ways to define differentiation rules in JAX:

View File

@ -17,7 +17,7 @@
"id": "pFtQjv4SzHRj"
},
"source": [
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
"\n",
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer."
]

View File

@ -19,7 +19,7 @@ kernelspec:
+++ {"id": "pFtQjv4SzHRj"}
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)
This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
"\n",
"*necula@google.com*, October 2019.\n",
"\n",

View File

@ -17,7 +17,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
*necula@google.com*, October 2019.

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
"\n",
"**Copyright 2018 The JAX Authors.**\n",
"\n",
@ -32,9 +32,9 @@
"id": "B_XlLLpcWjkA"
},
"source": [
"![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n",
"![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n",
"\n",
"Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
"Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
"\n",
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
]

View File

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-05-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
**Copyright 2018 The JAX Authors.**
@ -35,9 +35,9 @@ limitations under the License.
+++ {"id": "B_XlLLpcWjkA"}
![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)
![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)
Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
]
},
{
@ -79,7 +79,7 @@
"id": "gA8V51wZdsjh"
},
"source": [
"When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/google/jax#how-it-works) section in the README."
"When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/jax-ml/jax#how-it-works) section in the README."
]
},
{
@ -320,7 +320,7 @@
"source": [
"Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n",
"\n",
"Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
"Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
]
},
{
@ -333,7 +333,7 @@
"\n",
"An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n",
"\n",
"It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
"It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
]
},
{

View File

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
+++ {"id": "r-3vMiKRYXPJ"}
@ -57,7 +57,7 @@ fast_f = jit(f)
+++ {"id": "gA8V51wZdsjh"}
When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/google/jax#how-it-works) section in the README.
When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/jax-ml/jax#how-it-works) section in the README.
+++ {"id": "2Th1vYLVaFBz"}
@ -223,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
Notice that `eval_jaxpr` will always return a flat list even if the original function does not.
Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
+++ {"id": "0vb2ZoGrCMM4"}
@ -231,7 +231,7 @@ Furthermore, this interpreter does not handle higher-order primitives (like `jit
An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.
It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234).
It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234).
```{code-cell} ipython3
:id: gSMIT2z1vUpO

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
"\n",
"JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics."
]
@ -255,7 +255,7 @@
"id": "cJ2NxiN58bfI"
},
"source": [
"You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)."
"You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)."
]
},
{
@ -1015,7 +1015,7 @@
"source": [
"### Jacobian-Matrix and Matrix-Jacobian products\n",
"\n",
"Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products."
"Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products."
]
},
{

View File

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.
@ -151,7 +151,7 @@ print(grad(loss2)({'W': W, 'b': b}))
+++ {"id": "cJ2NxiN58bfI"}
You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.).
You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.).
+++ {"id": "PaCHzAtGruBz"}
@ -592,7 +592,7 @@ print("Naive full Hessian materialization")
### Jacobian-Matrix and Matrix-Jacobian products
Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
```{code-cell} ipython3
:id: asAWvxVaCmsx

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
"\n",
"JAX provides a number of interfaces to compute convolutions across data, including:\n",
"\n",

View File

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)
JAX provides a number of interfaces to compute convolutions across data, including:

View File

@ -40,11 +40,11 @@
"\n",
"<!--* freshness: { reviewed: '2024-05-03' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
"\n",
"_Forked from_ `neural_network_and_data_loading.ipynb`\n",
"\n",
"![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n",
"![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n",
"\n",
"Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n",
"\n",

View File

@ -38,11 +38,11 @@ limitations under the License.
<!--* freshness: { reviewed: '2024-05-03' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
_Forked from_ `neural_network_and_data_loading.ipynb`
![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)
![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)
Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
"\n",
"JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively."
]

View File

@ -17,7 +17,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.

View File

@ -10,7 +10,7 @@
"\n",
"<!--* freshness: { reviewed: '2024-04-08' } *-->\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
"\n",
"This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n",
"\n",

View File

@ -18,7 +18,7 @@ kernelspec:
<!--* freshness: { reviewed: '2024-04-08' } *-->
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.

View File

@ -21,7 +21,7 @@ software emulation, and can slow down the computation.
If you see unexpected outputs, please compare them against a kernel run with
``interpret=True`` passed in to ``pallas_call``. If the results diverge,
please file a `bug report <https://github.com/google/jax/issues/new/choose>`_.
please file a `bug report <https://github.com/jax-ml/jax/issues/new/choose>`_.
What is a TPU?
--------------

View File

@ -29,7 +29,7 @@ f(x)
### Setting cache directory
The compilation cache is enabled when the
[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
[cache location](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:
@ -54,7 +54,7 @@ os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
```
(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
(3) Using [`set_cache_dir()`](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
```python
from jax.experimental.compilation_cache import compilation_cache as cc

View File

@ -26,14 +26,14 @@ def jax_issue_role(name, rawtext, text, lineno, inliner, options=None,
:jax-issue:`1234`
This will output a hyperlink of the form
`#1234 <http://github.com/google/jax/issues/1234>`_. These links work even
`#1234 <http://github.com/jax-ml/jax/issues/1234>`_. These links work even
for PR numbers.
"""
text = text.lstrip('#')
if not text.isdigit():
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.")
options = {} if options is None else options
url = f"https://github.com/google/jax/issues/{text}"
url = f"https://github.com/jax-ml/jax/issues/{text}"
node = nodes.reference(rawtext, '#' + text, refuri=url, **options)
return [node], []

View File

@ -234,4 +234,4 @@ Handling parameters manually seems fine if you're dealing with two parameters, b
2) Are we supposed to pipe all these things around manually?
The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples.
The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/jax-ml/jax#neural-network-libraries) for some examples.

View File

@ -29,7 +29,7 @@ except Exception as exc:
# Defensively swallow any exceptions to avoid making jax unimportable
from warnings import warn as _warn
_warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report "
f"an issue at https://github.com/google/jax/issues")
f"an issue at https://github.com/jax-ml/jax/issues")
del _warn
del _cloud_tpu_init
@ -38,7 +38,7 @@ import jax.core as _core
del _core
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.basearray import Array as Array
from jax import tree as tree

View File

@ -546,7 +546,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
# To avoid precision mismatches in fwd and bwd passes due to XLA excess
# precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls
# on producers of any residuals. See https://github.com/google/jax/pull/22244.
# on producers of any residuals. See https://github.com/jax-ml/jax/pull/22244.
jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res)
# compute known outputs and residuals (hoisted out of remat primitive)

View File

@ -956,7 +956,7 @@ def vmap(fun: F,
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
# rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_axes = tuple(in_axes)
if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}):
@ -2505,7 +2505,7 @@ class ShapeDtypeStruct:
def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/google/jax/issues/8182
# https://github.com/jax-ml/jax/issues/8182
return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type))
def _sds_aval_mapping(x):

View File

@ -196,7 +196,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
device_assignment = axis_context.device_assignment
if device_assignment is None:
raise AssertionError(
"Please file a bug at https://github.com/google/jax/issues")
"Please file a bug at https://github.com/jax-ml/jax/issues")
try:
device_index = device_assignment.index(device)
except IndexError as e:

View File

@ -1170,7 +1170,7 @@ softmax_custom_jvp = bool_state(
upgrade=True,
help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should '
'improve memory usage and stability. Set True to use new '
'behavior. See https://github.com/google/jax/pull/15677'),
'behavior. See https://github.com/jax-ml/jax/pull/15677'),
update_global_hook=lambda val: _update_global_jit_state(
softmax_custom_jvp=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(

View File

@ -935,7 +935,7 @@ aval_method = namedtuple("aval_method", ["fun"])
class EvalTrace(Trace):
# See comments in https://github.com/google/jax/pull/3370
# See comments in https://github.com/jax-ml/jax/pull/3370
def pure(self, x): return x
lift = sublift = pure
@ -998,7 +998,7 @@ class MainTrace:
return self.trace_type(self, cur_sublevel(), **self.payload)
class TraceStack:
# See comments in https://github.com/google/jax/pull/3370
# See comments in https://github.com/jax-ml/jax/pull/3370
stack: list[MainTrace]
dynamic: MainTrace
@ -1167,7 +1167,7 @@ def _why_alive(ignore_ids: set[int], x: Any) -> str:
# parent->child jump. We do that by setting `parent` here to be a
# grandparent (or great-grandparent) of `child`, and then handling that case
# in _why_alive_container_info. See example:
# https://github.com/google/jax/pull/13022#discussion_r1008456599
# https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599
# To prevent this collapsing behavior, just comment out this code block.
if (isinstance(parent, dict) and
getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
@ -1213,7 +1213,7 @@ def _why_alive_container_info(container, obj_id) -> str:
@contextmanager
def new_main(trace_type: type[Trace], dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
# See comments in https://github.com/jax-ml/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
main = MainTrace(level, trace_type, **payload)
@ -1254,7 +1254,7 @@ def dynamic_level() -> int:
@contextmanager
def new_base_main(trace_type: type[Trace],
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
# See comments in https://github.com/jax-ml/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type, **payload)
prev_dynamic, stack.dynamic = stack.dynamic, main
@ -1319,7 +1319,7 @@ def ensure_compile_time_eval():
else:
return jnp.cos(x)
Here's a real-world example from https://github.com/google/jax/issues/3974::
Here's a real-world example from https://github.com/jax-ml/jax/issues/3974::
import jax
import jax.numpy as jnp
@ -1680,7 +1680,7 @@ class UnshapedArray(AbstractValue):
@property
def shape(self):
msg = ("UnshapedArray has no shape. Please open an issue at "
"https://github.com/google/jax/issues because it's unexpected for "
"https://github.com/jax-ml/jax/issues because it's unexpected for "
"UnshapedArray instances to ever be produced.")
raise TypeError(msg)

View File

@ -191,7 +191,7 @@ def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
# TODO(frostig): assert these also equal:
# treedef_tuple((in_tree, in_tree))
# once https://github.com/google/jax/issues/9066 is fixed
# once https://github.com/jax-ml/jax/issues/9066 is fixed
assert tree_ps_ts == tree_ps_ts2
del tree_ps_ts2

View File

@ -1144,7 +1144,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
# See https://github.com/google/jax/issues/6415 for motivation.
# See https://github.com/jax-ml/jax/issues/6415 for motivation.
x = core.full_lower(x)
if not isinstance(x, core.Tracer):
# If x is not a Tracer, it can't be perturbed.

View File

@ -492,7 +492,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues')
'Please file a bug at https://github.com/jax-ml/jax/issues')
if axis_context.mesh_shape is not None:
ma, ms = list(zip(*axis_context.mesh_shape))
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)

View File

@ -389,7 +389,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues')
'Please file a bug at https://github.com/jax-ml/jax/issues')
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = axis_context.mesh._flat_devices_tuple
else:

View File

@ -793,7 +793,7 @@ def check_user_dtype_supported(dtype, fun_name=None):
"and will be truncated to dtype {}. To enable more dtypes, set the "
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
"environment variable. "
"See https://github.com/google/jax#current-gotchas for more.")
"See https://github.com/jax-ml/jax#current-gotchas for more.")
fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = canonicalize_dtype(np_dtype).name
warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3)

View File

@ -61,7 +61,7 @@ def _ravel_list(lst):
if all(dt == to_dtype for dt in from_dtypes):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
# See https://github.com/google/jax/issues/7809.
# See https://github.com/jax-ml/jax/issues/7809.
del from_dtypes, to_dtype
raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes)

View File

@ -30,7 +30,7 @@ Instead of writing this information as conditions inside one
particular test, we write them as `Limitation` objects that can be reused in
multiple tests and can also be used to generate documentation, e.g.,
the report of [unsupported and partially-implemented JAX
primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md)
primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md)
The limitations are used to filter out from tests the harnesses that are known
to fail. A Limitation is specific to a harness.
@ -515,7 +515,7 @@ def _make_convert_element_type_harness(name,
for old_dtype in jtu.dtypes.all:
# TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating
# point numbers and new_dtype is an unsigned integer. See issue
# https://github.com/google/jax/issues/5082 for details.
# https://github.com/jax-ml/jax/issues/5082 for details.
for new_dtype in (jtu.dtypes.all
if not (dtypes.issubdtype(old_dtype, np.floating) or
dtypes.issubdtype(old_dtype, np.complexfloating))
@ -2336,7 +2336,7 @@ _make_select_and_scatter_add_harness("select_prim", select_prim=lax.le_p)
# Validate padding
for padding in [
# TODO(bchetioui): commented out the test based on
# https://github.com/google/jax/issues/4690
# https://github.com/jax-ml/jax/issues/4690
# ((1, 2), (2, 3), (3, 4)) # non-zero padding
((1, 1), (1, 1), (1, 1)) # non-zero padding
]:

View File

@ -1262,7 +1262,7 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool],
inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult:
msg = (f'custom-policy remat rule not implemented for {name}, '
'open a feature request at https://github.com/google/jax/issues!')
'open a feature request at https://github.com/jax-ml/jax/issues!')
raise NotImplementedError(msg)
@ -2688,7 +2688,7 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
# TODO(mattjj): the following are deprecated; update callers to _nounits version
# See https://github.com/google/jax/pull/9498
# See https://github.com/jax-ml/jax/pull/9498
@lu.transformation
def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool],
pvals: Sequence[PartialVal]):

View File

@ -500,7 +500,7 @@ class MapTrace(core.Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
if symbolic_zeros:
msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. "
"Please open an issue at https://github.com/google/jax/issues !")
"Please open an issue at https://github.com/jax-ml/jax/issues !")
raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros # always base main, can drop jvp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
@ -513,7 +513,7 @@ class MapTrace(core.Trace):
out_trees, symbolic_zeros):
if symbolic_zeros:
msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. "
"Please open an issue at https://github.com/google/jax/issues !")
"Please open an issue at https://github.com/jax-ml/jax/issues !")
raise NotImplementedError(msg)
del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
@ -1869,7 +1869,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
"See https://github.com/google/jax/issues/2926.")
"See https://github.com/jax-ml/jax/issues/2926.")
if nreps > xb.device_count(backend):
raise ValueError(

View File

@ -389,7 +389,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
branch_outs = []
for i, jaxpr in enumerate(branches_batched):
# Perform a select on the inputs for safety of reverse-mode autodiff; see
# https://github.com/google/jax/issues/1052
# https://github.com/jax-ml/jax/issues/1052
predicate = lax.eq(index, lax._const(index, i))
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))

View File

@ -715,7 +715,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
raise NotImplementedError
if not all(init_lin):
pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963
consts, _, xs = split_list(args, [num_consts, num_carry])
ires, _ = split_list(consts, [num_ires])
@ -1169,7 +1169,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
if discharged_consts:
raise NotImplementedError("Discharged jaxpr has consts. If you see this, "
"please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/jax-ml/jax/issues")
def wrapped(*wrapped_args):
val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args,
[n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs])
@ -1838,7 +1838,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
if body_jaxpr_consts:
raise NotImplementedError("Body jaxpr has consts. If you see this error, "
"please open an issue at "
"https://github.com/google/jax/issues")
"https://github.com/jax-ml/jax/issues")
# body_jaxpr has the signature (*body_consts, *carry) -> carry.
# Some of these body_consts are actually `Ref`s so when we discharge
# them, they also turn into outputs, effectively turning those consts into

View File

@ -157,7 +157,7 @@ def _irfft_transpose(t, fft_lengths):
out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
# Use JAX's convention for complex gradients
# https://github.com/google/jax/issues/6223#issuecomment-807740707
# https://github.com/jax-ml/jax/issues/6223#issuecomment-807740707
return lax.conj(out)
def _fft_transpose_rule(t, operand, fft_type, fft_lengths):

View File

@ -1077,7 +1077,7 @@ def _reduction_jaxpr(computation, aval):
if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue "
"at https://github.com/google/jax.")
"at https://github.com/jax-ml/jax.")
return jaxpr, tuple(consts)
@cache()
@ -1090,7 +1090,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue "
"at https://github.com/google/jax.")
"at https://github.com/jax-ml/jax.")
return core.ClosedJaxpr(jaxpr, consts), out_tree()
def _get_monoid_reducer(monoid_op: Callable,
@ -4911,7 +4911,7 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
return tree_util.tree_unflatten(p.out_tree(), out_flat)
# TODO(https://github.com/google/jax/issues/13552): Look into making this a
# TODO(https://github.com/jax-ml/jax/issues/13552): Look into making this a
# method on jax.Array so that we can bypass the XLA compilation here.
def _copy_impl(prim, *args, **kwargs):
a, = args

View File

@ -781,7 +781,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
raise NotImplementedError(
'The derivatives of eigenvectors are not implemented, only '
'eigenvalues. See '
'https://github.com/google/jax/issues/2748 for discussion.')
'https://github.com/jax-ml/jax/issues/2748 for discussion.')
# Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in
# https://arxiv.org/abs/1701.00392
a, = primals

View File

@ -27,7 +27,7 @@ try:
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
'jax requires jaxlib to be installed. See '
'https://github.com/google/jax#installation for installation instructions.'
'https://github.com/jax-ml/jax#installation for installation instructions.'
) from err
import jax.version
@ -92,7 +92,7 @@ pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
# XLA garbage collection: see https://github.com/google/jax/issues/14882
# XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882
def _xla_gc_callback(*args):
xla_client._xla.collect_garbage()
gc.callbacks.append(_xla_gc_callback)

View File

@ -333,7 +333,7 @@ class AbstractMesh:
should use this as an input to the sharding passed to with_sharding_constraint
and mesh passed to shard_map to avoid tracing and lowering cache misses when
your mesh shape and names stay the same but the devices change.
See the description of https://github.com/google/jax/pull/23022 for more
See the description of https://github.com/jax-ml/jax/pull/23022 for more
details.
"""

View File

@ -3953,7 +3953,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
arrays_out = [asarray(arr, dtype=dtype) for arr in arrays]
# lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode.
# (https://github.com/google/jax/issues/653).
# (https://github.com/jax-ml/jax/issues/653).
k = 16
while len(arrays_out) > 1:
arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
@ -4645,7 +4645,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
if all(not isinstance(leaf, Array) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
# https://github.com/google/jax/pull/6047. More correct would be to call
# https://github.com/jax-ml/jax/pull/6047. More correct would be to call
# coerce_to_array on each leaf, but this may have performance implications.
out = np.asarray(object, dtype=dtype)
elif isinstance(object, Array):
@ -10150,11 +10150,11 @@ def _eliminate_deprecated_list_indexing(idx):
if any(_should_unpack_list_index(i) for i in idx):
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[tuple(seq)]` instead of `arr[seq]`. "
"See https://github.com/google/jax/issues/4564 for more information.")
"See https://github.com/jax-ml/jax/issues/4564 for more information.")
else:
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[array(seq)]` instead of `arr[seq]`. "
"See https://github.com/google/jax/issues/4564 for more information.")
"See https://github.com/jax-ml/jax/issues/4564 for more information.")
raise TypeError(msg)
else:
idx = (idx,)

View File

@ -968,7 +968,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
msg = ("jax.numpy.var does not yet support real dtype parameters when "
"computing the variance of an array of complex values. The "
"semantics of numpy.var seem unclear in this case. Please comment "
"on https://github.com/google/jax/issues/2283 if this behavior is "
"on https://github.com/jax-ml/jax/issues/2283 if this behavior is "
"important to you.")
raise ValueError(msg)
computation_dtype = dtype

View File

@ -134,7 +134,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static ``size``
argument:
@ -217,7 +217,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike,
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static ``size``
argument:

View File

@ -2176,7 +2176,7 @@ def sinc(x: ArrayLike, /) -> Array:
def _sinc_maclaurin(k, x):
# compute the kth derivative of x -> sin(x)/x evaluated at zero (since we
# compute the monomial term in the jvp rule)
# TODO(mattjj): see https://github.com/google/jax/issues/10750
# TODO(mattjj): see https://github.com/jax-ml/jax/issues/10750
if k % 2:
return x * 0
else:

View File

@ -35,7 +35,7 @@ FRAME_PATTERN = re.compile(
)
MLIR_ERR_PREFIX = (
'Pallas encountered an internal verification error.'
'Please file a bug at https://github.com/google/jax/issues. '
'Please file a bug at https://github.com/jax-ml/jax/issues. '
'Error details: '
)

View File

@ -833,7 +833,7 @@ def jaxpr_subcomp(
raise NotImplementedError(
"Unimplemented primitive in Pallas TPU lowering: "
f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues.")
"Please file an issue on https://github.com/jax-ml/jax/issues.")
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, ans)
else:

View File

@ -549,7 +549,7 @@ def lower_jaxpr_to_mosaic_gpu(
raise NotImplementedError(
"Unimplemented primitive in Pallas Mosaic GPU lowering: "
f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues."
"Please file an issue on https://github.com/jax-ml/jax/issues."
)
rule = mosaic_lowering_rules[eqn.primitive]
rule_ctx = LoweringRuleContext(

View File

@ -381,7 +381,7 @@ def lower_jaxpr_to_triton_ir(
raise NotImplementedError(
"Unimplemented primitive in Pallas GPU lowering: "
f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/google/jax/issues.")
"Please file an issue on https://github.com/jax-ml/jax/issues.")
rule = triton_lowering_rules[eqn.primitive]
avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars]

View File

@ -459,7 +459,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
# rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_shardings = tuple(in_shardings)
in_layouts, in_shardings = _split_layout_and_sharding(in_shardings)
@ -1276,7 +1276,7 @@ def explain_tracing_cache_miss(
return done()
# we think this is unreachable...
p("explanation unavailable! please open an issue at https://github.com/google/jax")
p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax")
return done()
@partial(lu.cache, explain=explain_tracing_cache_miss)
@ -1701,7 +1701,7 @@ def _pjit_call_impl_python(
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
"https://github.com/google/jax.")
"https://github.com/jax-ml/jax.")
raise FloatingPointError(msg)

View File

@ -508,7 +508,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
# Ensure that span=1 when maxval <= minval, so minval is always returned;
# https://github.com/google/jax/issues/222
# https://github.com/jax-ml/jax/issues/222
span = lax.select(maxval <= minval, lax.full_like(span, 1), span)
# When maxval is out of range, the span has to be one larger.
@ -2540,7 +2540,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array:
_btrs(key, count_btrs, q_btrs, shape, dtype, max_iters),
)
# ensure nan q always leads to nan output and nan or neg count leads to nan
# as discussed in https://github.com/google/jax/pull/16134#pullrequestreview-1446642709
# as discussed in https://github.com/jax-ml/jax/pull/16134#pullrequestreview-1446642709
invalid = (q_l_0 | q_is_nan | count_nan_or_neg)
samples = lax.select(
invalid,

View File

@ -176,7 +176,7 @@ def map_coordinates(
Note:
Interpolation near boundaries differs from the scipy function, because JAX
fixed an outstanding bug; see https://github.com/google/jax/issues/11097.
fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097.
This function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy.
"""

View File

@ -44,7 +44,7 @@ def shard_alike(x, y):
raise ValueError(
'The leaves shapes of `x` and `y` should match. Got `x` leaf shape:'
f' {x_aval.shape} and `y` leaf shape: {y_aval.shape}. File an issue at'
' https://github.com/google/jax/issues if you want this feature.')
' https://github.com/jax-ml/jax/issues if you want this feature.')
outs = [shard_alike_p.bind(x_, y_) for x_, y_ in safe_zip(x_flat, y_flat)]
x_out_flat, y_out_flat = zip(*outs)

View File

@ -1208,7 +1208,7 @@ class JaxTestCase(parameterized.TestCase):
y = np.asarray(y)
if (not allow_object_dtype) and (x.dtype == object or y.dtype == object):
# See https://github.com/google/jax/issues/17867
# See https://github.com/jax-ml/jax/issues/17867
raise TypeError(
"assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
"If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. "

View File

@ -21,7 +21,7 @@ exported at `jax.typing`. Until then, the contents here should be considered uns
and may change without notice.
To see the proposal that led to the development of these tools, see
https://github.com/google/jax/pull/11859/.
https://github.com/jax-ml/jax/pull/11859/.
"""
from __future__ import annotations

View File

@ -1232,7 +1232,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs):
if library_path is None:
raise RuntimeError(
"JAX TPU support not installed; cannot generate TPU topology. See"
" https://github.com/google/jax#installation")
" https://github.com/jax-ml/jax#installation")
c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path)
xla_client.profiler.register_plugin_profiler(c_api)
assert xla_client.pjrt_plugin_loaded("tpu")

Some files were not shown because too many files have changed in this diff Show More