diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
index c19832e63..628310519 100644
--- a/.github/ISSUE_TEMPLATE/bug-report.yml
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -20,11 +20,11 @@ body:
* If you prefer a non-templated issue report, click [here][Raw report].
- [Discussions]: https://github.com/google/jax/discussions
+ [Discussions]: https://github.com/jax-ml/jax/discussions
- [issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues
+ [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues
- [Raw report]: http://github.com/google/jax/issues/new
+ [Raw report]: http://github.com/jax-ml/jax/issues/new
- type: textarea
attributes:
label: Description
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
index cabbed589..f078e8e94 100644
--- a/.github/ISSUE_TEMPLATE/config.yml
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -1,5 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: Have questions or need support?
- url: https://github.com/google/jax/discussions
+ url: https://github.com/jax-ml/jax/discussions
about: Please ask questions on the Discussions tab
diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml
index 1e345954d..74cb45920 100644
--- a/.github/workflows/upstream-nightly.yml
+++ b/.github/workflows/upstream-nightly.yml
@@ -84,7 +84,7 @@ jobs:
failure()
&& steps.status.outcome == 'failure'
&& github.event_name == 'schedule'
- && github.repository == 'google/jax'
+ && github.repository == 'jax-ml/jax'
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
with:
name: output-${{ matrix.python-version }}-log.jsonl
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ee782d04a..43db6e197 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -279,7 +279,7 @@ See the 0.4.33 release notes for more details.
which manifested as an incorrect output for cumulative reductions (#21403).
* Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(https://github.com/openxla/xla/pull/13301).
- * Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396).
+ * Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396).
* Deprecations
* `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
@@ -401,7 +401,7 @@ See the 0.4.33 release notes for more details.
branch consistent with that of NumPy 2.0.
* The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'`
and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has
- changed](https://github.com/google/jax/issues/19085) so that
+ changed](https://github.com/jax-ml/jax/issues/19085) so that
mapping over keys results in random generation only from the first
key in the batch.
* Docs now use `jax.random.key` for construction of PRNG key arrays
@@ -433,7 +433,7 @@ See the 0.4.33 release notes for more details.
* JAX export does not support older serialization versions anymore. Version 9
has been supported since October 27th, 2023 and has become the default
since February 1, 2024.
- See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
+ See [a description of the versions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
This change could break clients that set a specific
JAX serialization version lower than 9.
@@ -506,7 +506,7 @@ See the 0.4.33 release notes for more details.
* added the ability to specify symbolic constraints on the dimension variables.
This makes shape polymorphism more expressive, and gives a way to workaround
limitations in the reasoning about inequalities.
- See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
+ See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* with the addition of symbolic constraints ({jax-issue}`#19235`) we now
consider dimension variables from different scopes to be different, even
if they have the same name. Symbolic expressions from different scopes
@@ -516,7 +516,7 @@ See the 0.4.33 release notes for more details.
The scope of a symbolic expression `e` can be read with `e.scope` and passed
into the above functions to direct them to construct symbolic expressions in
a given scope.
- See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
+ See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
* simplified and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0
({jax-issue}`#19231`; note that this may result in user-visible behavior
@@ -535,7 +535,7 @@ See the 0.4.33 release notes for more details.
strings for polymorphic shapes specifications ({jax-issue}`#19284`).
* JAX default native serialization version is now 9. This is relevant
for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`.
- See [description of version numbers](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
+ See [description of version numbers](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
@@ -781,19 +781,19 @@ See the 0.4.33 release notes for more details.
* When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
- [here](https://github.com/google/jax/pull/16949) for an example.
+ [here](https://github.com/jax-ml/jax/pull/16949) for an example.
This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two
separate unfiltered/filtered tracebacks, which was the old behavior) or
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
- [safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
+ [safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are
already hashable.
* Breaking changes:
* jax2tf now uses native serialization by default. See
- the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
+ the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for details and for mechanisms to override the default.
* The option `--jax_coordination_service` has been removed. It is now always
`True`.
@@ -922,7 +922,7 @@ See the 0.4.33 release notes for more details.
arguments will always resolve to the "common operands" `cond`
behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See
- [#16413](https://github.com/google/jax/issues/16413).
+ [#16413](https://github.com/jax-ml/jax/issues/16413).
* The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
which did nothing, have been removed. These options have been true by
default for many releases.
@@ -933,7 +933,7 @@ See the 0.4.33 release notes for more details.
serialization version ({jax-issue}`#16746`).
* jax2tf in presence of shape polymorphism now generates code that checks
certain shape constraints, if the serialization version is at least 7.
- See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
+ See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
## jaxlib 0.4.14 (July 27, 2023)
@@ -1095,14 +1095,14 @@ See the 0.4.33 release notes for more details.
{func}`jax.experimental.host_callback` is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the [JAX issue
- tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs
+ tracker](https://github.com/jax-ml/jax/issues) if the new `jax.debug` APIs
are insufficient for your use case.
The old runtime component will be available for at least the next three
months by setting the environment variable
`JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new
runtime for any reason, please let us know on the [JAX issue
- tracker](https://github.com/google/jax/issues).
+ tracker](https://github.com/jax-ml/jax/issues).
* Changes
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
@@ -1126,7 +1126,7 @@ See the 0.4.33 release notes for more details.
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
- See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
+ See [documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
As part of this change the config flag `--jax2tf_default_experimental_native_lowering`
has been renamed to `--jax2tf_native_serialization`.
* JAX now depends on `ml_dtypes`, which contains definitions of NumPy types
@@ -1403,7 +1403,7 @@ Changes:
## jaxlib 0.3.22 (Oct 11, 2022)
## jax 0.3.21 (Sep 30, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.20...jax-v0.3.21).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.20...jax-v0.3.21).
* Changes
* The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}`#12582`), so program execution can continue
@@ -1417,18 +1417,18 @@ Changes:
* Fix incorrect `pip` url in `setup.py` comment ({jax-issue}`#12528`).
## jaxlib 0.3.20 (Sep 28, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20).
* Bug fixes
* Fixes support for limiting the visible CUDA devices via
`jax_cuda_visible_devices` in distributed jobs. This functionality is needed for
the JAX/SLURM integration on GPU ({jax-issue}`#12533`).
## jax 0.3.19 (Sep 27, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.18...jax-v0.3.19).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.18...jax-v0.3.19).
* Fixes required jaxlib version.
## jax 0.3.18 (Sep 26, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.17...jax-v0.3.18).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.17...jax-v0.3.18).
* Changes
* Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}`#7733`) is stable and public. See [the
@@ -1446,7 +1446,7 @@ Changes:
would have been provided.
## jax 0.3.17 (Aug 31, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.16...jax-v0.3.17).
* Bugs
* Fix corner case issue in gradient of `lax.pow` with an exponent of zero
({jax-issue}`12041`)
@@ -1462,7 +1462,7 @@ Changes:
* `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead.
## jax 0.3.16
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main).
* Breaking changes
* Support for NumPy 1.19 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
@@ -1486,7 +1486,7 @@ Changes:
deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
## jax 0.3.15 (July 22, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...jax-v0.3.15).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15).
* Changes
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
@@ -1507,10 +1507,10 @@ Changes:
following a similar deprecation in {func}`scipy.linalg.solve`.
## jaxlib 0.3.15 (July 22, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15).
## jax 0.3.14 (June 27, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.13...jax-v0.3.14).
* Breaking changes
* {func}`jax.experimental.compilation_cache.initialize_cache` does not support
`max_cache_size_ bytes` anymore and will not get that as an input.
@@ -1563,22 +1563,22 @@ Changes:
coefficients have leading zeros ({jax-issue}`#11215`).
## jaxlib 0.3.14 (June 27, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14).
* x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14
was released in 2018, so this should not be a very onerous requirement.
* The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
* The Python flatbuffers package is no longer a dependency of jaxlib.
## jax 0.3.13 (May 16, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.12...jax-v0.3.13).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.12...jax-v0.3.13).
## jax 0.3.12 (May 15, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.11...jax-v0.3.12).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.11...jax-v0.3.12).
* Changes
- * Fixes [#10717](https://github.com/google/jax/issues/10717).
+ * Fixes [#10717](https://github.com/jax-ml/jax/issues/10717).
## jax 0.3.11 (May 15, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...jax-v0.3.11).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.10...jax-v0.3.11).
* Changes
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
that allows users to opt out of eigenvalue sorting on TPU.
@@ -1592,22 +1592,22 @@ Changes:
scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead.
## jax 0.3.10 (May 3, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.9...jax-v0.3.10).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.9...jax-v0.3.10).
## jaxlib 0.3.10 (May 3, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10).
* Changes
* [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a)
fixes an issue in the MHLO canonicalizer that caused constant folding to
take a long time or crash for certain programs.
## jax 0.3.9 (May 2, 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.8...jax-v0.3.9).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.8...jax-v0.3.9).
* Changes
* Added support for fully asynchronous checkpointing for GlobalDeviceArray.
## jax 0.3.8 (April 29 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...jax-v0.3.8).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.7...jax-v0.3.8).
* Changes
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
@@ -1666,7 +1666,7 @@ Changes:
## jax 0.3.7 (April 15, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.3.6...jax-v0.3.7).
* Changes:
* Fixed a performance problem if the indices passed to
{func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`).
@@ -1684,17 +1684,17 @@ Changes:
## jax 0.3.6 (April 12, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.3.5...jax-v0.3.6).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.3.5...jax-v0.3.6).
* Changes:
* Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU
- pod. Fixes [#10218](https://github.com/google/jax/issues/10218).
+ pod. Fixes [#10218](https://github.com/jax-ml/jax/issues/10218).
* Deprecations:
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
for an alternative API.
## jax 0.3.5 (April 7, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.3.4...jax-v0.3.5).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.3.4...jax-v0.3.5).
* Changes:
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
@@ -1717,17 +1717,17 @@ Changes:
## jax 0.3.4 (March 18, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.3.3...jax-v0.3.4).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.3.3...jax-v0.3.4).
## jax 0.3.3 (March 17, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.3.2...jax-v0.3.3).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.3.2...jax-v0.3.3).
## jax 0.3.2 (March 16, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.3.1...jax-v0.3.2).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.3.1...jax-v0.3.2).
* Changes:
* The functions `jax.ops.index_update`, `jax.ops.index_add`, which were
deprecated in 0.2.22, have been removed. Please use
@@ -1751,7 +1751,7 @@ Changes:
## jax 0.3.1 (Feb 18, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.3.0...jax-v0.3.1).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.3.0...jax-v0.3.1).
* Changes:
* `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated.
@@ -1774,7 +1774,7 @@ Changes:
## jax 0.3.0 (Feb 10, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.28...jax-v0.3.0).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0).
* Changes
* jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html)
@@ -1788,7 +1788,7 @@ Changes:
## jax 0.2.28 (Feb 1, 2022)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.27...jax-v0.2.28).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.27...jax-v0.2.28).
* `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no
`dialect=` is passed.
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
@@ -1813,7 +1813,7 @@ Changes:
* The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
## jax 0.2.27 (Jan 18 2022)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...jax-v0.2.27).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.26...jax-v0.2.27).
* Breaking changes:
* Support for NumPy 1.18 has been dropped, per the
@@ -1858,7 +1858,7 @@ Changes:
## jax 0.2.26 (Dec 8, 2021)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.25...jax-v0.2.26).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.25...jax-v0.2.26).
* Bug fixes:
* Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with
@@ -1875,7 +1875,7 @@ Changes:
## jax 0.2.25 (Nov 10, 2021)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.24...jax-v0.2.25).
* New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
@@ -1889,7 +1889,7 @@ Changes:
## jax 0.2.24 (Oct 19, 2021)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.22...jax-v0.2.24).
* New features:
* `jax.random.choice` and `jax.random.permutation` now support
@@ -1923,7 +1923,7 @@ Changes:
## jax 0.2.22 (Oct 12, 2021)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.21...jax-v0.2.22).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.21...jax-v0.2.22).
* Breaking Changes
* Static arguments to `jax.pmap` must now be hashable.
@@ -1958,13 +1958,13 @@ Changes:
* Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+.
* Bug fixes:
- * Fixes https://github.com/google/jax/issues/7461, which caused wrong
+ * Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong
outputs on all platforms due to incorrect buffer aliasing inside the XLA
compiler.
## jax 0.2.21 (Sept 23, 2021)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
@@ -1992,7 +1992,7 @@ Changes:
## jax 0.2.20 (Sept 2, 2021)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.19...jax-v0.2.20).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20).
* Breaking Changes
* `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`)
* `jnp.unique` and other set-like operations now require array-like inputs
@@ -2005,7 +2005,7 @@ Changes:
## jax 0.2.19 (Aug 12, 2021)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.18...jax-v0.2.19).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19).
* Breaking changes:
* Support for NumPy 1.17 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
@@ -2042,7 +2042,7 @@ Changes:
called in sequence.
## jax 0.2.18 (July 21 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...jax-v0.2.18).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18).
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
@@ -2065,7 +2065,7 @@ Changes:
* Fix bugs in TFRT CPU backend that results in incorrect results.
## jax 0.2.17 (July 9 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17).
* Bug fixes:
* Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68
to work around #7229, which caused wrong outputs on CPU due to a concurrency
@@ -2082,12 +2082,12 @@ Changes:
## jax 0.2.16 (June 23 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.15...jax-v0.2.16).
## jax 0.2.15 (June 23 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.14...jax-v0.2.15).
* New features:
- * [#7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend
+ * [#7042](https://github.com/jax-ml/jax/pull/7042) Turned on TFRT CPU backend
with significant dispatch performance improvements on CPU.
* The {func}`jax2tf.convert` supports inequalities and min/max for booleans
({jax-issue}`#6956`).
@@ -2107,7 +2107,7 @@ Changes:
CPU.
## jax 0.2.14 (June 10 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.13...jax-v0.2.14).
* New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
@@ -2165,7 +2165,7 @@ Changes:
{func}`jit` transformed functions.
## jax 0.2.13 (May 3 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.12...jax-v0.2.13).
* New features:
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify
@@ -2209,7 +2209,7 @@ Changes:
## jaxlib 0.1.65 (April 7 2021)
## jax 0.2.12 (April 1 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.11...v0.2.12).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
@@ -2222,7 +2222,7 @@ Changes:
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function`
- * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md)
+ * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md)
for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`).
@@ -2236,23 +2236,23 @@ Changes:
## jax 0.2.11 (March 23 2021)
* [GitHub
- commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11).
+ commits](https://github.com/jax-ml/jax/compare/jax-v0.2.10...jax-v0.2.11).
* New features:
- * [#6112](https://github.com/google/jax/pull/6112) added context managers:
+ * [#6112](https://github.com/jax-ml/jax/pull/6112) added context managers:
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
`jax.debug_infs`, `jax.log_compiles`.
- * [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete`
+ * [#6085](https://github.com/jax-ml/jax/pull/6085) added `jnp.delete`
* Bug fixes:
- * [#6136](https://github.com/google/jax/pull/6136) generalized
+ * [#6136](https://github.com/jax-ml/jax/pull/6136) generalized
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
- * [#6129](https://github.com/google/jax/issues/6129) fixed a bug with handling
+ * [#6129](https://github.com/jax-ml/jax/issues/6129) fixed a bug with handling
some constants like `enum.IntEnums`
- * [#6145](https://github.com/google/jax/pull/6145) fixed batching issues with
+ * [#6145](https://github.com/jax-ml/jax/pull/6145) fixed batching issues with
incomplete beta functions
- * [#6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during
+ * [#6014](https://github.com/jax-ml/jax/pull/6014) fixed H2D transfers during
tracing
- * [#6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when
+ * [#6165](https://github.com/jax-ml/jax/pull/6165) avoids OverflowErrors when
converting some large Python integers to floats
* Breaking changes:
* The minimum jaxlib version is now 0.1.62.
@@ -2264,13 +2264,13 @@ Changes:
## jax 0.2.10 (March 5 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.9...jax-v0.2.10).
* New features:
* {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
* {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods.
* Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
from JAX ({jax-issue}`#5627`)
- and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
+ and [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
* Extended the batching rule for `lax.pad` to support batching of the padding values.
* Bug fixes:
* {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`)
@@ -2314,7 +2314,7 @@ Changes:
## jax 0.2.9 (January 26 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.8...jax-v0.2.9).
* New features:
* Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved
error checking and error messages.
@@ -2330,7 +2330,7 @@ Changes:
## jax 0.2.8 (January 12 2021)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.7...jax-v0.2.8).
* New features:
* Add {func}`jax.closure_convert` for use with higher-order custom
derivative functions. ({jax-issue}`#5244`)
@@ -2362,7 +2362,7 @@ Changes:
## jax 0.2.7 (Dec 4 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.6...jax-v0.2.7).
* New features:
* Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit`
@@ -2382,14 +2382,14 @@ Changes:
## jax 0.2.6 (Nov 18 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.5...jax-v0.2.6).
* New Features:
* Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
- See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
+ See [README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
* Breaking change cleanup
* Raise an error on non-hashable static arguments for jax.jit and
- xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42).
+ xla_computation. See [cb48f42](https://github.com/jax-ml/jax/commit/cb48f42).
* Improve consistency of type promotion behavior ({jax-issue}`#4744`):
* Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
@@ -2441,15 +2441,15 @@ Changes:
## jax 0.2.5 (October 27 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.4...jax-v0.2.5).
* Improvements:
* Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`.
* Expanded the set of JAX primitives converted by jax2tf.
- See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).
+ See [primitives_with_limited_support.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).
## jax 0.2.4 (October 19 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.3...jax-v0.2.4).
* Improvements:
* Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`.
* Deprecations
@@ -2461,17 +2461,17 @@ Changes:
## jax 0.2.3 (October 14 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.2...jax-v0.2.3).
* The reason for another release so soon is we need to temporarily roll back a
new jit fastpath while we look into a performance degradation
## jax 0.2.2 (October 13 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.1...jax-v0.2.2).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.1...jax-v0.2.2).
## jax 0.2.1 (October 6 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.0...jax-v0.2.1).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.0...jax-v0.2.1).
* Improvements:
* As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/
@@ -2479,10 +2479,10 @@ Changes:
## jax (0.2.0) (September 23 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.77...jax-v0.2.0).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.77...jax-v0.2.0).
* Improvements:
* Omnistaging on by default. See {jax-issue}`#3370` and
- [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md)
+ [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md)
## jax (0.1.77) (September 15 2020)
@@ -2496,11 +2496,11 @@ Changes:
## jax 0.1.76 (September 8, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.75...jax-v0.1.76).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.75...jax-v0.1.76).
## jax 0.1.75 (July 30, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.74...jax-v0.1.75).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.74...jax-v0.1.75).
* Bug Fixes:
* make jnp.abs() work for unsigned inputs (#3914)
* Improvements:
@@ -2508,7 +2508,7 @@ Changes:
## jax 0.1.74 (July 29, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.73...jax-v0.1.74).
* New Features:
* BFGS (#3101)
* TPU support for half-precision arithmetic (#3878)
@@ -2525,7 +2525,7 @@ Changes:
## jax 0.1.73 (July 22, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.72...jax-v0.1.73).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.72...jax-v0.1.73).
* The minimum jaxlib version is now 0.1.51.
* New Features:
* jax.image.resize. (#3703)
@@ -2563,14 +2563,14 @@ Changes:
## jax 0.1.72 (June 28, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.71...jax-v0.1.72).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.71...jax-v0.1.72).
* Bug fixes:
* Fix an odeint bug introduced in the previous release, see
{jax-issue}`#3587`.
## jax 0.1.71 (June 25, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.70...jax-v0.1.71).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.70...jax-v0.1.71).
* The minimum jaxlib version is now 0.1.48.
* Bug fixes:
* Allow `jax.experimental.ode.odeint` dynamics functions to close over
@@ -2606,7 +2606,7 @@ Changes:
## jax 0.1.70 (June 8, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.69...jax-v0.1.70).
* New features:
* `lax.switch` introduces indexed conditionals with multiple
branches, together with a generalization of the `cond`
@@ -2615,11 +2615,11 @@ Changes:
## jax 0.1.69 (June 3, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.68...jax-v0.1.69).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.68...jax-v0.1.69).
## jax 0.1.68 (May 21, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.67...jax-v0.1.68).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.67...jax-v0.1.68).
* New features:
* {func}`lax.cond` supports a single-operand form, taken as the argument
to both branches
@@ -2630,7 +2630,7 @@ Changes:
## jax 0.1.67 (May 12, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.66...jax-v0.1.67).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.66...jax-v0.1.67).
* New features:
* Support for reduction over subsets of a pmapped axis using `axis_index_groups`
{jax-issue}`#2382`.
@@ -2648,7 +2648,7 @@ Changes:
## jax 0.1.66 (May 5, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.65...jax-v0.1.66).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.65...jax-v0.1.66).
* New features:
* Support for `in_axes=None` on {func}`pmap`
{jax-issue}`#2896`.
@@ -2661,7 +2661,7 @@ Changes:
## jax 0.1.65 (April 30, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.64...jax-v0.1.65).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.64...jax-v0.1.65).
* New features:
* Differentiation of determinants of singular matrices
{jax-issue}`#2809`.
@@ -2679,7 +2679,7 @@ Changes:
## jax 0.1.64 (April 21, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.63...jax-v0.1.64).
* New features:
* Add syntactic sugar for functional indexed updates
{jax-issue}`#2684`.
@@ -2706,7 +2706,7 @@ Changes:
## jax 0.1.63 (April 12, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63).
* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works).
* Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`.
* Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`.
@@ -2727,7 +2727,7 @@ Changes:
## jax 0.1.62 (March 21, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.61...jax-v0.1.62).
* JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
* Removed the internal function `lax._safe_mul`, which implemented the
convention `0. * nan == 0.`. This change means some programs when
@@ -2745,13 +2745,13 @@ Changes:
## jax 0.1.61 (March 17, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.60...jax-v0.1.61).
* Fixes Python 3.5 support. This will be the last JAX or jaxlib release that
supports Python 3.5.
## jax 0.1.60 (March 17, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.59...jax-v0.1.60).
* New features:
* {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows
the user to specify arguments that should be treated as compile-time
@@ -2777,7 +2777,7 @@ Changes:
## jax 0.1.59 (February 11, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.58...jax-v0.1.59).
* Breaking changes
* The minimum jaxlib version is now 0.1.38.
@@ -2809,7 +2809,7 @@ Changes:
## jax 0.1.58 (January 28, 2020)
-* [GitHub commits](https://github.com/google/jax/compare/46014da21...jax-v0.1.58).
+* [GitHub commits](https://github.com/jax-ml/jax/compare/46014da21...jax-v0.1.58).
* Breaking changes
* JAX has dropped Python 2 support, because Python 2 reached its end of life on
diff --git a/CITATION.bib b/CITATION.bib
index 88049a146..777058b5a 100644
--- a/CITATION.bib
+++ b/CITATION.bib
@@ -1,7 +1,7 @@
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
- url = {http://github.com/google/jax},
+ url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
diff --git a/README.md b/README.md
index 35307bee3..d67bdac82 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
-

+
# Transformable numerical computing at scale
-
+

[**Quickstart**](#quickstart-colab-in-the-cloud)
@@ -50,7 +50,7 @@ parallel programming of multiple accelerators, with more to come.
This is a research project, not an official Google product. Expect bugs and
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
Please help by trying it out, [reporting
-bugs](https://github.com/google/jax/issues), and letting us know what you
+bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
think!
```python
@@ -84,16 +84,16 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html)
-- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
+- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
-Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs).
+Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs).
For a deeper dive into JAX:
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- See the [full list of
-notebooks](https://github.com/google/jax/tree/main/docs/notebooks).
+notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks).
## Transformations
@@ -300,7 +300,7 @@ print(normalize(jnp.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
```
-You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
+You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
sophisticated communication patterns.
It all composes, so you're free to differentiate through parallel computations:
@@ -333,9 +333,9 @@ When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
backward pass of the computation is parallelized just like the forward pass.
See the [SPMD
-Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
+Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
and the [SPMD MNIST classifier from scratch
-example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
+example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
for more.
## Current gotchas
@@ -349,7 +349,7 @@ Some standouts:
1. [In-place mutating updates of
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
1. [Random numbers are
- different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md).
+ different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
1. If you're looking for [convolution
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
they're in the `jax.lax` package.
@@ -437,7 +437,7 @@ To cite this repository:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
- url = {http://github.com/google/jax},
+ url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb
index cb5a42ced..edaa71b93 100644
--- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb
+++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb
@@ -451,7 +451,7 @@
"id": "jC-KIMQ1q-lK"
},
"source": [
- "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
+ "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
]
},
{
diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb
index 4952cdbe9..d7ba5ed33 100644
--- a/cloud_tpu_colabs/JAX_demo.ipynb
+++ b/cloud_tpu_colabs/JAX_demo.ipynb
@@ -837,7 +837,7 @@
"id": "f-FBsWeo1AXE"
},
"source": [
- "
"
+ "
"
]
},
{
@@ -847,7 +847,7 @@
"id": "jC-KIMQ1q-lK"
},
"source": [
- "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
+ "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
]
},
{
diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb
index 981f0a9e8..ea126ac4f 100644
--- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb
+++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb
@@ -15,13 +15,13 @@
"id": "sk-3cPGIBTq8"
},
"source": [
- "[](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n",
"\n",
"This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n",
"\n",
"**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.\n",
"\n",
- "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
+ "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)."
]
},
{
diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md
index 4a795f718..db3dc5f30 100644
--- a/cloud_tpu_colabs/README.md
+++ b/cloud_tpu_colabs/README.md
@@ -13,25 +13,25 @@ VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab:
-### [Pmap Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
+### [Pmap Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
A guide to getting started with `pmap`, a transform for easily distributing SPMD
computations across devices.
-### [Lorentz ODE Solver](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb)
+### [Lorentz ODE Solver](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb)
Contributed by Alex Alemi (alexalemi@)
Solve and plot parallel ODE solutions with `pmap`.
-
+
-### [Wave Equation](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb)
+### [Wave Equation](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb)
Contributed by Stephan Hoyer (shoyer@)
Solve the wave equation with `pmap`, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU.
-
+
-### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb)
+### [JAX Demo](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb)
An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/) and the [Compilers for ML workshop at CGO 2020](https://www.c4ml.org/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`.
## Performance notes
@@ -53,7 +53,7 @@ By default\*, matrix multiplication in JAX on TPUs [uses bfloat16](https://cloud
JAX also adds the `bfloat16` dtype, which you can use to explicitly cast arrays to bfloat16, e.g., `jax.numpy.array(x, dtype=jax.numpy.bfloat16)`.
-\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/google/jax/issues/2161) if it affects you!
+\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/jax-ml/jax/issues/2161) if it affects you!
## Running JAX on a Cloud TPU VM
@@ -65,8 +65,8 @@ documentation](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPU
VM), please email , or if
you are a [TRC](https://sites.research.google/trc/) member. You can also [file a
-JAX issue](https://github.com/google/jax/issues) or [ask a discussion
-question](https://github.com/google/jax/discussions) for any issues with these
+JAX issue](https://github.com/jax-ml/jax/issues) or [ask a discussion
+question](https://github.com/jax-ml/jax/discussions) for any issues with these
notebooks or using JAX in general.
If you have any other questions or comments regarding JAX on Cloud TPUs, please
diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md
index 180f65f5d..287487ad4 100644
--- a/docs/_tutorials/advanced-autodiff.md
+++ b/docs/_tutorials/advanced-autodiff.md
@@ -571,7 +571,7 @@ print("Naive full Hessian materialization")
### Jacobian-Matrix and Matrix-Jacobian products
-Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:
+Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:
```{code-cell}
# Isolate the function from the weight matrix to the predictions
diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb
index ed242ecc5..9a956670c 100644
--- a/docs/autodidax.ipynb
+++ b/docs/autodidax.ipynb
@@ -27,7 +27,7 @@
"metadata": {},
"source": [
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)"
+ "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)"
]
},
{
@@ -1781,7 +1781,7 @@
"metadata": {},
"source": [
"This is precisely the issue that\n",
- "[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n",
+ "[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.\n",
"We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n",
"applied, regardless of whether any inputs to `bind` are boxed in corresponding\n",
"`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n",
diff --git a/docs/autodidax.md b/docs/autodidax.md
index 471dd7c63..937e1012a 100644
--- a/docs/autodidax.md
+++ b/docs/autodidax.md
@@ -33,7 +33,7 @@ limitations under the License.
```
[](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)
+Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)
+++
@@ -1399,7 +1399,7 @@ print(jaxpr)
```
This is precisely the issue that
-[omnistaging](https://github.com/google/jax/pull/3370) fixed.
+[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.
We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always
applied, regardless of whether any inputs to `bind` are boxed in corresponding
`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`
diff --git a/docs/autodidax.py b/docs/autodidax.py
index 6d295fc50..c10e6365e 100644
--- a/docs/autodidax.py
+++ b/docs/autodidax.py
@@ -27,7 +27,7 @@
# ---
# [](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)
+# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)
# # Autodidax: JAX core from scratch
#
@@ -1396,7 +1396,7 @@ print(jaxpr)
# This is precisely the issue that
-# [omnistaging](https://github.com/google/jax/pull/3370) fixed.
+# [omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.
# We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always
# applied, regardless of whether any inputs to `bind` are boxed in corresponding
# `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`
diff --git a/docs/beginner_guide.rst b/docs/beginner_guide.rst
index 204659ec2..783d3b49a 100644
--- a/docs/beginner_guide.rst
+++ b/docs/beginner_guide.rst
@@ -52,4 +52,4 @@ questions answered are:
.. _Flax: https://flax.readthedocs.io/
.. _Haiku: https://dm-haiku.readthedocs.io/
.. _JAX on StackOverflow: https://stackoverflow.com/questions/tagged/jax
-.. _JAX GitHub discussions: https://github.com/google/jax/discussions
\ No newline at end of file
+.. _JAX GitHub discussions: https://github.com/jax-ml/jax/discussions
\ No newline at end of file
diff --git a/docs/conf.py b/docs/conf.py
index ed6fcfd0d..e77916e26 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -168,7 +168,7 @@ html_theme = 'sphinx_book_theme'
# documentation.
html_theme_options = {
'show_toc_level': 2,
- 'repository_url': 'https://github.com/google/jax',
+ 'repository_url': 'https://github.com/jax-ml/jax',
'use_repository_button': True, # add a "link to repository" button
'navigation_with_keys': False,
}
@@ -345,7 +345,7 @@ def linkcode_resolve(domain, info):
return None
filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__))
lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else ""
- return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}"
+ return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}"
# Generate redirects from deleted files to new sources
rediraffe_redirects = {
diff --git a/docs/contributing.md b/docs/contributing.md
index d7fa6e9da..99d78453c 100644
--- a/docs/contributing.md
+++ b/docs/contributing.md
@@ -5,22 +5,22 @@
Everyone can contribute to JAX, and we value everyone's contributions. There are several
ways to contribute, including:
-- Answering questions on JAX's [discussions page](https://github.com/google/jax/discussions)
+- Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions)
- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/)
-- Contributing to JAX's [code-base](http://github.com/google/jax/)
-- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/google/jax#neural-network-libraries)
+- Contributing to JAX's [code-base](http://github.com/jax-ml/jax/)
+- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries)
The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
## Ways to contribute
We welcome pull requests, in particular for those issues marked with
-[contributions welcome](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or
-[good first issue](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
+[contributions welcome](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or
+[good first issue](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
For other proposals, we ask that you first open a GitHub
-[Issue](https://github.com/google/jax/issues/new/choose) or
-[Discussion](https://github.com/google/jax/discussions)
+[Issue](https://github.com/jax-ml/jax/issues/new/choose) or
+[Discussion](https://github.com/jax-ml/jax/discussions)
to seek feedback on your planned contribution.
## Contributing code using pull requests
@@ -33,7 +33,7 @@ Follow these steps to contribute code:
For more information, see the Pull Request Checklist below.
2. Fork the JAX repository by clicking the **Fork** button on the
- [repository page](http://www.github.com/google/jax). This creates
+ [repository page](http://www.github.com/jax-ml/jax). This creates
a copy of the JAX repository in your own account.
3. Install Python >= 3.10 locally in order to run tests.
@@ -52,7 +52,7 @@ Follow these steps to contribute code:
changes.
```bash
- git remote add upstream https://www.github.com/google/jax
+ git remote add upstream https://www.github.com/jax-ml/jax
```
6. Create a branch where you will develop from:
diff --git a/docs/developer.md b/docs/developer.md
index 40ad51e87..4f3361413 100644
--- a/docs/developer.md
+++ b/docs/developer.md
@@ -6,7 +6,7 @@
First, obtain the JAX source code:
```
-git clone https://github.com/google/jax
+git clone https://github.com/jax-ml/jax
cd jax
```
@@ -26,7 +26,7 @@ If you're only modifying Python portions of JAX, we recommend installing
pip install jaxlib
```
-See the [JAX readme](https://github.com/google/jax#installation) for full
+See the [JAX readme](https://github.com/jax-ml/jax#installation) for full
guidance on pip installation (e.g., for GPU and TPU support).
### Building `jaxlib` from source
@@ -621,7 +621,7 @@ pytest --doctest-modules jax/_src/numpy/lax_numpy.py
Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in
-[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml)
+[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml)
## Type checking
@@ -712,7 +712,7 @@ jupytext --sync docs/notebooks/thinking_in_jax.ipynb
```
The jupytext version should match that specified in
-[.pre-commit-config.yaml](https://github.com/google/jax/blob/main/.pre-commit-config.yaml).
+[.pre-commit-config.yaml](https://github.com/jax-ml/jax/blob/main/.pre-commit-config.yaml).
To check that the markdown and ipynb files are properly synced, you may use the
[pre-commit](https://pre-commit.com/) framework to perform the same check used
@@ -740,12 +740,12 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked
Some of the notebooks are built automatically as part of the pre-submit checks and
as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build.
The build will fail if cells raise errors. If the errors are intentional, you can either catch them,
-or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)).
+or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)).
You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else
re-saves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations.
-See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs/conf.py).
+See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs/conf.py).
### Documentation building on `readthedocs.io`
@@ -772,7 +772,7 @@ I saw in the Readthedocs logs:
mkvirtualenv jax-docs # A new virtualenv
mkdir jax-docs # A new directory
cd jax-docs
-git clone --no-single-branch --depth 50 https://github.com/google/jax
+git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax
cd jax
git checkout --force origin/test-docs
git clean -d -f -f
diff --git a/docs/export/export.md b/docs/export/export.md
index 0ca1a6480..4e4d50556 100644
--- a/docs/export/export.md
+++ b/docs/export/export.md
@@ -153,7 +153,7 @@ JAX runtime system that are:
an inference system that is already deployed when the exporting is done.
(The particular compatibility window lengths are the same that JAX
-[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model),
+[promised for jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model),
and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow).
The terminology “backward compatibility” is from the perspective of the consumer,
e.g., the inference system.)
@@ -626,7 +626,7 @@ We list here a history of the calling convention version numbers:
June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`.
- See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
+ See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
since July 12th, 2023 (cl/547482522),
available in JAX serialization since July 20th, 2023 (JAX 0.4.14),
and the default since August 12th, 2023 (JAX 0.4.15).
@@ -721,7 +721,7 @@ that live in jaxlib):
2. Day “D”, we add the new custom call target `T_NEW`.
We should create a new custom call target, and clean up the old
target roughly after 6 months, rather than updating `T` in place:
- * See the example [PR #20997](https://github.com/google/jax/pull/20997)
+ * See the example [PR #20997](https://github.com/jax-ml/jax/pull/20997)
implementing the steps below.
* We add the custom call target `T_NEW`.
* We change the JAX lowering rules that were previous using `T`,
diff --git a/docs/export/jax2tf.md b/docs/export/jax2tf.md
index 498a0418f..9c0ee90a0 100644
--- a/docs/export/jax2tf.md
+++ b/docs/export/jax2tf.md
@@ -2,4 +2,4 @@
## Interoperation with TensorFlow
-See the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
+See the [JAX2TF documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md).
diff --git a/docs/faq.rst b/docs/faq.rst
index 3ac7d89fb..af14f382b 100644
--- a/docs/faq.rst
+++ b/docs/faq.rst
@@ -372,7 +372,7 @@ device.
Jitted functions behave like any other primitive operations—they will follow the
data and will show errors if invoked on data committed on more than one device.
-(Before `PR #6002 `_ in March 2021
+(Before `PR #6002 `_ in March 2021
there was some laziness in creation of array constants, so that
``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually
create the array of zeros on ``jax.devices()[1]``, instead of creating the
@@ -385,7 +385,7 @@ and its use is not recommended.)
For a worked-out example, we recommend reading through
``test_computation_follows_data`` in
-`multi_device_test.py `_.
+`multi_device_test.py `_.
.. _faq-benchmark:
@@ -691,7 +691,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.::
Additional reading:
- * `Issue: gradients through jnp.where when one of branches is nan `_.
+ * `Issue: gradients through jnp.where when one of branches is nan `_.
* `How to avoid NaN gradients when using where `_.
diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb
index 7f7bcc07c..04ae80cbf 100644
--- a/docs/ffi.ipynb
+++ b/docs/ffi.ipynb
@@ -406,7 +406,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)."
+ "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
]
},
{
@@ -492,7 +492,7 @@
"source": [
"At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n",
"One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n",
- "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n",
+ "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n",
"\n",
"One other JAX feature that this example doesn't support is higher-order AD.\n",
"It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n",
diff --git a/docs/ffi.md b/docs/ffi.md
index d96d9ff8c..03acf876b 100644
--- a/docs/ffi.md
+++ b/docs/ffi.md
@@ -333,7 +333,7 @@ def rms_norm_not_vectorized(x, eps=1e-5):
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
```
-If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues).
+If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
+++
@@ -406,7 +406,7 @@ np.testing.assert_allclose(
At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.
One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.
-JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.
+JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.
One other JAX feature that this example doesn't support is higher-order AD.
It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.
diff --git a/docs/installation.md b/docs/installation.md
index 93df4a240..acb802ea9 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -176,7 +176,7 @@ installation.
JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.
Make sure that it is present in your CUDA installation.
-Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues)
+Please let the JAX team know on [the GitHub issue tracker](https://github.com/jax-ml/jax/issues)
if you run into any errors or problems with the pre-built wheels.
(docker-containers-nvidia-gpu)=
@@ -216,7 +216,7 @@ refer to
**Note:** There are several caveats with the Metal plugin:
* The Metal plugin is new and experimental and has a number of
- [known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
+ [known issues](https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
Please report any issues on the JAX issue tracker.
* The Metal plugin currently requires very specific versions of `jax` and
`jaxlib`. This restriction will be relaxed over time as the plugin API
diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md
index 389cc0b5a..61d219d1b 100644
--- a/docs/investigating_a_regression.md
+++ b/docs/investigating_a_regression.md
@@ -9,7 +9,7 @@ Let's first make a JAX issue.
But if you can pinpoint the commit that triggered the regression, it will really help us.
This document explains how we identified the commit that caused a
-[15% performance regression](https://github.com/google/jax/issues/17686).
+[15% performance regression](https://github.com/jax-ml/jax/issues/17686).
## Steps
@@ -34,7 +34,7 @@ containers](https://github.com/NVIDIA/JAX-Toolbox).
- test_runner.sh: will start the containers and the test.
- test.sh: will install missing dependencies and run the test
-Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686
+Here are real example scripts used for the issue: https://github.com/jax-ml/jax/issues/17686
- test_runner.sh:
```
for m in 7 8 9; do
diff --git a/docs/jep/11830-new-remat-checkpoint.md b/docs/jep/11830-new-remat-checkpoint.md
index da0adaf18..019188349 100644
--- a/docs/jep/11830-new-remat-checkpoint.md
+++ b/docs/jep/11830-new-remat-checkpoint.md
@@ -14,7 +14,7 @@
## What’s going on?
-As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade)
+As of [#11830](https://github.com/jax-ml/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade)
## How can I disable the change, and go back to the old behavior for now?
@@ -29,7 +29,7 @@ If you need to revert to the old implementation, **please reach out** on a GitHu
As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer
available. If you have an issue, please reach out on [the issue
-tracker](https://github.com/google/jax/issues) so we can help fix it!
+tracker](https://github.com/jax-ml/jax/issues) so we can help fix it!
## Why are we doing this?
@@ -82,7 +82,7 @@ The old `jax.checkpoint` implementation was forced to save the value of `a`, whi
### Significantly less Python overhead in some cases
-The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/google/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice!
+The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/jax-ml/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice!
### Enabling new JAX features by simplifying internals
diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md
index 9137e3e71..7a20958c5 100644
--- a/docs/jep/12049-type-annotations.md
+++ b/docs/jep/12049-type-annotations.md
@@ -12,7 +12,7 @@ The current state of type annotations in JAX is a bit patchwork, and efforts to
This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX.
Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally.
-In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/google/jax/pull/9917), [PR #10322](https://github.com/google/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python.
+In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/jax-ml/jax/pull/9917), [PR #10322](https://github.com/jax-ml/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python.
This document details JAX's goals and recommendations for type annotations within the package.
## Why type annotations?
@@ -21,7 +21,7 @@ There are a number of reasons that a Python project might wish to annotate their
### Level 1: Annotations as documentation
-When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]:
+When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]:
```python
Array = Any
@@ -44,14 +44,14 @@ Many modern IDEs take advantage of type annotations as inputs to [intelligent co
This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development.
-JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]:
+JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]:
```python
def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
...
```
-In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:
+In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:

@@ -232,7 +232,7 @@ assert jit(f)(x) # x will be a tracer
```
Again, there are a couple mechanisms that could be used for this:
-- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/google/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)).
+- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/jax-ml/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)).
- define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer`
- restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer`
diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md
index bec060001..a5625abf8 100644
--- a/docs/jep/15856-jex.md
+++ b/docs/jep/15856-jex.md
@@ -170,7 +170,7 @@ print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
This module could expose our mechanism for defining new RNG
implementations, and functions for working with PRNG key internals
-(see issue [#9263](https://github.com/google/jax/issues/9263)),
+(see issue [#9263](https://github.com/jax-ml/jax/issues/9263)),
such as the current `jax._src.prng.random_wrap` and
`random_unwrap`.
diff --git a/docs/jep/18137-numpy-scipy-scope.md b/docs/jep/18137-numpy-scipy-scope.md
index 2371e11ee..eaebe8fb8 100644
--- a/docs/jep/18137-numpy-scipy-scope.md
+++ b/docs/jep/18137-numpy-scipy-scope.md
@@ -78,8 +78,8 @@ to JAX which have relatively complex implementations which are difficult to vali
and introduce outsized maintenance burdens; an example is {func}`jax.scipy.special.bessel_jn`:
as of the writing of this JEP, its current implementation is a non-straightforward
iterative approximation that has
-[convergence issues in some domains](https://github.com/google/jax/issues/12402#issuecomment-1384828637),
-and [proposed fixes](https://github.com/google/jax/pull/17038/files) introduce further
+[convergence issues in some domains](https://github.com/jax-ml/jax/issues/12402#issuecomment-1384828637),
+and [proposed fixes](https://github.com/jax-ml/jax/pull/17038/files) introduce further
complexity. Had we more carefully weighed the complexity and robustness of the
implementation when accepting the contribution, we may have chosen not to accept this
contribution to the package.
diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md
index aa568adc0..ce149fa6f 100644
--- a/docs/jep/2026-custom-derivatives.md
+++ b/docs/jep/2026-custom-derivatives.md
@@ -35,9 +35,9 @@ behavior of their code. This customization
Python control flow and workflows for NaN debugging.
As **JAX developers** we want to write library functions, like
-[`logit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83)
+[`logit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83)
and
-[`expit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91),
+[`expit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91),
that are defined in terms of other primitives, but for the purposes of
differentiation have primitive-like behavior in the sense that we want to define
custom differentiation rules for them, which may be more numerically stable or
@@ -50,9 +50,9 @@ looking to add custom differentiation rules for higher-order functions like
want to be confident we’re not going to preclude good solutions to that problem.
That is, our primary goals are
-1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/google/jax/issues/1249)), and
+1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/jax-ml/jax/issues/1249)), and
2. allow Python in custom VJPs, e.g. to debug NaNs
- ([#1275](https://github.com/google/jax/issues/1275)).
+ ([#1275](https://github.com/jax-ml/jax/issues/1275)).
Secondary goals are
3. clean up and simplify user experience (symbolic zeros, kwargs, etc)
@@ -60,18 +60,18 @@ Secondary goals are
`odeint`, `root`, etc.
Overall, we want to close
-[#116](https://github.com/google/jax/issues/116),
-[#1097](https://github.com/google/jax/issues/1097),
-[#1249](https://github.com/google/jax/issues/1249),
-[#1275](https://github.com/google/jax/issues/1275),
-[#1366](https://github.com/google/jax/issues/1366),
-[#1723](https://github.com/google/jax/issues/1723),
-[#1670](https://github.com/google/jax/issues/1670),
-[#1875](https://github.com/google/jax/issues/1875),
-[#1938](https://github.com/google/jax/issues/1938),
+[#116](https://github.com/jax-ml/jax/issues/116),
+[#1097](https://github.com/jax-ml/jax/issues/1097),
+[#1249](https://github.com/jax-ml/jax/issues/1249),
+[#1275](https://github.com/jax-ml/jax/issues/1275),
+[#1366](https://github.com/jax-ml/jax/issues/1366),
+[#1723](https://github.com/jax-ml/jax/issues/1723),
+[#1670](https://github.com/jax-ml/jax/issues/1670),
+[#1875](https://github.com/jax-ml/jax/issues/1875),
+[#1938](https://github.com/jax-ml/jax/issues/1938),
and replace the custom_transforms machinery (from
-[#636](https://github.com/google/jax/issues/636),
-[#818](https://github.com/google/jax/issues/818),
+[#636](https://github.com/jax-ml/jax/issues/636),
+[#818](https://github.com/jax-ml/jax/issues/818),
and others).
## Non-goals
@@ -400,7 +400,7 @@ There are some other bells and whistles to the API:
resolved to positions using the `inspect` module. This is a bit of an experiment
with Python 3’s improved ability to programmatically inspect argument
signatures. I believe it is sound but not complete, which is a fine place to be.
- (See also [#2069](https://github.com/google/jax/issues/2069).)
+ (See also [#2069](https://github.com/jax-ml/jax/issues/2069).)
* Arguments can be marked non-differentiable using `nondiff_argnums`, and as with
`jit`’s `static_argnums` these arguments don’t have to be JAX types. We need to
set a convention for how these arguments are passed to the rules. For a primal
@@ -433,5 +433,5 @@ There are some other bells and whistles to the API:
`custom_lin` to the tangent values; `custom_lin` carries with it the user’s
custom backward-pass function, and as a primitive it only has a transpose
rule.
- * This mechanism is described more in [#636](https://github.com/google/jax/issues/636).
+ * This mechanism is described more in [#636](https://github.com/jax-ml/jax/issues/636).
* To prevent
diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md
index 65235dc64..1e2270e05 100644
--- a/docs/jep/4008-custom-vjp-update.md
+++ b/docs/jep/4008-custom-vjp-update.md
@@ -9,7 +9,7 @@ notebook.
## What to update
-After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments
+After JAX [PR #4008](https://github.com/jax-ml/jax/pull/4008), the arguments
passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or
containers of `Tracer`s), which basically means to allow for
arbitrarily-transformable code `nondiff_argnums` shouldn't be used for
@@ -95,7 +95,7 @@ acted very much like lexical closure. But lexical closure over `Tracer`s wasn't
at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing
`nondiff_argnums` that way was a mistake!
-**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure
+**[PR #4008](https://github.com/jax-ml/jax/pull/4008) fixes all lexical closure
issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp`
and `custom_vjp` functions and rules can close over `Tracer`s to our hearts'
content. For all non-autodiff transformations, things will Just Work. For
@@ -120,9 +120,9 @@ manageable, until you think through how we have to handle arbitrary pytrees!
Moreover, that complexity isn't necessary: if user code treats array-like
non-differentiable arguments just like regular arguments and residuals,
everything already works. (Before
-[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about
+[#4039](https://github.com/jax-ml/jax/pull/4039) JAX might've complained about
involving integer-valued inputs and outputs in autodiff, but after
-[#4039](https://github.com/google/jax/pull/4039) those will just work!)
+[#4039](https://github.com/jax-ml/jax/pull/4039) those will just work!)
Unlike `custom_vjp`, it was easy to make `custom_jvp` work with
`nondiff_argnums` arguments that were `Tracer`s. So these updates only need to
diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md
index eb68ee5f0..f95c15f40 100644
--- a/docs/jep/4410-omnistaging.md
+++ b/docs/jep/4410-omnistaging.md
@@ -20,7 +20,7 @@ This is more of an upgrade guide than a design doc.
### What's going on?
A change to JAX's tracing infrastructure called “omnistaging”
-([google/jax#3370](https://github.com/google/jax/pull/3370)) was switched on in
+([jax-ml/jax#3370](https://github.com/jax-ml/jax/pull/3370)) was switched on in
jax==0.2.0. This change improves memory performance, trace execution time, and
simplifies jax internals, but may cause some existing code to break. Breakage is
usually a result of buggy code, so long-term it’s best to fix the bugs, but
@@ -191,7 +191,7 @@ and potentially even fragmenting memory.
(The `broadcast` that corresponds to the construction of the zeros array for
`jnp.zeros_like(x)` is staged out because JAX is lazy about very simple
-expressions from [google/jax#1668](https://github.com/google/jax/pull/1668). After
+expressions from [jax-ml/jax#1668](https://github.com/jax-ml/jax/pull/1668). After
omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)
The reason the creation of `mask` is not staged out is that, before omnistaging,
diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md
index 828b95e8c..d520f6f63 100644
--- a/docs/jep/9263-typed-keys.md
+++ b/docs/jep/9263-typed-keys.md
@@ -321,7 +321,7 @@ Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same
extended dtype mechanism elsewhere internally. For example, the
`jax._src.core.bint` object, a bounded integer type used for experimental work
on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies
-the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)).
+the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/jax-ml/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)).
### PRNG dtypes
PRNG dtypes are defined as a particular case of extended dtypes. Specifically,
diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb
index 3e99daabe..a1ede3177 100644
--- a/docs/jep/9407-type-promotion.ipynb
+++ b/docs/jep/9407-type-promotion.ipynb
@@ -8,7 +8,7 @@
"source": [
"# Design of Type Promotion Semantics for JAX\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
"\n",
"*Jake VanderPlas, December 2021*\n",
"\n",
diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md
index cdb1f7805..ff67a8c21 100644
--- a/docs/jep/9407-type-promotion.md
+++ b/docs/jep/9407-type-promotion.md
@@ -16,7 +16,7 @@ kernelspec:
# Design of Type Promotion Semantics for JAX
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
*Jake VanderPlas, December 2021*
diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md
index 759a9be86..b964aa2af 100644
--- a/docs/jep/9419-jax-versioning.md
+++ b/docs/jep/9419-jax-versioning.md
@@ -58,11 +58,11 @@ These constraints imply the following rules for releases:
* If a new `jaxlib` is released, a `jax` release must be made at the same time.
These
-[version constraints](https://github.com/google/jax/blob/main/jax/version.py)
+[version constraints](https://github.com/jax-ml/jax/blob/main/jax/version.py)
are currently checked by `jax` at import time, instead of being expressed as
Python package version constraints. `jax` checks the `jaxlib` version at
runtime rather than using a `pip` package version constraint because we
-[provide separate `jaxlib` wheels](https://github.com/google/jax#installation)
+[provide separate `jaxlib` wheels](https://github.com/jax-ml/jax#installation)
for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we
do not know which is the right choice for any given user, we do not want `pip`
to install a `jaxlib` package for us automatically.
@@ -119,7 +119,7 @@ no released `jax` version uses that API.
## How is the source to `jaxlib` laid out?
`jaxlib` is split across two main repositories, namely the
-[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib)
+[`jaxlib/` subdirectory in the main JAX repository](https://github.com/jax-ml/jax/tree/main/jaxlib)
and in the
[XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla).
The JAX-specific pieces inside XLA are primarily in the
@@ -146,7 +146,7 @@ level.
`jaxlib` is built using Bazel out of the `jax` repository. The pieces of
`jaxlib` from the XLA repository are incorporated into the build
-[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE).
+[as a Bazel submodule](https://github.com/jax-ml/jax/blob/main/WORKSPACE).
To update the version of XLA used during the build, one must update the pinned
version in the Bazel `WORKSPACE`. This is done manually on an
as-needed basis, but can be overridden on a build-by-build basis.
diff --git a/docs/jep/index.rst b/docs/jep/index.rst
index 194eb0cb9..f9dda2657 100644
--- a/docs/jep/index.rst
+++ b/docs/jep/index.rst
@@ -32,7 +32,7 @@ should be linked to this issue.
Then create a pull request that adds a file named
`%d-{short-title}.md` - with the number being the issue number.
-.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP
+.. _JEP label: https://github.com/jax-ml/jax/issues?q=label%3AJEP
.. toctree::
:maxdepth: 1
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
index 0cffc22f1..71bd45276 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb
+++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
]
},
{
@@ -661,7 +661,7 @@
"source": [
"Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n",
"\n",
- "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
+ "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
]
},
{
@@ -1003,7 +1003,7 @@
"id": "COjzGBpO4tzL"
},
"source": [
- "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
+ "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by a special array element that we call a __key__:"
]
@@ -1349,7 +1349,7 @@
"\n",
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
"\n",
- "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
+ "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
"\n",
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
"\n",
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md
index 543d9ecb1..741fa3af0 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.md
+++ b/docs/notebooks/Common_Gotchas_in_JAX.md
@@ -18,7 +18,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
+++ {"id": "4k5PVzEo2uJO"}
@@ -312,7 +312,7 @@ jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.
-Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
+Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
+++ {"id": "LwB07Kx5sgHu"}
@@ -460,7 +460,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
+++ {"id": "COjzGBpO4tzL"}
-JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
+JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by a special array element that we call a __key__:
@@ -623,7 +623,7 @@ When we `jit`-compile a function, we usually want to compile a version of the fu
For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.
-To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
+To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
index dd7a36e57..5c09a0a4f 100644
--- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
+++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
"\n",
"There are two ways to define differentiation rules in JAX:\n",
"\n",
diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md
index 930887af1..8a9b93155 100644
--- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md
+++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md
@@ -17,7 +17,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
There are two ways to define differentiation rules in JAX:
diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb
index 8bc0e0a52..32d332d9a 100644
--- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb
+++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb
@@ -17,7 +17,7 @@
"id": "pFtQjv4SzHRj"
},
"source": [
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
"\n",
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer."
]
diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md
index 97b07172b..2142db986 100644
--- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md
+++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md
@@ -19,7 +19,7 @@ kernelspec:
+++ {"id": "pFtQjv4SzHRj"}
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)
This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.
diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb
index 0c20fc47d..e9924e18d 100644
--- a/docs/notebooks/How_JAX_primitives_work.ipynb
+++ b/docs/notebooks/How_JAX_primitives_work.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
"\n",
"*necula@google.com*, October 2019.\n",
"\n",
diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md
index b926c22ea..7c24ac11a 100644
--- a/docs/notebooks/How_JAX_primitives_work.md
+++ b/docs/notebooks/How_JAX_primitives_work.md
@@ -17,7 +17,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
*necula@google.com*, October 2019.
diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb
index a4a4d7d16..a7ef2a017 100644
--- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb
+++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
"\n",
"**Copyright 2018 The JAX Authors.**\n",
"\n",
@@ -32,9 +32,9 @@
"id": "B_XlLLpcWjkA"
},
"source": [
- "\n",
+ "\n",
"\n",
- "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
+ "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
"\n",
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
]
diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md
index 03b8415fc..cd98022e7 100644
--- a/docs/notebooks/Neural_Network_and_Data_Loading.md
+++ b/docs/notebooks/Neural_Network_and_Data_Loading.md
@@ -18,7 +18,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
**Copyright 2018 The JAX Authors.**
@@ -35,9 +35,9 @@ limitations under the License.
+++ {"id": "B_XlLLpcWjkA"}
-
+
-Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
+Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.
diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
index 2c231bf99..00ba9186e 100644
--- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
+++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
]
},
{
@@ -79,7 +79,7 @@
"id": "gA8V51wZdsjh"
},
"source": [
- "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/google/jax#how-it-works) section in the README."
+ "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/jax-ml/jax#how-it-works) section in the README."
]
},
{
@@ -320,7 +320,7 @@
"source": [
"Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n",
"\n",
- "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
+ "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
]
},
{
@@ -333,7 +333,7 @@
"\n",
"An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n",
"\n",
- "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
+ "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
]
},
{
diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md
index 41d7a7e51..10c4e7cb6 100644
--- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md
+++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md
@@ -18,7 +18,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
+++ {"id": "r-3vMiKRYXPJ"}
@@ -57,7 +57,7 @@ fast_f = jit(f)
+++ {"id": "gA8V51wZdsjh"}
-When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/google/jax#how-it-works) section in the README.
+When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/jax-ml/jax#how-it-works) section in the README.
+++ {"id": "2Th1vYLVaFBz"}
@@ -223,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
Notice that `eval_jaxpr` will always return a flat list even if the original function does not.
-Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
+Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
+++ {"id": "0vb2ZoGrCMM4"}
@@ -231,7 +231,7 @@ Furthermore, this interpreter does not handle higher-order primitives (like `jit
An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.
-It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234).
+It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234).
```{code-cell} ipython3
:id: gSMIT2z1vUpO
diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb
index 3f2f0fd56..5538b70da 100644
--- a/docs/notebooks/autodiff_cookbook.ipynb
+++ b/docs/notebooks/autodiff_cookbook.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
"\n",
"JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics."
]
@@ -255,7 +255,7 @@
"id": "cJ2NxiN58bfI"
},
"source": [
- "You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)."
+ "You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)."
]
},
{
@@ -1015,7 +1015,7 @@
"source": [
"### Jacobian-Matrix and Matrix-Jacobian products\n",
"\n",
- "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products."
+ "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products."
]
},
{
diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md
index 496d676f7..db6fde805 100644
--- a/docs/notebooks/autodiff_cookbook.md
+++ b/docs/notebooks/autodiff_cookbook.md
@@ -18,7 +18,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.
@@ -151,7 +151,7 @@ print(grad(loss2)({'W': W, 'b': b}))
+++ {"id": "cJ2NxiN58bfI"}
-You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.).
+You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.).
+++ {"id": "PaCHzAtGruBz"}
@@ -592,7 +592,7 @@ print("Naive full Hessian materialization")
### Jacobian-Matrix and Matrix-Jacobian products
-Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
+Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
```{code-cell} ipython3
:id: asAWvxVaCmsx
diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb
index f628625bd..9d91804b6 100644
--- a/docs/notebooks/convolutions.ipynb
+++ b/docs/notebooks/convolutions.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
"\n",
"JAX provides a number of interfaces to compute convolutions across data, including:\n",
"\n",
diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md
index 83ab2d9fd..b98099aa9 100644
--- a/docs/notebooks/convolutions.md
+++ b/docs/notebooks/convolutions.md
@@ -18,7 +18,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)
JAX provides a number of interfaces to compute convolutions across data, including:
diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb
index 91f2ee571..c31a99746 100644
--- a/docs/notebooks/neural_network_with_tfds_data.ipynb
+++ b/docs/notebooks/neural_network_with_tfds_data.ipynb
@@ -40,11 +40,11 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
"\n",
"_Forked from_ `neural_network_and_data_loading.ipynb`\n",
"\n",
- "\n",
+ "\n",
"\n",
"Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n",
"\n",
diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md
index 480e7477b..53b7d4735 100644
--- a/docs/notebooks/neural_network_with_tfds_data.md
+++ b/docs/notebooks/neural_network_with_tfds_data.md
@@ -38,11 +38,11 @@ limitations under the License.
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
_Forked from_ `neural_network_and_data_loading.ipynb`
-
+
Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).
diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb
index e4f9d888e..b5f8074c0 100644
--- a/docs/notebooks/thinking_in_jax.ipynb
+++ b/docs/notebooks/thinking_in_jax.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
"\n",
"JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively."
]
diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md
index dd0c73ec7..b3672b90e 100644
--- a/docs/notebooks/thinking_in_jax.md
+++ b/docs/notebooks/thinking_in_jax.md
@@ -17,7 +17,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.
diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb
index 9aef1a8eb..dccc83168 100644
--- a/docs/notebooks/vmapped_log_probs.ipynb
+++ b/docs/notebooks/vmapped_log_probs.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
+ "[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
"\n",
"This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n",
"\n",
diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md
index 5989a87bc..3f836e680 100644
--- a/docs/notebooks/vmapped_log_probs.md
+++ b/docs/notebooks/vmapped_log_probs.md
@@ -18,7 +18,7 @@ kernelspec:
-[](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
+[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst
index 93d7e5547..4a2d4daa6 100644
--- a/docs/pallas/tpu/details.rst
+++ b/docs/pallas/tpu/details.rst
@@ -21,7 +21,7 @@ software emulation, and can slow down the computation.
If you see unexpected outputs, please compare them against a kernel run with
``interpret=True`` passed in to ``pallas_call``. If the results diverge,
- please file a `bug report `_.
+ please file a `bug report `_.
What is a TPU?
--------------
diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md
index 1d6a5d9b7..47a7587b6 100644
--- a/docs/persistent_compilation_cache.md
+++ b/docs/persistent_compilation_cache.md
@@ -29,7 +29,7 @@ f(x)
### Setting cache directory
The compilation cache is enabled when the
-[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
+[cache location](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:
@@ -54,7 +54,7 @@ os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
```
-(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
+(3) Using [`set_cache_dir()`](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
```python
from jax.experimental.compilation_cache import compilation_cache as cc
diff --git a/docs/sphinxext/jax_extensions.py b/docs/sphinxext/jax_extensions.py
index 3a7855763..7cce8b882 100644
--- a/docs/sphinxext/jax_extensions.py
+++ b/docs/sphinxext/jax_extensions.py
@@ -26,14 +26,14 @@ def jax_issue_role(name, rawtext, text, lineno, inliner, options=None,
:jax-issue:`1234`
This will output a hyperlink of the form
- `#1234 `_. These links work even
+ `#1234 `_. These links work even
for PR numbers.
"""
text = text.lstrip('#')
if not text.isdigit():
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.")
options = {} if options is None else options
- url = f"https://github.com/google/jax/issues/{text}"
+ url = f"https://github.com/jax-ml/jax/issues/{text}"
node = nodes.reference(rawtext, '#' + text, refuri=url, **options)
return [node], []
diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md
index 4eb6e7a66..2ff82e043 100644
--- a/docs/stateful-computations.md
+++ b/docs/stateful-computations.md
@@ -234,4 +234,4 @@ Handling parameters manually seems fine if you're dealing with two parameters, b
2) Are we supposed to pipe all these things around manually?
-The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples.
+The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/jax-ml/jax#neural-network-libraries) for some examples.
diff --git a/jax/__init__.py b/jax/__init__.py
index e2e302adb..c6e073699 100644
--- a/jax/__init__.py
+++ b/jax/__init__.py
@@ -29,7 +29,7 @@ except Exception as exc:
# Defensively swallow any exceptions to avoid making jax unimportable
from warnings import warn as _warn
_warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report "
- f"an issue at https://github.com/google/jax/issues")
+ f"an issue at https://github.com/jax-ml/jax/issues")
del _warn
del _cloud_tpu_init
@@ -38,7 +38,7 @@ import jax.core as _core
del _core
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.basearray import Array as Array
from jax import tree as tree
diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py
index 8c7fe2f48..39df07359 100644
--- a/jax/_src/ad_checkpoint.py
+++ b/jax/_src/ad_checkpoint.py
@@ -546,7 +546,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
# To avoid precision mismatches in fwd and bwd passes due to XLA excess
# precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls
- # on producers of any residuals. See https://github.com/google/jax/pull/22244.
+ # on producers of any residuals. See https://github.com/jax-ml/jax/pull/22244.
jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res)
# compute known outputs and residuals (hoisted out of remat primitive)
diff --git a/jax/_src/api.py b/jax/_src/api.py
index aae99a28b..bd8a95195 100644
--- a/jax/_src/api.py
+++ b/jax/_src/api.py
@@ -956,7 +956,7 @@ def vmap(fun: F,
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
- # rather than raising an error. https://github.com/google/jax/issues/2367
+ # rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_axes = tuple(in_axes)
if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}):
@@ -2505,7 +2505,7 @@ class ShapeDtypeStruct:
def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
- # https://github.com/google/jax/issues/8182
+ # https://github.com/jax-ml/jax/issues/8182
return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type))
def _sds_aval_mapping(x):
diff --git a/jax/_src/callback.py b/jax/_src/callback.py
index 453a4eba4..3a18dcdfa 100644
--- a/jax/_src/callback.py
+++ b/jax/_src/callback.py
@@ -196,7 +196,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
device_assignment = axis_context.device_assignment
if device_assignment is None:
raise AssertionError(
- "Please file a bug at https://github.com/google/jax/issues")
+ "Please file a bug at https://github.com/jax-ml/jax/issues")
try:
device_index = device_assignment.index(device)
except IndexError as e:
diff --git a/jax/_src/config.py b/jax/_src/config.py
index fe56ec68f..b21d2f35f 100644
--- a/jax/_src/config.py
+++ b/jax/_src/config.py
@@ -1170,7 +1170,7 @@ softmax_custom_jvp = bool_state(
upgrade=True,
help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should '
'improve memory usage and stability. Set True to use new '
- 'behavior. See https://github.com/google/jax/pull/15677'),
+ 'behavior. See https://github.com/jax-ml/jax/pull/15677'),
update_global_hook=lambda val: _update_global_jit_state(
softmax_custom_jvp=val),
update_thread_local_hook=lambda val: update_thread_local_jit_state(
diff --git a/jax/_src/core.py b/jax/_src/core.py
index 057a79925..bff59625b 100644
--- a/jax/_src/core.py
+++ b/jax/_src/core.py
@@ -935,7 +935,7 @@ aval_method = namedtuple("aval_method", ["fun"])
class EvalTrace(Trace):
- # See comments in https://github.com/google/jax/pull/3370
+ # See comments in https://github.com/jax-ml/jax/pull/3370
def pure(self, x): return x
lift = sublift = pure
@@ -998,7 +998,7 @@ class MainTrace:
return self.trace_type(self, cur_sublevel(), **self.payload)
class TraceStack:
- # See comments in https://github.com/google/jax/pull/3370
+ # See comments in https://github.com/jax-ml/jax/pull/3370
stack: list[MainTrace]
dynamic: MainTrace
@@ -1167,7 +1167,7 @@ def _why_alive(ignore_ids: set[int], x: Any) -> str:
# parent->child jump. We do that by setting `parent` here to be a
# grandparent (or great-grandparent) of `child`, and then handling that case
# in _why_alive_container_info. See example:
- # https://github.com/google/jax/pull/13022#discussion_r1008456599
+ # https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599
# To prevent this collapsing behavior, just comment out this code block.
if (isinstance(parent, dict) and
getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
@@ -1213,7 +1213,7 @@ def _why_alive_container_info(container, obj_id) -> str:
@contextmanager
def new_main(trace_type: type[Trace], dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
- # See comments in https://github.com/google/jax/pull/3370
+ # See comments in https://github.com/jax-ml/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
main = MainTrace(level, trace_type, **payload)
@@ -1254,7 +1254,7 @@ def dynamic_level() -> int:
@contextmanager
def new_base_main(trace_type: type[Trace],
**payload) -> Generator[MainTrace, None, None]:
- # See comments in https://github.com/google/jax/pull/3370
+ # See comments in https://github.com/jax-ml/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type, **payload)
prev_dynamic, stack.dynamic = stack.dynamic, main
@@ -1319,7 +1319,7 @@ def ensure_compile_time_eval():
else:
return jnp.cos(x)
- Here's a real-world example from https://github.com/google/jax/issues/3974::
+ Here's a real-world example from https://github.com/jax-ml/jax/issues/3974::
import jax
import jax.numpy as jnp
@@ -1680,7 +1680,7 @@ class UnshapedArray(AbstractValue):
@property
def shape(self):
msg = ("UnshapedArray has no shape. Please open an issue at "
- "https://github.com/google/jax/issues because it's unexpected for "
+ "https://github.com/jax-ml/jax/issues because it's unexpected for "
"UnshapedArray instances to ever be produced.")
raise TypeError(msg)
diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py
index 07df2321b..35e7d3343 100644
--- a/jax/_src/custom_batching.py
+++ b/jax/_src/custom_batching.py
@@ -191,7 +191,7 @@ def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
# TODO(frostig): assert these also equal:
# treedef_tuple((in_tree, in_tree))
- # once https://github.com/google/jax/issues/9066 is fixed
+ # once https://github.com/jax-ml/jax/issues/9066 is fixed
assert tree_ps_ts == tree_ps_ts2
del tree_ps_ts2
diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py
index 88be655a0..f5ecdfcda 100644
--- a/jax/_src/custom_derivatives.py
+++ b/jax/_src/custom_derivatives.py
@@ -1144,7 +1144,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
- # See https://github.com/google/jax/issues/6415 for motivation.
+ # See https://github.com/jax-ml/jax/issues/6415 for motivation.
x = core.full_lower(x)
if not isinstance(x, core.Tracer):
# If x is not a Tracer, it can't be perturbed.
diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py
index 8f48746dd..984d55fe2 100644
--- a/jax/_src/custom_partitioning.py
+++ b/jax/_src/custom_partitioning.py
@@ -492,7 +492,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
- 'Please file a bug at https://github.com/google/jax/issues')
+ 'Please file a bug at https://github.com/jax-ml/jax/issues')
if axis_context.mesh_shape is not None:
ma, ms = list(zip(*axis_context.mesh_shape))
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)
diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py
index 337349694..465dc90e2 100644
--- a/jax/_src/debugging.py
+++ b/jax/_src/debugging.py
@@ -389,7 +389,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
- 'Please file a bug at https://github.com/google/jax/issues')
+ 'Please file a bug at https://github.com/jax-ml/jax/issues')
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = axis_context.mesh._flat_devices_tuple
else:
diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py
index d76b80ad3..352a3e550 100644
--- a/jax/_src/dtypes.py
+++ b/jax/_src/dtypes.py
@@ -793,7 +793,7 @@ def check_user_dtype_supported(dtype, fun_name=None):
"and will be truncated to dtype {}. To enable more dtypes, set the "
"jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
"environment variable. "
- "See https://github.com/google/jax#current-gotchas for more.")
+ "See https://github.com/jax-ml/jax#current-gotchas for more.")
fun_name = f"requested in {fun_name}" if fun_name else ""
truncated_dtype = canonicalize_dtype(np_dtype).name
warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3)
diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py
index e18ad1f6e..11a9dda66 100644
--- a/jax/_src/flatten_util.py
+++ b/jax/_src/flatten_util.py
@@ -61,7 +61,7 @@ def _ravel_list(lst):
if all(dt == to_dtype for dt in from_dtypes):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
- # See https://github.com/google/jax/issues/7809.
+ # See https://github.com/jax-ml/jax/issues/7809.
del from_dtypes, to_dtype
raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes)
diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py
index 4bf6d1ceb..2c9490756 100644
--- a/jax/_src/internal_test_util/test_harnesses.py
+++ b/jax/_src/internal_test_util/test_harnesses.py
@@ -30,7 +30,7 @@ Instead of writing this information as conditions inside one
particular test, we write them as `Limitation` objects that can be reused in
multiple tests and can also be used to generate documentation, e.g.,
the report of [unsupported and partially-implemented JAX
-primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md)
+primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md)
The limitations are used to filter out from tests the harnesses that are known
to fail. A Limitation is specific to a harness.
@@ -515,7 +515,7 @@ def _make_convert_element_type_harness(name,
for old_dtype in jtu.dtypes.all:
# TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating
# point numbers and new_dtype is an unsigned integer. See issue
- # https://github.com/google/jax/issues/5082 for details.
+ # https://github.com/jax-ml/jax/issues/5082 for details.
for new_dtype in (jtu.dtypes.all
if not (dtypes.issubdtype(old_dtype, np.floating) or
dtypes.issubdtype(old_dtype, np.complexfloating))
@@ -2336,7 +2336,7 @@ _make_select_and_scatter_add_harness("select_prim", select_prim=lax.le_p)
# Validate padding
for padding in [
# TODO(bchetioui): commented out the test based on
- # https://github.com/google/jax/issues/4690
+ # https://github.com/jax-ml/jax/issues/4690
# ((1, 2), (2, 3), (3, 4)) # non-zero padding
((1, 1), (1, 1), (1, 1)) # non-zero padding
]:
diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py
index 6bc6539b9..6bc3cceb7 100644
--- a/jax/_src/interpreters/partial_eval.py
+++ b/jax/_src/interpreters/partial_eval.py
@@ -1262,7 +1262,7 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool],
inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult:
msg = (f'custom-policy remat rule not implemented for {name}, '
- 'open a feature request at https://github.com/google/jax/issues!')
+ 'open a feature request at https://github.com/jax-ml/jax/issues!')
raise NotImplementedError(msg)
@@ -2688,7 +2688,7 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
# TODO(mattjj): the following are deprecated; update callers to _nounits version
-# See https://github.com/google/jax/pull/9498
+# See https://github.com/jax-ml/jax/pull/9498
@lu.transformation
def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool],
pvals: Sequence[PartialVal]):
diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py
index 944e20fa7..de668090e 100644
--- a/jax/_src/interpreters/pxla.py
+++ b/jax/_src/interpreters/pxla.py
@@ -500,7 +500,7 @@ class MapTrace(core.Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
if symbolic_zeros:
msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. "
- "Please open an issue at https://github.com/google/jax/issues !")
+ "Please open an issue at https://github.com/jax-ml/jax/issues !")
raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros # always base main, can drop jvp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
@@ -513,7 +513,7 @@ class MapTrace(core.Trace):
out_trees, symbolic_zeros):
if symbolic_zeros:
msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. "
- "Please open an issue at https://github.com/google/jax/issues !")
+ "Please open an issue at https://github.com/jax-ml/jax/issues !")
raise NotImplementedError(msg)
del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
@@ -1869,7 +1869,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
"does not preserve sharded data representations and instead collects "
"input and output arrays onto a single device. "
"Consider removing the outer jit unless you know what you're doing. "
- "See https://github.com/google/jax/issues/2926.")
+ "See https://github.com/jax-ml/jax/issues/2926.")
if nreps > xb.device_count(backend):
raise ValueError(
diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py
index 4cb38d28c..d3065d0f9 100644
--- a/jax/_src/lax/control_flow/conditionals.py
+++ b/jax/_src/lax/control_flow/conditionals.py
@@ -389,7 +389,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args,
branch_outs = []
for i, jaxpr in enumerate(branches_batched):
# Perform a select on the inputs for safety of reverse-mode autodiff; see
- # https://github.com/google/jax/issues/1052
+ # https://github.com/jax-ml/jax/issues/1052
predicate = lax.eq(index, lax._const(index, i))
ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]
branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py
index 41d809f8d..7a9596bf2 100644
--- a/jax/_src/lax/control_flow/loops.py
+++ b/jax/_src/lax/control_flow/loops.py
@@ -715,7 +715,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts,
if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
raise NotImplementedError
if not all(init_lin):
- pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963
+ pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963
consts, _, xs = split_list(args, [num_consts, num_carry])
ires, _ = split_list(consts, [num_ires])
@@ -1169,7 +1169,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
if discharged_consts:
raise NotImplementedError("Discharged jaxpr has consts. If you see this, "
"please open an issue at "
- "https://github.com/google/jax/issues")
+ "https://github.com/jax-ml/jax/issues")
def wrapped(*wrapped_args):
val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args,
[n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs])
@@ -1838,7 +1838,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
if body_jaxpr_consts:
raise NotImplementedError("Body jaxpr has consts. If you see this error, "
"please open an issue at "
- "https://github.com/google/jax/issues")
+ "https://github.com/jax-ml/jax/issues")
# body_jaxpr has the signature (*body_consts, *carry) -> carry.
# Some of these body_consts are actually `Ref`s so when we discharge
# them, they also turn into outputs, effectively turning those consts into
diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py
index 0cbee6d2b..36553e512 100644
--- a/jax/_src/lax/fft.py
+++ b/jax/_src/lax/fft.py
@@ -157,7 +157,7 @@ def _irfft_transpose(t, fft_lengths):
out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
# Use JAX's convention for complex gradients
- # https://github.com/google/jax/issues/6223#issuecomment-807740707
+ # https://github.com/jax-ml/jax/issues/6223#issuecomment-807740707
return lax.conj(out)
def _fft_transpose_rule(t, operand, fft_type, fft_lengths):
diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py
index 48af9c64f..394a54c35 100644
--- a/jax/_src/lax/lax.py
+++ b/jax/_src/lax/lax.py
@@ -1077,7 +1077,7 @@ def _reduction_jaxpr(computation, aval):
if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue "
- "at https://github.com/google/jax.")
+ "at https://github.com/jax-ml/jax.")
return jaxpr, tuple(consts)
@cache()
@@ -1090,7 +1090,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue "
- "at https://github.com/google/jax.")
+ "at https://github.com/jax-ml/jax.")
return core.ClosedJaxpr(jaxpr, consts), out_tree()
def _get_monoid_reducer(monoid_op: Callable,
@@ -4911,7 +4911,7 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
return tree_util.tree_unflatten(p.out_tree(), out_flat)
-# TODO(https://github.com/google/jax/issues/13552): Look into making this a
+# TODO(https://github.com/jax-ml/jax/issues/13552): Look into making this a
# method on jax.Array so that we can bypass the XLA compilation here.
def _copy_impl(prim, *args, **kwargs):
a, = args
diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py
index ec0a075da..453e79a5c 100644
--- a/jax/_src/lax/linalg.py
+++ b/jax/_src/lax/linalg.py
@@ -781,7 +781,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
raise NotImplementedError(
'The derivatives of eigenvectors are not implemented, only '
'eigenvalues. See '
- 'https://github.com/google/jax/issues/2748 for discussion.')
+ 'https://github.com/jax-ml/jax/issues/2748 for discussion.')
# Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in
# https://arxiv.org/abs/1701.00392
a, = primals
diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py
index b2bcc53a5..e8fcb4334 100644
--- a/jax/_src/lib/__init__.py
+++ b/jax/_src/lib/__init__.py
@@ -27,7 +27,7 @@ try:
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
'jax requires jaxlib to be installed. See '
- 'https://github.com/google/jax#installation for installation instructions.'
+ 'https://github.com/jax-ml/jax#installation for installation instructions.'
) from err
import jax.version
@@ -92,7 +92,7 @@ pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
-# XLA garbage collection: see https://github.com/google/jax/issues/14882
+# XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882
def _xla_gc_callback(*args):
xla_client._xla.collect_garbage()
gc.callbacks.append(_xla_gc_callback)
diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py
index 20234b678..08c8bfcb3 100644
--- a/jax/_src/mesh.py
+++ b/jax/_src/mesh.py
@@ -333,7 +333,7 @@ class AbstractMesh:
should use this as an input to the sharding passed to with_sharding_constraint
and mesh passed to shard_map to avoid tracing and lowering cache misses when
your mesh shape and names stay the same but the devices change.
- See the description of https://github.com/google/jax/pull/23022 for more
+ See the description of https://github.com/jax-ml/jax/pull/23022 for more
details.
"""
diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py
index a1601e920..5b9362685 100644
--- a/jax/_src/numpy/lax_numpy.py
+++ b/jax/_src/numpy/lax_numpy.py
@@ -3953,7 +3953,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
arrays_out = [asarray(arr, dtype=dtype) for arr in arrays]
# lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode.
- # (https://github.com/google/jax/issues/653).
+ # (https://github.com/jax-ml/jax/issues/653).
k = 16
while len(arrays_out) > 1:
arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
@@ -4645,7 +4645,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
if all(not isinstance(leaf, Array) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
- # https://github.com/google/jax/pull/6047. More correct would be to call
+ # https://github.com/jax-ml/jax/pull/6047. More correct would be to call
# coerce_to_array on each leaf, but this may have performance implications.
out = np.asarray(object, dtype=dtype)
elif isinstance(object, Array):
@@ -10150,11 +10150,11 @@ def _eliminate_deprecated_list_indexing(idx):
if any(_should_unpack_list_index(i) for i in idx):
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[tuple(seq)]` instead of `arr[seq]`. "
- "See https://github.com/google/jax/issues/4564 for more information.")
+ "See https://github.com/jax-ml/jax/issues/4564 for more information.")
else:
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[array(seq)]` instead of `arr[seq]`. "
- "See https://github.com/google/jax/issues/4564 for more information.")
+ "See https://github.com/jax-ml/jax/issues/4564 for more information.")
raise TypeError(msg)
else:
idx = (idx,)
diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py
index b6aea9e19..043c976ef 100644
--- a/jax/_src/numpy/reductions.py
+++ b/jax/_src/numpy/reductions.py
@@ -968,7 +968,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
msg = ("jax.numpy.var does not yet support real dtype parameters when "
"computing the variance of an array of complex values. The "
"semantics of numpy.var seem unclear in this case. Please comment "
- "on https://github.com/google/jax/issues/2283 if this behavior is "
+ "on https://github.com/jax-ml/jax/issues/2283 if this behavior is "
"important to you.")
raise ValueError(msg)
computation_dtype = dtype
diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py
index 7e8acb090..6491a7617 100644
--- a/jax/_src/numpy/setops.py
+++ b/jax/_src/numpy/setops.py
@@ -134,7 +134,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
- The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
+ The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static ``size``
argument:
@@ -217,7 +217,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike,
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
- The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
+ The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static ``size``
argument:
diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py
index b45b3370f..00b5311b8 100644
--- a/jax/_src/numpy/ufuncs.py
+++ b/jax/_src/numpy/ufuncs.py
@@ -2176,7 +2176,7 @@ def sinc(x: ArrayLike, /) -> Array:
def _sinc_maclaurin(k, x):
# compute the kth derivative of x -> sin(x)/x evaluated at zero (since we
# compute the monomial term in the jvp rule)
- # TODO(mattjj): see https://github.com/google/jax/issues/10750
+ # TODO(mattjj): see https://github.com/jax-ml/jax/issues/10750
if k % 2:
return x * 0
else:
diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py
index 5340eb3fa..f8231f5b2 100644
--- a/jax/_src/pallas/mosaic/error_handling.py
+++ b/jax/_src/pallas/mosaic/error_handling.py
@@ -35,7 +35,7 @@ FRAME_PATTERN = re.compile(
)
MLIR_ERR_PREFIX = (
'Pallas encountered an internal verification error.'
- 'Please file a bug at https://github.com/google/jax/issues. '
+ 'Please file a bug at https://github.com/jax-ml/jax/issues. '
'Error details: '
)
diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py
index f120bbabf..d4dc534d0 100644
--- a/jax/_src/pallas/mosaic/lowering.py
+++ b/jax/_src/pallas/mosaic/lowering.py
@@ -833,7 +833,7 @@ def jaxpr_subcomp(
raise NotImplementedError(
"Unimplemented primitive in Pallas TPU lowering: "
f"{eqn.primitive.name}. "
- "Please file an issue on https://github.com/google/jax/issues.")
+ "Please file an issue on https://github.com/jax-ml/jax/issues.")
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, ans)
else:
diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py
index 5eaf6e523..4b8199b36 100644
--- a/jax/_src/pallas/mosaic_gpu/lowering.py
+++ b/jax/_src/pallas/mosaic_gpu/lowering.py
@@ -549,7 +549,7 @@ def lower_jaxpr_to_mosaic_gpu(
raise NotImplementedError(
"Unimplemented primitive in Pallas Mosaic GPU lowering: "
f"{eqn.primitive.name}. "
- "Please file an issue on https://github.com/google/jax/issues."
+ "Please file an issue on https://github.com/jax-ml/jax/issues."
)
rule = mosaic_lowering_rules[eqn.primitive]
rule_ctx = LoweringRuleContext(
diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py
index ac28bd21a..6a1615627 100644
--- a/jax/_src/pallas/triton/lowering.py
+++ b/jax/_src/pallas/triton/lowering.py
@@ -381,7 +381,7 @@ def lower_jaxpr_to_triton_ir(
raise NotImplementedError(
"Unimplemented primitive in Pallas GPU lowering: "
f"{eqn.primitive.name}. "
- "Please file an issue on https://github.com/google/jax/issues.")
+ "Please file an issue on https://github.com/jax-ml/jax/issues.")
rule = triton_lowering_rules[eqn.primitive]
avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars]
diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py
index 0abaa3fd0..ac1318ed7 100644
--- a/jax/_src/pjit.py
+++ b/jax/_src/pjit.py
@@ -459,7 +459,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
- # rather than raising an error. https://github.com/google/jax/issues/2367
+ # rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_shardings = tuple(in_shardings)
in_layouts, in_shardings = _split_layout_and_sharding(in_shardings)
@@ -1276,7 +1276,7 @@ def explain_tracing_cache_miss(
return done()
# we think this is unreachable...
- p("explanation unavailable! please open an issue at https://github.com/google/jax")
+ p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax")
return done()
@partial(lu.cache, explain=explain_tracing_cache_miss)
@@ -1701,7 +1701,7 @@ def _pjit_call_impl_python(
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
- "https://github.com/google/jax.")
+ "https://github.com/jax-ml/jax.")
raise FloatingPointError(msg)
diff --git a/jax/_src/random.py b/jax/_src/random.py
index d889713f6..203f72d40 100644
--- a/jax/_src/random.py
+++ b/jax/_src/random.py
@@ -508,7 +508,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
span = lax.convert_element_type(maxval - minval, unsigned_dtype)
# Ensure that span=1 when maxval <= minval, so minval is always returned;
- # https://github.com/google/jax/issues/222
+ # https://github.com/jax-ml/jax/issues/222
span = lax.select(maxval <= minval, lax.full_like(span, 1), span)
# When maxval is out of range, the span has to be one larger.
@@ -2540,7 +2540,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array:
_btrs(key, count_btrs, q_btrs, shape, dtype, max_iters),
)
# ensure nan q always leads to nan output and nan or neg count leads to nan
- # as discussed in https://github.com/google/jax/pull/16134#pullrequestreview-1446642709
+ # as discussed in https://github.com/jax-ml/jax/pull/16134#pullrequestreview-1446642709
invalid = (q_l_0 | q_is_nan | count_nan_or_neg)
samples = lax.select(
invalid,
diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py
index d81008308..ee144eaf9 100644
--- a/jax/_src/scipy/ndimage.py
+++ b/jax/_src/scipy/ndimage.py
@@ -176,7 +176,7 @@ def map_coordinates(
Note:
Interpolation near boundaries differs from the scipy function, because JAX
- fixed an outstanding bug; see https://github.com/google/jax/issues/11097.
+ fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097.
This function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy.
"""
diff --git a/jax/_src/shard_alike.py b/jax/_src/shard_alike.py
index 2361eaf64..574d725c4 100644
--- a/jax/_src/shard_alike.py
+++ b/jax/_src/shard_alike.py
@@ -44,7 +44,7 @@ def shard_alike(x, y):
raise ValueError(
'The leaves shapes of `x` and `y` should match. Got `x` leaf shape:'
f' {x_aval.shape} and `y` leaf shape: {y_aval.shape}. File an issue at'
- ' https://github.com/google/jax/issues if you want this feature.')
+ ' https://github.com/jax-ml/jax/issues if you want this feature.')
outs = [shard_alike_p.bind(x_, y_) for x_, y_ in safe_zip(x_flat, y_flat)]
x_out_flat, y_out_flat = zip(*outs)
diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py
index 5afcd5e3a..81737f275 100644
--- a/jax/_src/test_util.py
+++ b/jax/_src/test_util.py
@@ -1208,7 +1208,7 @@ class JaxTestCase(parameterized.TestCase):
y = np.asarray(y)
if (not allow_object_dtype) and (x.dtype == object or y.dtype == object):
- # See https://github.com/google/jax/issues/17867
+ # See https://github.com/jax-ml/jax/issues/17867
raise TypeError(
"assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. "
"If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. "
diff --git a/jax/_src/typing.py b/jax/_src/typing.py
index 0caa6e7c6..010841b45 100644
--- a/jax/_src/typing.py
+++ b/jax/_src/typing.py
@@ -21,7 +21,7 @@ exported at `jax.typing`. Until then, the contents here should be considered uns
and may change without notice.
To see the proposal that led to the development of these tools, see
-https://github.com/google/jax/pull/11859/.
+https://github.com/jax-ml/jax/pull/11859/.
"""
from __future__ import annotations
diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 1d3c50403..796093b62 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -1232,7 +1232,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs):
if library_path is None:
raise RuntimeError(
"JAX TPU support not installed; cannot generate TPU topology. See"
- " https://github.com/google/jax#installation")
+ " https://github.com/jax-ml/jax#installation")
c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path)
xla_client.profiler.register_plugin_profiler(c_api)
assert xla_client.pjrt_plugin_loaded("tpu")
diff --git a/jax/core.py b/jax/core.py
index 9857fcf88..cdf8d7655 100644
--- a/jax/core.py
+++ b/jax/core.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.core import (
AbstractToken as AbstractToken,
diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py
index 8e517f5d4..ea1ef4f02 100644
--- a/jax/custom_derivatives.py
+++ b/jax/custom_derivatives.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.custom_derivatives import (
_initial_style_jaxpr,
diff --git a/jax/dtypes.py b/jax/dtypes.py
index f2071fd4f..a6f1b7645 100644
--- a/jax/dtypes.py
+++ b/jax/dtypes.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.dtypes import (
bfloat16 as bfloat16,
diff --git a/jax/errors.py b/jax/errors.py
index 15a6654fa..2a811661d 100644
--- a/jax/errors.py
+++ b/jax/errors.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.errors import (
JAXTypeError as JAXTypeError,
diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py
index caf27ec7a..1b22c2c2a 100644
--- a/jax/experimental/__init__.py
+++ b/jax/experimental/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax.experimental.x64_context import (
enable_x64 as enable_x64,
diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py
index 69b25d0b6..3ac1d4246 100644
--- a/jax/experimental/array_api/__init__.py
+++ b/jax/experimental/array_api/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
import sys as _sys
import warnings as _warnings
diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py
index 0b6b51f71..8e11d4173 100644
--- a/jax/experimental/checkify.py
+++ b/jax/experimental/checkify.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.checkify import (
Error as Error,
diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py
index 3c7bfac40..6da3ad7c5 100644
--- a/jax/experimental/custom_partitioning.py
+++ b/jax/experimental/custom_partitioning.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.custom_partitioning import (
custom_partitioning as custom_partitioning,
diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py
index 63c3299c5..49162809a 100644
--- a/jax/experimental/host_callback.py
+++ b/jax/experimental/host_callback.py
@@ -17,7 +17,7 @@
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks `_
- See https://github.com/google/jax/issues/20385.
+ See https://github.com/jax-ml/jax/issues/20385.
This module introduces the host callback functions :func:`call`,
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
@@ -363,11 +363,11 @@ using the :func:`jax.custom_vjp` mechanism.
This is relatively easy to do, once one understands both the JAX custom VJP
and the TensorFlow autodiff mechanisms.
The code for how this can be done is shown in the ``call_tf_full_ad``
-function in `host_callback_to_tf_test.py `_.
+function in `host_callback_to_tf_test.py `_.
This example supports arbitrary higher-order differentiation as well.
Note that if you just want to call TensorFlow functions from JAX, you can also
-use the `jax2tf.call_tf function `_.
+use the `jax2tf.call_tf function `_.
Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support
------------------------------------------------------------------------------------------------
@@ -378,7 +378,7 @@ the host, and then to the outside device on which the JAX host
computation will run, and then the results are sent back to the original accelerator.
The code for how this can be done is shown in the ``call_jax_other_device function``
-in `host_callback_test.py `_.
+in `host_callback_test.py `_.
Low-level details and debugging
-------------------------------
@@ -572,7 +572,7 @@ _HOST_CALLBACK_LEGACY = config.bool_flag(
help=(
'Use old implementation of host_callback, documented in the module docstring.'
'If False, use the jax.experimental.io_callback implementation. '
- 'See https://github.com/google/jax/issues/20385.'
+ 'See https://github.com/jax-ml/jax/issues/20385.'
)
)
@@ -592,7 +592,7 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
"See https://jax.readthedocs.io/en/latest/debugging/index.html and "
"https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html"
" for alternatives. Please file a feature request at "
- "https://github.com/google/jax/issues if none of the alternatives are "
+ "https://github.com/jax-ml/jax/issues if none of the alternatives are "
"sufficient.")
@@ -608,7 +608,7 @@ DType = Any
class CallbackFlavor(enum.Enum):
"""Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False.
- See https://github.com/google/jax/issues/20385.
+ See https://github.com/jax-ml/jax/issues/20385.
"""
IO_CALLBACK = 1 # uses jax.experimental.io_callback
PURE = 2 # uses jax.pure_callback
@@ -629,7 +629,7 @@ def _deprecated_id_tap(tap_func,
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks `_
- See https://github.com/google/jax/issues/20385.
+ See https://github.com/jax-ml/jax/issues/20385.
``id_tap`` behaves semantically like the identity function but has the
side-effect that a user-defined Python function is called with the runtime
@@ -655,7 +655,7 @@ def _deprecated_id_tap(tap_func,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
- See https://github.com/google/jax/issues/20385.
+ See https://github.com/jax-ml/jax/issues/20385.
Returns:
``arg``, or ``result`` if given.
@@ -712,7 +712,7 @@ def _deprecated_id_print(arg,
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks `_
- See https://github.com/google/jax/issues/20385.
+ See https://github.com/jax-ml/jax/issues/20385.
On each invocation of the printing tap, the ``kwargs`` if present
will be printed first (sorted by keys). Then arg will be printed,
@@ -730,7 +730,7 @@ def _deprecated_id_print(arg,
* ``threshold`` is passed to ``numpy.array2string``.
* ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
- See https://github.com/google/jax/issues/20385.
+ See https://github.com/jax-ml/jax/issues/20385.
For more details see the :mod:`jax.experimental.host_callback` module documentation.
"""
@@ -757,7 +757,7 @@ def _deprecated_call(callback_func: Callable, arg, *,
The host_callback APIs are deprecated as of March 20, 2024.
The functionality is subsumed by the
`new JAX external callbacks `_
- See https://github.com/google/jax/issues/20385.
+ See https://github.com/jax-ml/jax/issues/20385.
Args:
callback_func: The Python function to invoke on the host as
@@ -787,7 +787,7 @@ def _deprecated_call(callback_func: Callable, arg, *,
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
the flavor of callback to use.
- See https://github.com/google/jax/issues/20385.
+ See https://github.com/jax-ml/jax/issues/20385.
Returns:
the result of the ``callback_func`` invocation.
@@ -800,7 +800,7 @@ def _deprecated_call(callback_func: Callable, arg, *,
raise NotImplementedError(
"When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` "
"flavor of callback only when the `result_shape` is None. "
- "See https://github.com/google/jax/issues/20385."
+ "See https://github.com/jax-ml/jax/issues/20385."
)
return _call(callback_func, arg, result_shape=result_shape,
call_with_device=call_with_device, identity=False,
@@ -819,7 +819,7 @@ class _CallbackWrapper:
raise NotImplementedError(
"When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs"
" do not support `tap_with_device` and `call_with_device`. "
- "See https://github.com/google/jax/issues/20385.")
+ "See https://github.com/jax-ml/jax/issues/20385.")
def __hash__(self):
return hash((self.callback_func, self.identity, self.call_with_device))
@@ -2121,7 +2121,7 @@ def _deprecated_stop_outfeed_receiver():
_deprecation_msg = (
"The host_callback APIs are deprecated as of March 20, 2024. The functionality "
"is subsumed by the new JAX external callbacks. "
- "See https://github.com/google/jax/issues/20385.")
+ "See https://github.com/jax-ml/jax/issues/20385.")
_deprecations = {
# Added March 20, 2024
diff --git a/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb b/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb
index 4f23d88e0..3613dba0e 100644
--- a/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb
+++ b/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb
@@ -26,7 +26,7 @@
"Link: go/jax2tf-colab\n",
"\n",
"The JAX2TF colab has been deprecated, and the example code has\n",
- "been moved to [jax2tf/examples](https://github.com/google/jax/tree/main/jax/experimental/jax2tf/examples). \n"
+ "been moved to [jax2tf/examples](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf/examples). \n"
]
}
]
diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md
index dbdc4f563..b77474c03 100644
--- a/jax/experimental/jax2tf/README.md
+++ b/jax/experimental/jax2tf/README.md
@@ -103,10 +103,10 @@ For more involved examples, please see examples involving:
* SavedModel for archival ([examples below](#usage-saved-model)), including
saving [batch-polymorphic functions](#shape-polymorphic-conversion),
- * TensorFlow Lite ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)),
- * TensorFlow.js ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)),
+ * TensorFlow Lite ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)),
+ * TensorFlow.js ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)),
* TFX ([examples](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/README.md#instructions-for-using-flax)),
- * TensorFlow Hub and Keras ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md)).
+ * TensorFlow Hub and Keras ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md)).
[TOC]
@@ -249,7 +249,7 @@ graph (they will be saved in a `variables` area of the model, which is not
subject to the 2GB limitation).
For examples of how to save a Flax model as a SavedModel see the
-[examples directory](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md).
+[examples directory](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md).
### Saved model and differentiation
@@ -619,7 +619,7 @@ Cannot solve for values of dimension variables {'a', 'b'}. "
We can only solve linear uni-variate constraints. "
Using the following polymorphic shapes specifications: args[0].shape = (a + b,).
Unprocessed specifications: 'a + b' for dimension size args[0].shape[0]. "
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.
```
### Shape assertion errors
@@ -645,7 +645,7 @@ Input shapes do not match the polymorphic shapes specification.
Division had remainder 1 when computing the value of 'd'.
Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d).
Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.
```
When using native serialization these are checked by the `tf.XlaCallModule`
@@ -869,7 +869,7 @@ leads to errors for the following expressions `b == a or b == b` or `b in [a, b]
even though the error is avoided if we change the order of the comparisons.
We attempted to retain soundness and hashability by creating both hashable and unhashable
-kinds of symbolic dimensions [PR #14200](https://github.com/google/jax/pull/14200),
+kinds of symbolic dimensions [PR #14200](https://github.com/jax-ml/jax/pull/14200),
but it turned out to be very hard to diagnose hashing failures in user programs because
often hashing is implicit when using sets or memo tables.
@@ -989,7 +989,7 @@ We list here a history of the serialization version numbers:
June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`.
- See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
+ See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
since July 12th, 2023 (cl/547482522),
available in JAX serialization since July 20th, 2023 (JAX 0.4.14),
and the default since August 12th, 2023 (JAX 0.4.15).
@@ -1164,7 +1164,7 @@ self.assertAllClose(grad_jax.b, grad_tf[1])
Applies to both native and non-native serialization.
When JAX differentiates functions with integer or boolean arguments, the gradients will
-be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)).
+be zero-vectors with a special `float0` type (see PR 4039](https://github.com/jax-ml/jax/pull/4039)).
This type is translated to `int32` when lowering to TF.
For example,
@@ -1441,7 +1441,7 @@ Operations like ``jax.numpy.cumsum`` are lowered by JAX differently based
on the platform. For TPU, the lowering uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)
operation, which has an efficient implementation for the cases when the
reduction function is associative. For CPU and GPU, JAX uses an alternative
-lowering using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801).
+lowering using [associative scans](https://github.com/jax-ml/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801).
jax2tf uses the TPU lowering (because it does not support backend-specific lowering)
and hence it can be slow in some cases on CPU and GPU.
@@ -1502,7 +1502,7 @@ before conversion. (This is a hypothesis, we have not yet verified it extensivel
There is one know case when the performance of the lowered code will be different.
JAX programs use a [stateless
-deterministic PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md)
+deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md)
and it has an internal JAX primitive for it.
This primitive is at the moment lowered to a soup of tf.bitwise operations,
which has a clear performance penalty. We plan to look into using the
@@ -1589,7 +1589,7 @@ Applies to non-native serialization only.
There are a number of cases when the TensorFlow ops that are used by the
`jax2tf` are not supported by TensorFlow for the same data types as in JAX.
There is an
-[up-to-date list of unimplemented cases](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md).
+[up-to-date list of unimplemented cases](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md).
If you try to lower and run in TensorFlow a program with partially supported primitives,
you may see TensorFlow errors that
@@ -1626,7 +1626,7 @@ the function to a SavedModel, knowing that upon restore the
jax2tf-lowered code will be compiled.
For a more elaborate example, see the test `test_tf_mix_jax_with_uncompilable`
-in [savedmodel_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py).
+in [savedmodel_test.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py).
# Calling TensorFlow functions from JAX
@@ -1704,7 +1704,7 @@ For a more elaborate example, including round-tripping from JAX
to TensorFlow and back through a SavedModel, with support for
custom gradients,
see the test `test_round_trip_custom_grad_saved_model`
-in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py).
+in [call_tf_test.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py).
All the metadata inserted by TF during tracing and compilation, e.g.,
source location information and op names, is carried through to the
@@ -1901,7 +1901,7 @@ As of today, the tests are run using `tf_nightly==2.14.0.dev20230720`.
To run jax2tf on GPU, both jaxlib and TensorFlow must be installed with support
for CUDA. One must be mindful to install a version of CUDA that is compatible
-with both [jaxlib](https://github.com/google/jax/blob/main/README.md#pip-installation) and
+with both [jaxlib](https://github.com/jax-ml/jax/blob/main/README.md#pip-installation) and
[TensorFlow](https://www.tensorflow.org/install/source#tested_build_configurations).
## Updating the limitations documentation
@@ -1913,9 +1913,9 @@ JAX primitive, data type, device type, and TensorFlow execution mode (`eager`,
`graph`, or `compiled`). These limitations are also used
to generate tables of limitations, e.g.,
- * [List of primitives not supported in JAX](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md),
+ * [List of primitives not supported in JAX](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md),
e.g., due to unimplemented cases in the XLA compiler, and
- * [List of primitives not supported in jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md),
+ * [List of primitives not supported in jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md),
e.g., due to unimplemented cases in TensorFlow. This list is incremental
on top of the unsupported JAX primitives.
diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py
index 037f8bbc2..baae52403 100644
--- a/jax/experimental/jax2tf/call_tf.py
+++ b/jax/experimental/jax2tf/call_tf.py
@@ -19,7 +19,7 @@ This module introduces the function :func:`call_tf` that allows JAX to call
TensorFlow functions.
For examples and details, see
-https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
+https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
"""
@@ -93,7 +93,7 @@ def call_tf(
For an example and more details see the
`README
- `_.
+ `_.
Args:
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
@@ -460,7 +460,7 @@ def _call_tf_abstract_eval(
msg = ("call_tf cannot call functions whose output has dynamic shape. "
f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
"Consider using the `output_shape_dtype` argument to call_tf. "
- "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
+ "\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
" for a discussion.")
raise ValueError(msg)
@@ -499,7 +499,7 @@ def _call_tf_lowering(
msg = (
"call_tf works best with a TensorFlow function that does not capture "
"variables or tensors from the context. "
- "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
+ "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
logging.warning(msg)
for inp in concrete_function_flat_tf.captured_inputs:
@@ -544,7 +544,7 @@ def _call_tf_lowering(
"\ncall_tf can used " +
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
"compilable functions with static output shapes.\n" +
- "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
+ "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
"\n\nCaught TensorFlow exception: " + str(e))
raise ValueError(msg) from e
@@ -557,7 +557,7 @@ def _call_tf_lowering(
f"{res_shape}. call_tf can used " +
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
"compilable functions with static output shapes. " +
- "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.")
+ "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.")
raise ValueError(msg)
res_dtype = res_shape.numpy_dtype()
diff --git a/jax/experimental/jax2tf/examples/README.md b/jax/experimental/jax2tf/examples/README.md
index b049798e7..8869a226b 100644
--- a/jax/experimental/jax2tf/examples/README.md
+++ b/jax/experimental/jax2tf/examples/README.md
@@ -4,7 +4,7 @@ jax2tf Examples
Link: go/jax2tf-examples.
This directory contains a number of examples of using the
-[jax2tf converter](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) to:
+[jax2tf converter](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) to:
* save SavedModel from trained MNIST models, using both Flax and pure JAX.
* reuse the feature-extractor part of the trained MNIST model
@@ -19,12 +19,12 @@ You can also find usage examples in other projects:
The functions generated by `jax2tf.convert` are standard TensorFlow functions
and you can save them in a SavedModel using standard TensorFlow code, as shown
-in the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model).
+in the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model).
This decoupling of jax2tf from SavedModel is important, because it **allows
the user to have full control over what metadata is saved in the SavedModel**.
As an example, we provide the function `convert_and_save_model`
-(see [saved_model_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py).)
+(see [saved_model_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py).)
For serious uses, you will probably want to copy and expand this
function as needed.
@@ -65,7 +65,7 @@ If you are using Flax, then the recipe to obtain this pair is as follows:
```
You can see in
-[mnist_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py)
+[mnist_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py)
how this can be done for two
implementations of MNIST, one using pure JAX (`PureJaxMNIST`) and a CNN
one using Flax (`FlaxMNIST`). Other Flax models can be arranged similarly,
@@ -91,7 +91,7 @@ embed all parameters in the graph:
```
(The MNIST Flax examples from
-[mnist_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py)
+[mnist_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py)
normally has a GraphDef of 150k and a variables section of 3Mb. If we embed the
parameters as constants in the GraphDef as shown above, the variables section
becomes empty and the GraphDef becomes 13Mb. This embedding may allow
@@ -112,7 +112,7 @@ If you are using Haiku, then the recipe is along these lines:
Once you have the model in this form, you can use the
`saved_model_lib.save_model` function from
-[saved_model_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py)
+[saved_model_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py)
to generate the SavedModel. There is very little in that function that is
specific to jax2tf. The goal of jax2tf is to convert JAX functions
into functions that behave as if they had been written with TensorFlow.
@@ -120,7 +120,7 @@ Therefore, if you are familiar with how to generate SavedModel, you can most
likely just use your own code for this.
The file
-[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py)
+[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py)
is an executable that shows how to perform the
following sequence of steps:
@@ -147,9 +147,9 @@ batch sizes: 1, 16, 128. You can see this in the dumped SavedModel.
The SavedModel produced by the example in `saved_model_main.py` already
implements the [reusable saved models interface](https://www.tensorflow.org/hub/reusable_saved_models).
The executable
-[keras_reuse_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/keras_reuse_main.py)
+[keras_reuse_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/keras_reuse_main.py)
extends
-[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py)
+[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py)
with code
to include a jax2tf SavedModel into a larger TensorFlow Keras model.
@@ -174,7 +174,7 @@ In particular, you can select the Flax MNIST model: `--model=mnist_flax`.
It is also possible to use jax2tf-generated SavedModel with TensorFlow serving.
At the moment, the open-source TensorFlow model server is missing XLA support,
but the Google version can be used, as shown in the
-[serving examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md).
+[serving examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md).
# Using jax2tf with TensorFlow Lite and TensorFlow JavaScript
@@ -186,6 +186,6 @@ can pass the `enable_xla=False` parameter to `jax2tf.convert` to direct
`jax2tf` to avoid problematic ops. This will increase the coverage, and in fact
most, but not all, Flax examples can be converted this way.
-Check out the [MNIST TensorFlow Lite](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)
+Check out the [MNIST TensorFlow Lite](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)
and the
-[Quickdraw TensorFlow.js example](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md).
+[Quickdraw TensorFlow.js example](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md).
diff --git a/jax/experimental/jax2tf/examples/serving/README.md b/jax/experimental/jax2tf/examples/serving/README.md
index 0d8f49e45..299923109 100644
--- a/jax/experimental/jax2tf/examples/serving/README.md
+++ b/jax/experimental/jax2tf/examples/serving/README.md
@@ -2,7 +2,7 @@ Using jax2tf with TensorFlow serving
====================================
This is a supplement to the
-[examples/README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md)
+[examples/README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md)
with example code and
instructions for using `jax2tf` with the open source TensorFlow model server.
Specific instructions for Google-internal versions of model server are in the `internal` subdirectory.
@@ -15,16 +15,16 @@ SavedModel**.
The only difference in the SavedModel produced with jax2tf is that the
function graphs may contain
-[XLA TF ops](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#caveats)
+[XLA TF ops](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#caveats)
that require enabling CPU/GPU XLA for execution in the model server. This
is achieved using a command-line flag. There are no other differences compared
to using SavedModel produced by TensorFlow.
This serving example uses
-[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py)
+[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py)
for saving the SavedModel and adds code specific to interacting with the
model server:
-[model_server_request.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/serving/model_server_request.py).
+[model_server_request.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/model_server_request.py).
0. *Set up JAX and TensorFlow serving*.
@@ -36,7 +36,7 @@ We also need to install TensorFlow for the `jax2tf` feature and the rest of this
We use the `tf_nightly` package to get an up-to-date version.
```shell
- git clone https://github.com/google/jax
+ git clone https://github.com/jax-ml/jax
JAX2TF_EXAMPLES=$(pwd)/jax/jax/experimental/jax2tf/examples
pip install -e jax
pip install flax jaxlib tensorflow_datasets tensorflow_serving_api tf_nightly
diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/README.md b/jax/experimental/jax2tf/examples/tflite/mnist/README.md
index 9c889e647..f39bd9c7e 100644
--- a/jax/experimental/jax2tf/examples/tflite/mnist/README.md
+++ b/jax/experimental/jax2tf/examples/tflite/mnist/README.md
@@ -65,7 +65,7 @@ TensorFlow ops that are only available with the XLA compiler, and which
are not understood (yet) by the TFLite converter to be used below.
-Check out [more details about this limitation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md),
+Check out [more details about this limitation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md),
including to which JAX primitives it applies.
### Convert the trained model to the TF Lite format
diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md b/jax/experimental/jax2tf/g3doc/convert_models_results.md
index 545f1faee..24e2539a3 100644
--- a/jax/experimental/jax2tf/g3doc/convert_models_results.md
+++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md
@@ -48,13 +48,13 @@ details on the different converters.
## `flax/actor_critic_[(_, 4*b, 4*b, _)]`
### Example: `flax/actor_critic_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_xla`
```
-InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
### Example: `flax/actor_critic_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_noxla`
```
-InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
@@ -62,13 +62,13 @@ InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github
```
Conversion error
InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -78,13 +78,13 @@ for more details.
```
Conversion error
InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -94,13 +94,13 @@ for more details.
```
Conversion error
InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -122,13 +122,13 @@ RuntimeError('third_party/tensorflow/lite/kernels/concatenation.cc:159 t->dims->
## `flax/bilstm_[(b, _), (_,)]`
### Example: `flax/bilstm_[(b, _), (_,)]` | Converter: `jax2tf_xla`
```
-InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
### Example: `flax/bilstm_[(b, _), (_,)]` | Converter: `jax2tf_noxla`
```
-InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
@@ -141,7 +141,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -156,7 +156,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -171,7 +171,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -180,13 +180,13 @@ for more details.
## `flax/bilstm_[(_, _), (b,)]`
### Example: `flax/bilstm_[(_, _), (b,)]` | Converter: `jax2tf_xla`
```
-InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
### Example: `flax/bilstm_[(_, _), (b,)]` | Converter: `jax2tf_noxla`
```
-InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
@@ -199,7 +199,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -214,7 +214,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -229,7 +229,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -238,13 +238,13 @@ for more details.
## `flax/cnn_[(_, b, b, _)]`
### Example: `flax/cnn_[(_, b, b, _)]` | Converter: `jax2tf_xla`
```
-InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
### Example: `flax/cnn_[(_, b, b, _)]` | Converter: `jax2tf_noxla`
```
-InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
@@ -253,13 +253,13 @@ InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_
Conversion error
InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.
Details: Cannot divide 'b + -2' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
.
@@ -267,7 +267,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -278,13 +278,13 @@ for more details.
Conversion error
InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.
Details: Cannot divide 'b + -2' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
.
@@ -292,7 +292,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -303,13 +303,13 @@ for more details.
Conversion error
InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.
Details: Cannot divide 'b + -2' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
.
@@ -317,7 +317,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -395,13 +395,13 @@ ValueError('Cannot set tensor: Dimension mismatch. Got 8 but expected 1 for dime
## `flax/resnet50_[(_, 4*b, 4*b, _)]`
### Example: `flax/resnet50_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_xla`
```
-InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
### Example: `flax/resnet50_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_noxla`
```
-InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
@@ -409,13 +409,13 @@ InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github
```
Conversion error
InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -425,13 +425,13 @@ for more details.
```
Conversion error
InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -441,13 +441,13 @@ for more details.
```
Conversion error
InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.
-See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
+See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.
This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -613,13 +613,13 @@ IndexError('Cannot use NumPy slice indexing on an array dimension whose size is
## `flax/lm1b_[(b, _)]`
### Example: `flax/lm1b_[(b, _)]` | Converter: `jax2tf_xla`
```
-InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
### Example: `flax/lm1b_[(b, _)]` | Converter: `jax2tf_noxla`
```
-InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
@@ -632,7 +632,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -647,7 +647,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -662,7 +662,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -684,13 +684,13 @@ ValueError('Cannot set tensor: Dimension mismatch. Got 2 but expected 1 for dime
## `flax/wmt_[(b, _), (b, _)]`
### Example: `flax/wmt_[(b, _), (b, _)]` | Converter: `jax2tf_xla`
```
-InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
### Example: `flax/wmt_[(b, _), (b, _)]` | Converter: `jax2tf_noxla`
```
-InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
+InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n")
```
[Back to top](#summary-table)
@@ -703,7 +703,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -718,7 +718,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -733,7 +733,7 @@ This error arises for arithmetic or comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a polynomial of dimension variables, or a boolean constant (for comparisons).
-Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
+Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
for more details.
")
```
@@ -798,14 +798,14 @@ This converter simply converts a the forward function of a JAX model to a
Tensorflow function with XLA support linked in. This is considered the baseline
converter and has the largest coverage, because we expect nearly all ops to be
convertible. However, please see
-[jax2tf Known Issue](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues)
+[jax2tf Known Issue](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues)
for a list of known problems.
### `jax2tf_noxla`
This converter converts a JAX model to a Tensorflow function without XLA
support. This means the Tensorflow XLA ops aren't used. See
-[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops)
+[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops)
for more details.
### `jax2tfjs`
diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md.template b/jax/experimental/jax2tf/g3doc/convert_models_results.md.template
index b54c57503..54e1d2135 100644
--- a/jax/experimental/jax2tf/g3doc/convert_models_results.md.template
+++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md.template
@@ -29,14 +29,14 @@ This converter simply converts a the forward function of a JAX model to a
Tensorflow function with XLA support linked in. This is considered the baseline
converter and has the largest coverage, because we expect nearly all ops to be
convertible. However, please see
-[jax2tf Known Issue](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues)
+[jax2tf Known Issue](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues)
for a list of known problems.
### `jax2tf_noxla`
This converter converts a JAX model to a Tensorflow function without XLA
support. This means the Tensorflow XLA ops aren't used. See
-[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops)
+[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops)
for more details.
### `jax2tfjs`
diff --git a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md
index 457dc998a..24a1d62ee 100644
--- a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md
+++ b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md
@@ -1,6 +1,6 @@
# jax2tf Limitations for `enable_xla=False`
-*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues)*
+*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues)*
For most JAX primitives there is a natural TF op that fits the needed semantics
(e.g., `jax.lax.abs` is equivalent to `tf.abs`). However, there are a number of
diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md
index dabbcca4d..b36b004a9 100644
--- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md
+++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md
@@ -40,7 +40,7 @@ The converter has a mode in which it attempts to avoid special XLA TF ops
(`enable_xla=False`). In this mode, some primitives have additional limitations.
This table only shows errors for cases that are working in JAX (see [separate
-list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
+list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
We do not yet have support for `pmap` (with its collective primitives),
nor for `sharded_jit` (SPMD partitioning).
@@ -56,7 +56,7 @@ We use the following abbreviations for sets of dtypes:
* `all` = `integer`, `inexact`, `bool`
More detailed information can be found in the
-[source code for the limitation specification](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py).
+[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py).
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template
index bf5dc41d8..219802f53 100644
--- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template
+++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template
@@ -40,7 +40,7 @@ The converter has a mode in which it attempts to avoid special XLA TF ops
(`enable_xla=False`). In this mode, some primitives have additional limitations.
This table only shows errors for cases that are working in JAX (see [separate
-list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
+list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
We do not yet have support for `pmap` (with its collective primitives),
nor for `sharded_jit` (SPMD partitioning).
@@ -56,7 +56,7 @@ We use the following abbreviations for sets of dtypes:
* `all` = `integer`, `inexact`, `bool`
More detailed information can be found in the
-[source code for the limitation specification](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py).
+[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py).
{{tf_error_table}}
diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py
index 5ecde602c..310cbaab6 100644
--- a/jax/experimental/jax2tf/impl_no_xla.py
+++ b/jax/experimental/jax2tf/impl_no_xla.py
@@ -591,7 +591,7 @@ def _padding_reduce_window(operand, operand_shape, computation_name,
padding_type = pads_to_padtype(operand_shape, window_dimensions,
window_strides, padding)
- # https://github.com/google/jax/issues/11874.
+ # https://github.com/jax-ml/jax/issues/11874.
needs_manual_padding = (
padding_type == "SAME" and computation_name == "add" and
window_dimensions != [1] * len(operand_shape))
diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py
index 24dee390f..8a90c491e 100644
--- a/jax/experimental/jax2tf/jax2tf.py
+++ b/jax/experimental/jax2tf/jax2tf.py
@@ -198,7 +198,7 @@ class _ThreadLocalState(threading.local):
# A cache for the tf.convert_to_tensor for constants. We try to preserve
# sharing for constants, to enable tf.Graph to take advantage of it.
- # See https://github.com/google/jax/issues/7992.
+ # See https://github.com/jax-ml/jax/issues/7992.
self.constant_cache = None # None means that we don't use a cache. We
# may be outside a conversion scope.
@@ -249,7 +249,7 @@ def convert(fun_jax: Callable,
"""Allows calling a JAX function from a TensorFlow program.
See
- [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
+ [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for more details about usage and common problems.
Args:
@@ -291,12 +291,12 @@ def convert(fun_jax: Callable,
polymorphic_shapes are only supported for positional arguments; shape
polymorphism is not supported for keyword arguments.
- See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
+ See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of
the form `e1 >= e2` or `e1 <= e2`.
- See more details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
+ See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
TensorFlow AD is supported for the output TensorFlow function, and the
@@ -3536,7 +3536,7 @@ def _shard_value(val: TfVal,
if tf_context.executing_eagerly():
raise ValueError(
"A jit function with sharded arguments or results must be used under a `tf.function` context. "
- "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion")
+ "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion")
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
val, use_sharding_op=True)
diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py
index 2760efea8..e10c3fbfd 100644
--- a/jax/experimental/jax2tf/tests/call_tf_test.py
+++ b/jax/experimental/jax2tf/tests/call_tf_test.py
@@ -304,7 +304,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
def test_with_var_different_shape(self):
- # See https://github.com/google/jax/issues/6050
+ # See https://github.com/jax-ml/jax/issues/6050
v = tf.Variable((4., 2.), dtype=tf.float32)
def tf_func(x):
@@ -428,7 +428,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(g_jax, g_tf)
def test_grad_int_argument(self):
- # Similar to https://github.com/google/jax/issues/6975
+ # Similar to https://github.com/jax-ml/jax/issues/6975
# state is a pytree that contains an integer and a boolean.
# The function returns an integer and a boolean.
def f(param, state, x):
diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py
index 896d0436e..c3b9e96dc 100644
--- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py
+++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py
@@ -1066,9 +1066,9 @@ class Jax2TfLimitation(test_harnesses.Limitation):
@classmethod
def qr(cls, harness: test_harnesses.Harness):
- # See https://github.com/google/jax/pull/3775#issuecomment-659407824;
+ # See https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824;
# # jit_compile=True breaks for complex types.
- # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
+ # TODO: see https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824.
# - for now, the performance of the HLO QR implementation called when
# compiling with TF is expected to have worse performance than the
# custom calls made in JAX.
diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py
index 64d461fe9..ef7a5ee2c 100644
--- a/jax/experimental/jax2tf/tests/jax2tf_test.py
+++ b/jax/experimental/jax2tf/tests/jax2tf_test.py
@@ -595,7 +595,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@jtu.sample_product(with_function=[False, True])
def test_gradients_int_argument(self, with_function=False):
- # https://github.com/google/jax/issues/6975
+ # https://github.com/jax-ml/jax/issues/6975
# Also issue #6975.
# An expanded version of test_gradients_unused_argument
state = dict(
@@ -969,7 +969,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def)
def test_bfloat16_constant(self):
- # Re: https://github.com/google/jax/issues/3942
+ # Re: https://github.com/jax-ml/jax/issues/3942
def jax_fn_scalar(x):
x = x.astype(jnp.bfloat16)
x *= 2.
@@ -990,7 +990,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_shared_constants(self):
# Check that the constants are shared properly in converted functions
- # See https://github.com/google/jax/issues/7992.
+ # See https://github.com/jax-ml/jax/issues/7992.
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
@@ -1002,7 +1002,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_shared_constants_under_cond(self):
# Check that the constants are shared properly in converted functions
- # See https://github.com/google/jax/issues/7992.
+ # See https://github.com/jax-ml/jax/issues/7992.
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const_size = 512
@@ -1018,7 +1018,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertLen(f2_consts, len(f1_consts))
def test_shared_constants_under_scan(self):
- # See https://github.com/google/jax/issues/7992.
+ # See https://github.com/jax-ml/jax/issues/7992.
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const_size = 512
@@ -1092,7 +1092,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@jtu.sample_product(with_function=[False, True])
def test_kwargs(self, with_function=False):
- # Re: https://github.com/google/jax/issues/6791
+ # Re: https://github.com/jax-ml/jax/issues/6791
def f_jax(*, x):
return jnp.sum(x)
f_tf = jax2tf.convert(f_jax)
@@ -1104,7 +1104,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
@jtu.sample_product(with_function=[False, True])
def test_grad_kwargs(self, with_function=False):
- # Re: https://github.com/google/jax/issues/6791
+ # Re: https://github.com/jax-ml/jax/issues/6791
x = (np.zeros(3, dtype=np.float32),
np.zeros(4, dtype=np.float32))
def f_jax(*, x=(1., 2.)):
diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py
index 485fa6e58..78c24b7ea 100644
--- a/jax/experimental/jax2tf/tests/primitives_test.py
+++ b/jax/experimental/jax2tf/tests/primitives_test.py
@@ -30,7 +30,7 @@ in Tensorflow errors (for some devices and compilation modes). These limitations
are captured as jax2tf_limitations.Jax2TfLimitation objects.
From the limitations objects, we generate a
-[report](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md).
+[report](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md).
The report has instructions for how to re-generate it.
If a harness run fails with error, and a limitation that matches the device
diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py
index 8b71de7db..bc19915d1 100644
--- a/jax/experimental/jax2tf/tests/savedmodel_test.py
+++ b/jax/experimental/jax2tf/tests/savedmodel_test.py
@@ -175,7 +175,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase):
def test_save_grad_integers(self):
- # https://github.com/google/jax/issues/7123
+ # https://github.com/jax-ml/jax/issues/7123
# In the end this is a test that does not involve JAX at all
batch_size = 5
state = np.array([1], dtype=np.int32) # Works if float32
diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py
index 2475a062f..a9ee17762 100644
--- a/jax/experimental/jax2tf/tests/shape_poly_test.py
+++ b/jax/experimental/jax2tf/tests/shape_poly_test.py
@@ -933,7 +933,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
kwargs=[dict(with_function=v) for v in [True, False]]
)
def test_grad_int(self, with_function=False):
- # https://github.com/google/jax/issues/7093
+ # https://github.com/jax-ml/jax/issues/7093
# Also issue #6975.
x_shape = (2, 3, 4)
xi = np.arange(math.prod(x_shape), dtype=np.int16).reshape(x_shape)
@@ -2172,7 +2172,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
(2, x.shape[0]), (1, 1), "VALID"),
arg_descriptors=[RandArg((3, 8), _f32)],
polymorphic_shapes=["b, ..."]),
- # https://github.com/google/jax/issues/11804
+ # https://github.com/jax-ml/jax/issues/11804
# Use the reshape trick to simulate a polymorphic dimension of 16*b.
# (See test "conv_general_dilated.1d_1" above for more details.)
PolyHarness("reduce_window", "add_monoid_strides_window_size=static",
diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py
index d6349b487..9009c1586 100644
--- a/jax/experimental/jax2tf/tests/sharding_test.py
+++ b/jax/experimental/jax2tf/tests/sharding_test.py
@@ -437,7 +437,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
def test_grad_sharding_different_mesh(self):
# Convert with two similar meshes, the only difference being
# the order of the devices. grad should not fail.
- # https://github.com/google/jax/issues/21314
+ # https://github.com/jax-ml/jax/issues/21314
devices = jax.local_devices()[:2]
if len(devices) < 2:
raise unittest.SkipTest("Test requires 2 local devices")
diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py
index 1ed6183b1..ffe362974 100644
--- a/jax/experimental/jet.py
+++ b/jax/experimental/jet.py
@@ -45,11 +45,11 @@ r"""Jet is an experimental module for higher-order automatic differentiation
and can thus be used for high-order
automatic differentiation of :math:`f`.
Details are explained in
- `these notes `__.
+ `these notes `__.
Note:
Help improve :func:`jet` by contributing
- `outstanding primitive rules `__.
+ `outstanding primitive rules `__.
"""
from collections.abc import Callable
diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py
index fabd45ca0..f19401525 100644
--- a/jax/experimental/shard_map.py
+++ b/jax/experimental/shard_map.py
@@ -639,7 +639,7 @@ def _rule_missing(prim: core.Primitive, *_, **__):
raise NotImplementedError(
f"No replication rule for {prim}. As a workaround, pass the "
"`check_rep=False` argument to `shard_map`. To get this fixed, open an "
- "issue at https://github.com/google/jax/issues")
+ "issue at https://github.com/jax-ml/jax/issues")
# Lowering
@@ -845,20 +845,20 @@ class ShardMapTrace(core.Trace):
f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't "
"yet supported. Put a `jax.jit` around the `shard_map`-decorated "
"function, and open a feature request at "
- "https://github.com/google/jax/issues !")
+ "https://github.com/jax-ml/jax/issues !")
def process_map(self, map_primitive, fun, tracers, params):
raise NotImplementedError(
"Eager evaluation of `pmap` inside a `shard_map` isn't yet supported."
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
- "a feature request at https://github.com/google/jax/issues !")
+ "a feature request at https://github.com/jax-ml/jax/issues !")
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
if symbolic_zeros:
msg = ("custom_jvp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at "
- "https://github.com/google/jax/issues")
+ "https://github.com/jax-ml/jax/issues")
raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
@@ -876,7 +876,7 @@ class ShardMapTrace(core.Trace):
if symbolic_zeros:
msg = ("custom_vjp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at "
- "https://github.com/google/jax/issues")
+ "https://github.com/jax-ml/jax/issues")
raise NotImplementedError(msg)
del prim, fwd, bwd, out_trees, symbolic_zeros
in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers)
@@ -1042,7 +1042,7 @@ def _standard_check(prim, mesh, *in_rep, **__):
if in_rep_ and not in_rep_[:-1] == in_rep_[1:]:
raise Exception(f"Primitive {prim} requires argument replication types "
f"to match, but got {in_rep}. Please open an issue at "
- "https://github.com/google/jax/issues and as a temporary "
+ "https://github.com/jax-ml/jax/issues and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
return in_rep_[0] if in_rep_ else None
@@ -1057,7 +1057,7 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params):
raise Exception(f"Collective {prim} must be applied to a device-varying "
f"replication type, but got {x_rep} for collective acting "
f"over axis name {axis_name}. Please open an issue at "
- "https://github.com/google/jax/issues and as a temporary "
+ "https://github.com/jax-ml/jax/issues and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
return x_rep
@@ -1114,7 +1114,7 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups):
raise Exception("Collective psum must be applied to a device-varying "
f"replication type, but got {in_rep} for collective acting "
f"over axis name {axes}. Please open an issue at "
- "https://github.com/google/jax/issues, and as a temporary "
+ "https://github.com/jax-ml/jax/issues, and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
return [r | set(axes) for r in in_rep]
@@ -1129,7 +1129,7 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups):
"non-device-varying "
f"replication type, but got {in_rep} for collective acting "
f"over axis name {axes}. Please open an issue at "
- "https://github.com/google/jax/issues, and as a temporary "
+ "https://github.com/jax-ml/jax/issues, and as a temporary "
"workaround pass the check_rep=False argument to shard_map")
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
return [r - set(axes) for r in in_rep]
@@ -1216,7 +1216,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_):
if not carry_rep_in == carry_rep_out:
raise Exception("Scan carry input and output got mismatched replication "
f"types {carry_rep_in} and {carry_rep_out}. Please open an "
- "issue at https://github.com/google/jax/issues, and as a "
+ "issue at https://github.com/jax-ml/jax/issues, and as a "
"temporary workaround pass the check_rep=False argument to "
"shard_map")
return out_rep
@@ -1267,7 +1267,7 @@ def _custom_vjp_call_jaxpr_rewrite(
mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees,
symbolic_zeros):
if symbolic_zeros:
- msg = ("Please open an issue at https://github.com/google/jax/issues and as"
+ msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and as"
" a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)
@@ -1303,7 +1303,7 @@ def _linear_solve_check(mesh, *in_rep, const_lengths, jaxprs):
assert in_rep
if not in_rep_[:-1] == in_rep_[1:]:
msg = ("shard_map check_rep rewrite failed. Please open an issue at "
- "https://github.com/google/jax/issues and as a workaround pass the "
+ "https://github.com/jax-ml/jax/issues and as a workaround pass the "
"check_rep=False argument to shard_map")
raise Exception(msg)
return [in_rep_[0]] * len(jaxprs.solve.out_avals)
@@ -1878,7 +1878,7 @@ class RewriteTrace(core.Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
if symbolic_zeros:
- msg = ("Please open an issue at https://github.com/google/jax/issues and "
+ msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and "
"as a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)
@@ -1899,7 +1899,7 @@ class RewriteTrace(core.Trace):
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
if symbolic_zeros:
- msg = ("Please open an issue at https://github.com/google/jax/issues and "
+ msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and "
"as a temporary workaround pass the check_rep=False argument to "
"shard_map")
raise NotImplementedError(msg)
diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py
index 8ab8cd887..f388cd527 100644
--- a/jax/experimental/sparse/__init__.py
+++ b/jax/experimental/sparse/__init__.py
@@ -189,7 +189,7 @@ To fit the same model on sparse data, we can apply the :func:`sparsify` transfor
"""
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax.experimental.sparse.ad import (
jacfwd as jacfwd,
diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py
index d200577c2..9f2f0f69b 100644
--- a/jax/experimental/sparse/bcoo.py
+++ b/jax/experimental/sparse/bcoo.py
@@ -105,7 +105,7 @@ def _bcoo_set_nse(mat: BCOO, nse: int) -> BCOO:
unique_indices=mat.unique_indices)
# TODO(jakevdp) this can be problematic when used with autodiff; see
-# https://github.com/google/jax/issues/10163. Should this be a primitive?
+# https://github.com/jax-ml/jax/issues/10163. Should this be a primitive?
# Alternatively, maybe roll this into bcoo_sum_duplicates as an optional argument.
def bcoo_eliminate_zeros(mat: BCOO, nse: int | None = None) -> BCOO:
data, indices, shape = mat.data, mat.indices, mat.shape
@@ -1140,7 +1140,7 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices,
out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :])
out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1])
# Note: we do not eliminate zeros here, because it can cause issues with autodiff.
- # See https://github.com/google/jax/issues/10163.
+ # See https://github.com/jax-ml/jax/issues/10163.
return _bcoo_sum_duplicates(out_data, out_indices, spinfo=SparseInfo(shape=out_shape), nse=out_nse)
@bcoo_spdot_general_p.def_impl
@@ -1537,7 +1537,7 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse):
nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype)
data_dot_out = data_out
# This check is because scatter-add on zero-sized arrays has poorly defined
- # semantics; see https://github.com/google/jax/issues/13656.
+ # semantics; see https://github.com/jax-ml/jax/issues/13656.
if data_out.size:
permute = lambda x, i, y: x.at[i].add(y, mode='drop')
else:
diff --git a/jax/extend/backend.py b/jax/extend/backend.py
index 66fd149d7..b1e471133 100644
--- a/jax/extend/backend.py
+++ b/jax/extend/backend.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.api import (
clear_backends as clear_backends,
diff --git a/jax/extend/core/__init__.py b/jax/extend/core/__init__.py
index 2732b1984..9f1632fb3 100644
--- a/jax/extend/core/__init__.py
+++ b/jax/extend/core/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.abstract_arrays import (
array_types as array_types
diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py
index e37287180..feb70b517 100644
--- a/jax/extend/core/primitives.py
+++ b/jax/extend/core/primitives.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py
index 3a26030c1..b2d480adc 100644
--- a/jax/extend/ffi.py
+++ b/jax/extend/ffi.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.extend.ffi import (
ffi_call as ffi_call,
diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py
index d5fb9245a..715dfd435 100644
--- a/jax/extend/ifrt_programs.py
+++ b/jax/extend/ifrt_programs.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.lib import xla_extension as _xe
diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py
index 1706f8c8c..74c52dddb 100644
--- a/jax/extend/linear_util.py
+++ b/jax/extend/linear_util.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.linear_util import (
StoreException as StoreException,
diff --git a/jax/extend/random.py b/jax/extend/random.py
index a055c7575..d6e0cfaab 100644
--- a/jax/extend/random.py
+++ b/jax/extend/random.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.extend.random import (
define_prng_impl as define_prng_impl,
diff --git a/jax/extend/source_info_util.py b/jax/extend/source_info_util.py
index f74df2cab..f031dabef 100644
--- a/jax/extend/source_info_util.py
+++ b/jax/extend/source_info_util.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.source_info_util import (
NameStack as NameStack,
diff --git a/jax/image/__init__.py b/jax/image/__init__.py
index c7ee8ffa9..993395f50 100644
--- a/jax/image/__init__.py
+++ b/jax/image/__init__.py
@@ -21,7 +21,7 @@ JAX, such as `PIX`_.
"""
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.image.scale import (
resize as resize,
diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py
index 6bfc3473f..28816afb0 100644
--- a/jax/interpreters/ad.py
+++ b/jax/interpreters/ad.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from __future__ import annotations
diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py
index 98fad903c..607fc6fa5 100644
--- a/jax/interpreters/batching.py
+++ b/jax/interpreters/batching.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.interpreters.batching import (
Array as Array,
diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py
index e2bcd5de9..293bd0244 100644
--- a/jax/lax/__init__.py
+++ b/jax/lax/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.lax.lax import (
DotDimensionNumbers as DotDimensionNumbers,
diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py
index 230aacb76..496d03261 100644
--- a/jax/nn/__init__.py
+++ b/jax/nn/__init__.py
@@ -15,7 +15,7 @@
"""Common functions for neural network libraries."""
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax.numpy import tanh as tanh
from jax.nn import initializers as initializers
diff --git a/jax/nn/initializers.py b/jax/nn/initializers.py
index 6c73356ce..019f3e179 100644
--- a/jax/nn/initializers.py
+++ b/jax/nn/initializers.py
@@ -18,7 +18,7 @@ used in Keras and Sonnet.
"""
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.nn.initializers import (
constant as constant,
diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py
index da79f7859..20c37c559 100644
--- a/jax/numpy/__init__.py
+++ b/jax/numpy/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax.numpy import fft as fft
from jax.numpy import linalg as linalg
diff --git a/jax/numpy/fft.py b/jax/numpy/fft.py
index 24a271487..c268c2d65 100644
--- a/jax/numpy/fft.py
+++ b/jax/numpy/fft.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.numpy.fft import (
ifft as ifft,
diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py
index c342fde0a..05b5ff6db 100644
--- a/jax/numpy/linalg.py
+++ b/jax/numpy/linalg.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.numpy.linalg import (
cholesky as cholesky,
diff --git a/jax/ops/__init__.py b/jax/ops/__init__.py
index c61a44fd1..5e1f3d682 100644
--- a/jax/ops/__init__.py
+++ b/jax/ops/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.ops.scatter import (
segment_sum as segment_sum,
diff --git a/jax/profiler.py b/jax/profiler.py
index 01ea6e222..77157dc02 100644
--- a/jax/profiler.py
+++ b/jax/profiler.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.profiler import (
StepTraceAnnotation as StepTraceAnnotation,
diff --git a/jax/random.py b/jax/random.py
index 5c2eaf81f..29a625389 100644
--- a/jax/random.py
+++ b/jax/random.py
@@ -103,7 +103,7 @@ Design and background
**TLDR**: JAX PRNG = `Threefry counter PRNG `_
+ a functional array-oriented `splitting model `_
-See `docs/jep/263-prng.md `_
+See `docs/jep/263-prng.md `_
for more details.
To summarize, among other requirements, the JAX PRNG aims to:
@@ -201,7 +201,7 @@ https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_
"""
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.random import (
PRNGKey as PRNGKey,
diff --git a/jax/scipy/__init__.py b/jax/scipy/__init__.py
index c0746910d..cf44b6e17 100644
--- a/jax/scipy/__init__.py
+++ b/jax/scipy/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from typing import TYPE_CHECKING
diff --git a/jax/scipy/cluster/__init__.py b/jax/scipy/cluster/__init__.py
index 5a01ea0ee..ea35467f6 100644
--- a/jax/scipy/cluster/__init__.py
+++ b/jax/scipy/cluster/__init__.py
@@ -13,6 +13,6 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax.scipy.cluster import vq as vq
diff --git a/jax/scipy/cluster/vq.py b/jax/scipy/cluster/vq.py
index 3a46ce52f..eeeabb722 100644
--- a/jax/scipy/cluster/vq.py
+++ b/jax/scipy/cluster/vq.py
@@ -13,6 +13,6 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.cluster.vq import vq as vq
diff --git a/jax/scipy/fft.py b/jax/scipy/fft.py
index b8005b72f..d3c2de099 100644
--- a/jax/scipy/fft.py
+++ b/jax/scipy/fft.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.fft import (
dct as dct,
diff --git a/jax/scipy/integrate.py b/jax/scipy/integrate.py
index b19aa054c..3335f12fd 100644
--- a/jax/scipy/integrate.py
+++ b/jax/scipy/integrate.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.integrate import (
trapezoid as trapezoid
diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py
index 059f927ec..64bc05440 100644
--- a/jax/scipy/linalg.py
+++ b/jax/scipy/linalg.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.linalg import (
block_diag as block_diag,
diff --git a/jax/scipy/ndimage.py b/jax/scipy/ndimage.py
index 2f63e2366..81d7e3ef2 100644
--- a/jax/scipy/ndimage.py
+++ b/jax/scipy/ndimage.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.ndimage import (
map_coordinates as map_coordinates,
diff --git a/jax/scipy/optimize/__init__.py b/jax/scipy/optimize/__init__.py
index 8a2248733..f1c7167c3 100644
--- a/jax/scipy/optimize/__init__.py
+++ b/jax/scipy/optimize/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.optimize.minimize import (
minimize as minimize,
diff --git a/jax/scipy/signal.py b/jax/scipy/signal.py
index 7e39da3f9..c46b2fce3 100644
--- a/jax/scipy/signal.py
+++ b/jax/scipy/signal.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.signal import (
fftconvolve as fftconvolve,
diff --git a/jax/scipy/sparse/__init__.py b/jax/scipy/sparse/__init__.py
index f2e305e82..2968a26b4 100644
--- a/jax/scipy/sparse/__init__.py
+++ b/jax/scipy/sparse/__init__.py
@@ -13,6 +13,6 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax.scipy.sparse import linalg as linalg
diff --git a/jax/scipy/sparse/linalg.py b/jax/scipy/sparse/linalg.py
index d475ddff8..d22e5ec43 100644
--- a/jax/scipy/sparse/linalg.py
+++ b/jax/scipy/sparse/linalg.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.sparse.linalg import (
cg as cg,
diff --git a/jax/scipy/spatial/transform.py b/jax/scipy/spatial/transform.py
index 4b532d5f3..63e8dd373 100644
--- a/jax/scipy/spatial/transform.py
+++ b/jax/scipy/spatial/transform.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.spatial.transform import (
Rotation as Rotation,
diff --git a/jax/scipy/special.py b/jax/scipy/special.py
index 5d72339ea..431617d36 100644
--- a/jax/scipy/special.py
+++ b/jax/scipy/special.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.special import (
bernoulli as bernoulli,
diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py
index 7aa73f7b5..7719945f2 100644
--- a/jax/scipy/stats/__init__.py
+++ b/jax/scipy/stats/__init__.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax.scipy.stats import bernoulli as bernoulli
from jax.scipy.stats import beta as beta
diff --git a/jax/scipy/stats/bernoulli.py b/jax/scipy/stats/bernoulli.py
index 46c1e4825..1623f7113 100644
--- a/jax/scipy/stats/bernoulli.py
+++ b/jax/scipy/stats/bernoulli.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.bernoulli import (
logpmf as logpmf,
diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py
index 5c57dda6b..2a4e7f12f 100644
--- a/jax/scipy/stats/beta.py
+++ b/jax/scipy/stats/beta.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.beta import (
cdf as cdf,
diff --git a/jax/scipy/stats/betabinom.py b/jax/scipy/stats/betabinom.py
index 48f955d9e..f8adf68f4 100644
--- a/jax/scipy/stats/betabinom.py
+++ b/jax/scipy/stats/betabinom.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.betabinom import (
logpmf as logpmf,
diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py
index 4ff79f5f9..34c9972d0 100644
--- a/jax/scipy/stats/cauchy.py
+++ b/jax/scipy/stats/cauchy.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.cauchy import (
cdf as cdf,
diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py
index e17a2e331..47fcb76db 100644
--- a/jax/scipy/stats/chi2.py
+++ b/jax/scipy/stats/chi2.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.chi2 import (
cdf as cdf,
diff --git a/jax/scipy/stats/dirichlet.py b/jax/scipy/stats/dirichlet.py
index 9368defc8..22e9b3cc1 100644
--- a/jax/scipy/stats/dirichlet.py
+++ b/jax/scipy/stats/dirichlet.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.dirichlet import (
logpdf as logpdf,
diff --git a/jax/scipy/stats/expon.py b/jax/scipy/stats/expon.py
index 1ec50ac3f..8f5c0a068 100644
--- a/jax/scipy/stats/expon.py
+++ b/jax/scipy/stats/expon.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.expon import (
logpdf as logpdf,
diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py
index 8efecafed..531a1e300 100644
--- a/jax/scipy/stats/gamma.py
+++ b/jax/scipy/stats/gamma.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.gamma import (
cdf as cdf,
diff --git a/jax/scipy/stats/gennorm.py b/jax/scipy/stats/gennorm.py
index c903ff606..c760575fa 100644
--- a/jax/scipy/stats/gennorm.py
+++ b/jax/scipy/stats/gennorm.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.gennorm import (
cdf as cdf,
diff --git a/jax/scipy/stats/geom.py b/jax/scipy/stats/geom.py
index 75f917fc2..eb12dbb5a 100644
--- a/jax/scipy/stats/geom.py
+++ b/jax/scipy/stats/geom.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.geom import (
logpmf as logpmf,
diff --git a/jax/scipy/stats/laplace.py b/jax/scipy/stats/laplace.py
index 3abe62020..8f182804d 100644
--- a/jax/scipy/stats/laplace.py
+++ b/jax/scipy/stats/laplace.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.laplace import (
cdf as cdf,
diff --git a/jax/scipy/stats/logistic.py b/jax/scipy/stats/logistic.py
index c25a06856..7cdb26fb1 100644
--- a/jax/scipy/stats/logistic.py
+++ b/jax/scipy/stats/logistic.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.logistic import (
cdf as cdf,
diff --git a/jax/scipy/stats/multinomial.py b/jax/scipy/stats/multinomial.py
index 723d1a645..392ca4055 100644
--- a/jax/scipy/stats/multinomial.py
+++ b/jax/scipy/stats/multinomial.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.multinomial import (
logpmf as logpmf,
diff --git a/jax/scipy/stats/multivariate_normal.py b/jax/scipy/stats/multivariate_normal.py
index 95ad355c7..94c4cc50a 100644
--- a/jax/scipy/stats/multivariate_normal.py
+++ b/jax/scipy/stats/multivariate_normal.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.multivariate_normal import (
logpdf as logpdf,
diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py
index f47765adf..563e40ce0 100644
--- a/jax/scipy/stats/norm.py
+++ b/jax/scipy/stats/norm.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.norm import (
cdf as cdf,
diff --git a/jax/scipy/stats/pareto.py b/jax/scipy/stats/pareto.py
index bf27ea205..5e46fd5d0 100644
--- a/jax/scipy/stats/pareto.py
+++ b/jax/scipy/stats/pareto.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.pareto import (
logpdf as logpdf,
diff --git a/jax/scipy/stats/poisson.py b/jax/scipy/stats/poisson.py
index 2e857bc15..5fcde905f 100644
--- a/jax/scipy/stats/poisson.py
+++ b/jax/scipy/stats/poisson.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.poisson import (
logpmf as logpmf,
diff --git a/jax/scipy/stats/t.py b/jax/scipy/stats/t.py
index d92fcab97..694bcb0b0 100644
--- a/jax/scipy/stats/t.py
+++ b/jax/scipy/stats/t.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.t import (
logpdf as logpdf,
diff --git a/jax/scipy/stats/truncnorm.py b/jax/scipy/stats/truncnorm.py
index 28d5533b0..cb8e8958d 100644
--- a/jax/scipy/stats/truncnorm.py
+++ b/jax/scipy/stats/truncnorm.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.truncnorm import (
cdf as cdf,
diff --git a/jax/scipy/stats/uniform.py b/jax/scipy/stats/uniform.py
index d0a06c673..fa754125f 100644
--- a/jax/scipy/stats/uniform.py
+++ b/jax/scipy/stats/uniform.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.uniform import (
logpdf as logpdf,
diff --git a/jax/scipy/stats/vonmises.py b/jax/scipy/stats/vonmises.py
index 8de7fba47..6572e43f6 100644
--- a/jax/scipy/stats/vonmises.py
+++ b/jax/scipy/stats/vonmises.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.vonmises import (
logpdf as logpdf,
diff --git a/jax/scipy/stats/wrapcauchy.py b/jax/scipy/stats/wrapcauchy.py
index 6e2420c5a..eb1768f0c 100644
--- a/jax/scipy/stats/wrapcauchy.py
+++ b/jax/scipy/stats/wrapcauchy.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.scipy.stats.wrapcauchy import (
logpdf as logpdf,
diff --git a/jax/sharding.py b/jax/sharding.py
index ea92e9d17..26c542292 100644
--- a/jax/sharding.py
+++ b/jax/sharding.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.sharding import Sharding as Sharding
from jax._src.sharding_impls import (
diff --git a/jax/stages.py b/jax/stages.py
index 6ffc3144c..3e7e461c3 100644
--- a/jax/stages.py
+++ b/jax/stages.py
@@ -22,7 +22,7 @@ For more, see the `AOT walkthrough as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.stages import (
Compiled as Compiled,
diff --git a/jax/test_util.py b/jax/test_util.py
index 5d4f5ed0a..176f4521b 100644
--- a/jax/test_util.py
+++ b/jax/test_util.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.public_test_util import (
check_grads as check_grads,
diff --git a/jax/tree_util.py b/jax/tree_util.py
index b4854c7df..956d79b9b 100644
--- a/jax/tree_util.py
+++ b/jax/tree_util.py
@@ -36,7 +36,7 @@ for examples.
"""
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.tree_util import (
DictKey as DictKey,
diff --git a/jax/util.py b/jax/util.py
index c1259e9c5..8071f77df 100644
--- a/jax/util.py
+++ b/jax/util.py
@@ -13,7 +13,7 @@
# limitations under the License.
# Note: import as is required for names to be exported.
-# See PEP 484 & https://github.com/google/jax/issues/7570
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.util import (
HashableFunction as HashableFunction,
diff --git a/jax/version.py b/jax/version.py
index c6e4b3ad1..6c64d75b9 100644
--- a/jax/version.py
+++ b/jax/version.py
@@ -115,7 +115,7 @@ def _get_cmdclass(pkg_source_path):
# missing or outdated. Because _write_version(...) modifies the copy of
# this file in the build tree, re-building from the same JAX directory
# would not automatically re-copy a clean version, and _write_version
- # would fail without this deletion. See google/jax#18252.
+ # would fail without this deletion. See jax-ml/jax#18252.
if os.path.isfile(this_file_in_build_dir):
os.unlink(this_file_in_build_dir)
super().run()
diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py
index 9ccf3bf44..a84a6b34e 100644
--- a/jax_plugins/rocm/plugin_setup.py
+++ b/jax_plugins/rocm/plugin_setup.py
@@ -51,7 +51,7 @@ setup(
packages=[package_name],
python_requires=">=3.9",
install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"],
- url="https://github.com/google/jax",
+ url="https://github.com/jax-ml/jax",
license="Apache-2.0",
classifiers=[
"Development Status :: 3 - Alpha",
diff --git a/jaxlib/README.md b/jaxlib/README.md
index 74e1e5b36..cee5f246d 100644
--- a/jaxlib/README.md
+++ b/jaxlib/README.md
@@ -4,4 +4,4 @@ jaxlib is the support library for JAX. While JAX itself is a pure Python package
jaxlib contains the binary (C/C++) parts of the library, including Python bindings,
the XLA compiler, the PJRT runtime, and a handful of handwritten kernels.
For more information, including installation and build instructions, refer to main
-JAX README: https://github.com/google/jax/.
+JAX README: https://github.com/jax-ml/jax/.
diff --git a/jaxlib/setup.py b/jaxlib/setup.py
index 215313f9b..dea9503c7 100644
--- a/jaxlib/setup.py
+++ b/jaxlib/setup.py
@@ -66,7 +66,7 @@ setup(
'numpy>=1.24',
'ml_dtypes>=0.2.0',
],
- url='https://github.com/google/jax',
+ url='https://github.com/jax-ml/jax',
license='Apache-2.0',
classifiers=[
"Programming Language :: Python :: 3.10",
diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py
index 6305b0c24..52a17c451 100644
--- a/jaxlib/tools/build_wheel.py
+++ b/jaxlib/tools/build_wheel.py
@@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack():
We don't entirely know why this happens, but in some build environments
we seem to target the wrong Mac OS version.
- https://github.com/google/jax/issues/3867
+ https://github.com/jax-ml/jax/issues/3867
This check makes sure we don't release wheels that have this dependency.
"""
diff --git a/setup.py b/setup.py
index 81eef74e0..e807ff3b0 100644
--- a/setup.py
+++ b/setup.py
@@ -104,7 +104,7 @@ setup(
f"jax-cuda12-plugin=={_current_jaxlib_version}",
],
},
- url='https://github.com/google/jax',
+ url='https://github.com/jax-ml/jax',
license='Apache-2.0',
classifiers=[
"Programming Language :: Python :: 3.10",
diff --git a/tests/api_test.py b/tests/api_test.py
index 1deaa4c08..adce61d65 100644
--- a/tests/api_test.py
+++ b/tests/api_test.py
@@ -467,7 +467,7 @@ class JitTest(jtu.BufferDonationTestCase):
("argnames", "donate_argnames", ('array',)),
)
def test_jnp_array_copy(self, argnum_type, argnum_val):
- # https://github.com/google/jax/issues/3412
+ # https://github.com/jax-ml/jax/issues/3412
@partial(jit, **{argnum_type: argnum_val})
def _test(array):
@@ -905,7 +905,7 @@ class JitTest(jtu.BufferDonationTestCase):
@jax.legacy_prng_key('allow')
def test_omnistaging(self):
- # See https://github.com/google/jax/issues/5206
+ # See https://github.com/jax-ml/jax/issues/5206
# TODO(frostig): remove `wrap` once we always enable_custom_prng
def wrap(arr):
@@ -1409,7 +1409,7 @@ class JitTest(jtu.BufferDonationTestCase):
f({E.A: 1.0, E.B: 2.0})
def test_jit_static_argnums_requires_type_equality(self):
- # See: https://github.com/google/jax/pull/9311
+ # See: https://github.com/jax-ml/jax/pull/9311
@partial(jit, static_argnums=(0,))
def f(k):
assert python_should_be_executing
@@ -1424,7 +1424,7 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertEqual(x, f(x))
def test_caches_depend_on_axis_env(self):
- # https://github.com/google/jax/issues/9187
+ # https://github.com/jax-ml/jax/issues/9187
f = lambda: lax.psum(1, "i")
g = jax.jit(f)
expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)()
@@ -1437,7 +1437,7 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertEqual(ans, expected)
def test_caches_dont_depend_on_unnamed_axis_env(self):
- # https://github.com/google/jax/issues/9187
+ # https://github.com/jax-ml/jax/issues/9187
f = jax.jit(lambda: jnp.sin(1))
expected = f()
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
@@ -1446,7 +1446,7 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertArraysAllClose(ans, expected, check_dtypes=True)
def test_cache_key_defaults(self):
- # https://github.com/google/jax/discussions/11875
+ # https://github.com/jax-ml/jax/discussions/11875
f = jit(lambda x: (x ** 2).sum())
self.assertEqual(f._cache_size(), 0)
x = jnp.arange(5.0)
@@ -1455,7 +1455,7 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertEqual(f._cache_size(), 1)
def test_jit_nan_times_zero(self):
- # https://github.com/google/jax/issues/4780
+ # https://github.com/jax-ml/jax/issues/4780
def f(x):
return 1 + x * 0
self.assertAllClose(f(np.nan), np.nan)
@@ -2163,7 +2163,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(aux, [4.**2, 4.])
def test_grad_and_aux_no_tracers(self):
- # see https://github.com/google/jax/issues/1950
+ # see https://github.com/jax-ml/jax/issues/1950
def f(x):
aux = dict(identity=x, p1=x+1)
return x ** 2, aux
@@ -2322,7 +2322,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(actual, expected)
def test_linear_transpose_dce(self):
- # https://github.com/google/jax/issues/15660
+ # https://github.com/jax-ml/jax/issues/15660
f = jit(lambda x: (2 * x, x > 0))
g = lambda x: f(x)[0]
api.linear_transpose(g, 1.)(1.)
@@ -2389,7 +2389,7 @@ class APITest(jtu.JaxTestCase):
self.assertRaises(TypeError, lambda: jacrev(lambda x: jnp.sin(x))(1 + 2j))
def test_nonholomorphic_jacrev(self):
- # code based on https://github.com/google/jax/issues/603
+ # code based on https://github.com/jax-ml/jax/issues/603
zs = 0.5j * np.arange(5) + np.arange(5)
def f(z):
@@ -2401,8 +2401,8 @@ class APITest(jtu.JaxTestCase):
@jax.numpy_dtype_promotion('standard') # Test explicitly exercises implicit dtype promotion.
def test_heterogeneous_jacfwd(self):
- # See https://github.com/google/jax/issues/7157
- # See https://github.com/google/jax/issues/7780
+ # See https://github.com/jax-ml/jax/issues/7157
+ # See https://github.com/jax-ml/jax/issues/7780
x = np.array([2.0], dtype=np.float16)
y = np.array([3.0], dtype=np.float32)
a = (x, y)
@@ -2421,8 +2421,8 @@ class APITest(jtu.JaxTestCase):
@jax.numpy_dtype_promotion('standard') # Test explicitly exercises implicit dtype promotion.
def test_heterogeneous_jacrev(self):
- # See https://github.com/google/jax/issues/7157
- # See https://github.com/google/jax/issues/7780
+ # See https://github.com/jax-ml/jax/issues/7157
+ # See https://github.com/jax-ml/jax/issues/7780
x = np.array([2.0], dtype=np.float16)
y = np.array([3.0], dtype=np.float32)
a = (x, y)
@@ -2440,7 +2440,7 @@ class APITest(jtu.JaxTestCase):
jtu.check_eq(actual, desired)
def test_heterogeneous_grad(self):
- # See https://github.com/google/jax/issues/7157
+ # See https://github.com/jax-ml/jax/issues/7157
x = np.array(1.0+1j)
y = np.array(2.0)
a = (x, y)
@@ -2512,7 +2512,7 @@ class APITest(jtu.JaxTestCase):
self.assertIsNone(y())
def test_namedtuple_transparency(self):
- # See https://github.com/google/jax/issues/446
+ # See https://github.com/jax-ml/jax/issues/446
Point = collections.namedtuple("Point", ["x", "y"])
def f(pt):
@@ -2528,7 +2528,7 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
def test_namedtuple_subclass_transparency(self):
- # See https://github.com/google/jax/issues/806
+ # See https://github.com/jax-ml/jax/issues/806
Point = collections.namedtuple("Point", ["x", "y"])
class ZeroPoint(Point):
@@ -2705,7 +2705,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(out_shape.shape, (3, 5))
def test_eval_shape_duck_typing2(self):
- # https://github.com/google/jax/issues/5683
+ # https://github.com/jax-ml/jax/issues/5683
class EasyDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -2980,7 +2980,7 @@ class APITest(jtu.JaxTestCase):
]))
def test_vmap_in_axes_list(self):
- # https://github.com/google/jax/issues/2367
+ # https://github.com/jax-ml/jax/issues/2367
dictionary = {'a': 5., 'b': jnp.ones(2)}
x = jnp.zeros(3)
y = jnp.arange(3.)
@@ -2993,7 +2993,7 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(out1, out2)
def test_vmap_in_axes_non_tuple_error(self):
- # https://github.com/google/jax/issues/18548
+ # https://github.com/jax-ml/jax/issues/18548
with self.assertRaisesRegex(
TypeError,
re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding "
@@ -3001,7 +3001,7 @@ class APITest(jtu.JaxTestCase):
jax.vmap(lambda x: x['a'], in_axes={'a': 0})
def test_vmap_in_axes_wrong_length_tuple_error(self):
- # https://github.com/google/jax/issues/18548
+ # https://github.com/jax-ml/jax/issues/18548
with self.assertRaisesRegex(
ValueError,
re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding to the "
@@ -3009,7 +3009,7 @@ class APITest(jtu.JaxTestCase):
jax.vmap(lambda x: x['a'], in_axes=(0, {'a': 0}))({'a': jnp.zeros((3, 3))})
def test_vmap_in_axes_tree_prefix_error(self):
- # https://github.com/google/jax/issues/795
+ # https://github.com/jax-ml/jax/issues/795
value_tree = jnp.ones(3)
self.assertRaisesRegex(
ValueError,
@@ -3030,14 +3030,14 @@ class APITest(jtu.JaxTestCase):
api.vmap(lambda x: x, out_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.]))
def test_vmap_unbatched_object_passthrough_issue_183(self):
- # https://github.com/google/jax/issues/183
+ # https://github.com/jax-ml/jax/issues/183
fun = lambda f, x: f(x)
vfun = api.vmap(fun, (None, 0))
ans = vfun(lambda x: x + 1, jnp.arange(3))
self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False)
def test_vmap_mismatched_keyword(self):
- # https://github.com/google/jax/issues/10193
+ # https://github.com/jax-ml/jax/issues/10193
@jax.vmap
def f(x, y):
return x + y
@@ -3051,7 +3051,7 @@ class APITest(jtu.JaxTestCase):
f(jnp.array([1], 'int32'), y=jnp.array([1, 2], 'int32'))
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
- # https://github.com/google/jax/issues/705
+ # https://github.com/jax-ml/jax/issues/705
def h(a, b):
return jnp.sum(a) + jnp.sum(b)
@@ -3156,12 +3156,12 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
def test_vmap_in_axes_bool_error(self):
- # https://github.com/google/jax/issues/6372
+ # https://github.com/jax-ml/jax/issues/6372
with self.assertRaisesRegex(TypeError, "must be an int"):
api.vmap(lambda x: x, in_axes=False)(jnp.zeros(3))
def test_pmap_in_axes_bool_error(self):
- # https://github.com/google/jax/issues/6372
+ # https://github.com/jax-ml/jax/issues/6372
with self.assertRaisesRegex(TypeError, "must be an int"):
api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1))
@@ -3223,7 +3223,7 @@ class APITest(jtu.JaxTestCase):
hash(rep)
def test_grad_without_enough_args_error_message(self):
- # https://github.com/google/jax/issues/1696
+ # https://github.com/jax-ml/jax/issues/1696
def f(x, y): return x + y
df = api.grad(f, argnums=0)
self.assertRaisesRegex(
@@ -3301,7 +3301,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(count[0], 0) # cache hits on both fwd and bwd
def test_grad_does_not_unflatten_tree_with_none(self):
- # https://github.com/google/jax/issues/7546
+ # https://github.com/jax-ml/jax/issues/7546
class CustomNode(list):
pass
@@ -3370,7 +3370,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(count[0], 1)
def test_arange_jit(self):
- # see https://github.com/google/jax/issues/553
+ # see https://github.com/jax-ml/jax/issues/553
def fun(x):
r = jnp.arange(x.shape[0])[x]
return r
@@ -3496,7 +3496,7 @@ class APITest(jtu.JaxTestCase):
_ = self._saved_tracer+1
def test_pmap_static_kwarg_error_message(self):
- # https://github.com/google/jax/issues/3007
+ # https://github.com/jax-ml/jax/issues/3007
def f(a, b):
return a + b
@@ -3650,7 +3650,7 @@ class APITest(jtu.JaxTestCase):
g(1)
def test_join_concrete_arrays_with_omnistaging(self):
- # https://github.com/google/jax/issues/4622
+ # https://github.com/jax-ml/jax/issues/4622
x = jnp.array([1., 2., 3.])
y = jnp.array([1., 2., 4.])
@@ -3673,7 +3673,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(aux, True)
def test_linearize_aval_error(self):
- # https://github.com/google/jax/issues/4622
+ # https://github.com/jax-ml/jax/issues/4622
f = lambda x: x
# these should not error
@@ -3691,7 +3691,7 @@ class APITest(jtu.JaxTestCase):
f_jvp(np.ones(2, np.int32))
def test_grad_of_token_consuming_primitive(self):
- # https://github.com/google/jax/issues/5463
+ # https://github.com/jax-ml/jax/issues/5463
tokentest_p = core.Primitive("tokentest")
tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p))
tokentest_p.def_abstract_eval(lambda x, y: x)
@@ -3823,7 +3823,7 @@ class APITest(jtu.JaxTestCase):
f(3)
def test_leak_checker_avoids_false_positive_custom_jvp(self):
- # see https://github.com/google/jax/issues/5636
+ # see https://github.com/jax-ml/jax/issues/5636
with jax.checking_leaks():
@jax.custom_jvp
def t(y):
@@ -3906,7 +3906,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(jnp.ones(1).devices(), system_default_devices)
def test_dunder_jax_array(self):
- # https://github.com/google/jax/pull/4725
+ # https://github.com/jax-ml/jax/pull/4725
class AlexArray:
def __init__(self, jax_val):
@@ -3939,7 +3939,7 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(np.array(((1, 1), (1, 1))), a2)
def test_eval_shape_weak_type(self):
- # https://github.com/google/jax/issues/23302
+ # https://github.com/jax-ml/jax/issues/23302
arr = jax.numpy.array(1)
with jtu.count_jit_tracing_cache_miss() as count:
@@ -3980,7 +3980,7 @@ class APITest(jtu.JaxTestCase):
f(a, a) # don't crash
def test_constant_handler_mro(self):
- # https://github.com/google/jax/issues/6129
+ # https://github.com/jax-ml/jax/issues/6129
class Foo(enum.IntEnum):
bar = 1
@@ -3997,7 +3997,7 @@ class APITest(jtu.JaxTestCase):
{"testcase_name": f"{dtype.__name__}", "dtype": dtype}
for dtype in jtu.dtypes.all])
def test_constant_handlers(self, dtype):
- # https://github.com/google/jax/issues/9380
+ # https://github.com/jax-ml/jax/issues/9380
@jax.jit
def f():
return jnp.exp(dtype(0))
@@ -4135,7 +4135,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = api.make_jaxpr(f)(3)
self.assertNotIn('pjit', str(jaxpr))
- # Repro for https://github.com/google/jax/issues/7229.
+ # Repro for https://github.com/jax-ml/jax/issues/7229.
def test_compute_with_large_transfer(self):
def f(x, delta):
return x + jnp.asarray(delta, x.dtype)
@@ -4193,7 +4193,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(actual, expected)
def test_leaked_tracer_issue_7613(self):
- # from https://github.com/google/jax/issues/7613
+ # from https://github.com/jax-ml/jax/issues/7613
import numpy.random as npr
def sigmoid(x):
@@ -4211,7 +4211,7 @@ class APITest(jtu.JaxTestCase):
_ = jax.grad(loss)(A, x) # doesn't crash
def test_vmap_caching(self):
- # https://github.com/google/jax/issues/7621
+ # https://github.com/jax-ml/jax/issues/7621
f = lambda x: jnp.square(x).mean()
jf = jax.jit(f)
@@ -4299,7 +4299,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(2 * i, g(2, i), msg=i)
def test_fastpath_cache_confusion(self):
- # https://github.com/google/jax/issues/12542
+ # https://github.com/jax-ml/jax/issues/12542
@jax.jit
def a(x):
return ()
@@ -4344,7 +4344,7 @@ class APITest(jtu.JaxTestCase):
b(8) # don't crash
def test_vjp_multiple_arguments_error_message(self):
- # https://github.com/google/jax/issues/13099
+ # https://github.com/jax-ml/jax/issues/13099
def foo(x):
return (x, x)
_, f_vjp = jax.vjp(foo, 1.0)
@@ -4376,7 +4376,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(jfoo.__module__, "jax")
def test_inner_jit_function_retracing(self):
- # https://github.com/google/jax/issues/7155
+ # https://github.com/jax-ml/jax/issues/7155
inner_count = outer_count = 0
@jax.jit
@@ -4403,7 +4403,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(outer_count, 1)
def test_grad_conj_symbolic_zeros(self):
- # https://github.com/google/jax/issues/15400
+ # https://github.com/jax-ml/jax/issues/15400
f = lambda x: jax.jit(lambda x, y: (x, y))(x, jax.lax.conj(x))[0]
out = jax.grad(f)(3.0) # doesn't crash
self.assertAllClose(out, 1., check_dtypes=False)
@@ -4555,7 +4555,7 @@ class APITest(jtu.JaxTestCase):
self._CompileAndCheck(f, args_maker)
def test_jvp_asarray_returns_array(self):
- # https://github.com/google/jax/issues/15676
+ # https://github.com/jax-ml/jax/issues/15676
p, t = jax.jvp(jax.numpy.asarray, (1.,), (2.,))
_check_instance(self, p)
_check_instance(self, t)
@@ -4716,7 +4716,7 @@ class APITest(jtu.JaxTestCase):
f()
def test_inline_return_twice(self):
- # https://github.com/google/jax/issues/22944
+ # https://github.com/jax-ml/jax/issues/22944
@jax.jit
def add_one(x: int) -> int:
return x + 1
@@ -5074,7 +5074,7 @@ class RematTest(jtu.JaxTestCase):
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_no_redundant_flops(self, remat):
- # see https://github.com/google/jax/pull/1749#issuecomment-558267584
+ # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584
@api.jit
def g(x):
@@ -5124,7 +5124,7 @@ class RematTest(jtu.JaxTestCase):
('_new', new_checkpoint),
])
def test_remat_symbolic_zeros(self, remat):
- # code from https://github.com/google/jax/issues/1907
+ # code from https://github.com/jax-ml/jax/issues/1907
key = jax.random.key(0)
key, split = jax.random.split(key)
@@ -5177,7 +5177,7 @@ class RematTest(jtu.JaxTestCase):
('_new', new_checkpoint),
])
def test_remat_nontrivial_env(self, remat):
- # simplified from https://github.com/google/jax/issues/2030
+ # simplified from https://github.com/jax-ml/jax/issues/2030
@remat
def foo(state, dt=0.5, c=1):
@@ -5211,7 +5211,7 @@ class RematTest(jtu.JaxTestCase):
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_jit3(self, remat):
- # https://github.com/google/jax/issues/2180
+ # https://github.com/jax-ml/jax/issues/2180
def f(w, x):
a = jnp.dot(x, w)
b = jnp.einsum("btd,bTd->btT", a, a)
@@ -5244,7 +5244,7 @@ class RematTest(jtu.JaxTestCase):
('_new', new_checkpoint),
])
def test_remat_scan2(self, remat):
- # https://github.com/google/jax/issues/1963
+ # https://github.com/jax-ml/jax/issues/1963
def scan_bug(x0):
f = lambda x, _: (x + 1, None)
@@ -5256,7 +5256,7 @@ class RematTest(jtu.JaxTestCase):
jax.grad(scan_bug)(1.0) # doesn't crash
def test_remat_jit_static_argnum_omnistaging(self):
- # https://github.com/google/jax/issues/2833
+ # https://github.com/jax-ml/jax/issues/2833
# NOTE(mattjj): after #3370, this test doesn't actually call remat...
def named_call(f):
def named_f(*args):
@@ -5281,7 +5281,7 @@ class RematTest(jtu.JaxTestCase):
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_eval_counter(self, remat):
- # https://github.com/google/jax/issues/2737
+ # https://github.com/jax-ml/jax/issues/2737
add_one_p = core.Primitive('add_one')
add_one = add_one_p.bind
@@ -5665,7 +5665,7 @@ class RematTest(jtu.JaxTestCase):
# The old implementation of remat worked by data dependence, and so
# (potentially large) constants would not be rematerialized and could be
# wastefully instantiated. This test checks that the newer remat
- # implementation avoids that. See https://github.com/google/jax/pull/8191.
+ # implementation avoids that. See https://github.com/jax-ml/jax/pull/8191.
# no residuals from constants created inside jnp.einsum
@partial(new_checkpoint, policy=lambda *_, **__: False)
@@ -5790,7 +5790,7 @@ class RematTest(jtu.JaxTestCase):
_ = jax.grad(f)(3.) # doesn't crash
def test_linearize_caching(self):
- # https://github.com/google/jax/issues/9661
+ # https://github.com/jax-ml/jax/issues/9661
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
_, f_lin = jax.linearize(identity, 1.)
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
@@ -5799,7 +5799,7 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(count[0], 1) # cached after first execution
def test_vjp_caching(self):
- # https://github.com/google/jax/issues/9661
+ # https://github.com/jax-ml/jax/issues/9661
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
_, f_vjp = jax.vjp(identity, 1.)
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
@@ -6928,7 +6928,7 @@ class CustomJVPTest(jtu.JaxTestCase):
check_dtypes=False)
def test_kwargs(self):
- # from https://github.com/google/jax/issues/1938
+ # from https://github.com/jax-ml/jax/issues/1938
@jax.custom_jvp
def my_fun(x, y, c=1.):
return c * (x + y)
@@ -7209,7 +7209,7 @@ class CustomJVPTest(jtu.JaxTestCase):
def test_jvp_rule_doesnt_return_pair_error_message(self):
- # https://github.com/google/jax/issues/2516
+ # https://github.com/jax-ml/jax/issues/2516
@jax.custom_jvp
def f(x):
@@ -7374,7 +7374,7 @@ class CustomJVPTest(jtu.JaxTestCase):
api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3)))
def test_jaxpr_zeros(self):
- # from https://github.com/google/jax/issues/2657
+ # from https://github.com/jax-ml/jax/issues/2657
@jax.custom_jvp
def f(A, b):
return A @ b
@@ -7420,7 +7420,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_jvps_first_rule_is_none(self):
- # https://github.com/google/jax/issues/3389
+ # https://github.com/jax-ml/jax/issues/3389
@jax.custom_jvp
def f(x, y):
return x ** 2 * y
@@ -7431,7 +7431,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_concurrent_initial_style(self):
- # https://github.com/google/jax/issues/3843
+ # https://github.com/jax-ml/jax/issues/3843
def unroll(param, sequence):
def scan_f(prev_state, inputs):
return prev_state, jax.nn.sigmoid(param * inputs)
@@ -7453,7 +7453,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected)
def test_nondiff_argnums_vmap_tracer(self):
- # https://github.com/google/jax/issues/3964
+ # https://github.com/jax-ml/jax/issues/3964
@partial(jax.custom_jvp, nondiff_argnums=(0, 2))
def sample(shape, param, seed):
return jax.random.uniform(key=seed, shape=shape, minval=param)
@@ -7495,7 +7495,7 @@ class CustomJVPTest(jtu.JaxTestCase):
api.vmap(fun_with_nested_calls_2)(jnp.arange(3.))
def test_closure_with_vmap(self):
- # https://github.com/google/jax/issues/3822
+ # https://github.com/jax-ml/jax/issues/3822
alpha = np.float32(2.)
def sample(seed):
@@ -7515,7 +7515,7 @@ class CustomJVPTest(jtu.JaxTestCase):
api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash
def test_closure_with_vmap2(self):
- # https://github.com/google/jax/issues/8783
+ # https://github.com/jax-ml/jax/issues/8783
def h(z):
def f(x):
@jax.custom_jvp
@@ -7660,7 +7660,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_jvp_vmap_broadcasting_interaction(self):
- # https://github.com/google/jax/issues/6452
+ # https://github.com/jax-ml/jax/issues/6452
def f2(y, z):
v1 = z
v2 = jnp.sum(y) + z
@@ -7678,7 +7678,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertEqual(g.shape, ())
def test_custom_jvp_vmap_broadcasting_interaction_2(self):
- # https://github.com/google/jax/issues/5849
+ # https://github.com/jax-ml/jax/issues/5849
@jax.custom_jvp
def transform(box, R):
if jnp.isscalar(box) or box.size == 1:
@@ -7716,7 +7716,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertEqual(grad(energy_fn)(scalar_box).shape, ())
def test_custom_jvp_implicit_broadcasting(self):
- # https://github.com/google/jax/issues/6357
+ # https://github.com/jax-ml/jax/issues/6357
if config.enable_x64.value:
raise unittest.SkipTest("test only applies when x64 is disabled")
@@ -7774,7 +7774,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3)
def test_vmap_inside_defjvp(self):
- # https://github.com/google/jax/issues/3201
+ # https://github.com/jax-ml/jax/issues/3201
seed = 47
key = jax.random.key(seed)
mat = jax.random.normal(key, (2, 3))
@@ -7823,7 +7823,7 @@ class CustomJVPTest(jtu.JaxTestCase):
jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash
def test_custom_jvp_unbroadcasting(self):
- # https://github.com/google/jax/issues/3056
+ # https://github.com/jax-ml/jax/issues/3056
a = jnp.array([1., 1.])
@jax.custom_jvp
@@ -7841,8 +7841,8 @@ class CustomJVPTest(jtu.JaxTestCase):
def test_maybe_perturbed_internal_helper_function(self):
# This is a unit test for an internal API. We include it so as not to
- # regress https://github.com/google/jax/issues/9567. For an explanation of
- # this helper function, see https://github.com/google/jax/issues/6415.
+ # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of
+ # this helper function, see https://github.com/jax-ml/jax/issues/6415.
def f(x):
def g(y, _):
z = y * x
@@ -7854,7 +7854,7 @@ class CustomJVPTest(jtu.JaxTestCase):
jax.jvp(f, (1.0,), (1.0,)) # assertions inside f
def test_maybe_perturbed_int_regression(self):
- # see https://github.com/google/jax/discussions/9951
+ # see https://github.com/jax-ml/jax/discussions/9951
@jax.jit
def f():
@@ -7864,7 +7864,7 @@ class CustomJVPTest(jtu.JaxTestCase):
f()
def test_sinc_constant_function_batching(self):
- # https://github.com/google/jax/pull/10756
+ # https://github.com/jax-ml/jax/pull/10756
batch_data = jnp.arange(15.).reshape(5, 3)
@jax.vmap
@@ -7981,7 +7981,7 @@ class CustomJVPTest(jtu.JaxTestCase):
_ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash!
def test_symbolic_zeros_under_jit(self):
- # https://github.com/google/jax/issues/14833
+ # https://github.com/jax-ml/jax/issues/14833
Zero = jax.custom_derivatives.SymbolicZero
@jax.custom_jvp
@@ -8015,7 +8015,7 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0)))
def test_run_rules_more_than_once(self):
- # https://github.com/google/jax/issues/16614
+ # https://github.com/jax-ml/jax/issues/16614
@jax.custom_jvp
def f(x, y):
@@ -8206,7 +8206,7 @@ class CustomVJPTest(jtu.JaxTestCase):
lambda: api.jvp(jit(f), (3.,), (1.,)))
def test_kwargs(self):
- # from https://github.com/google/jax/issues/1938
+ # from https://github.com/jax-ml/jax/issues/1938
@jax.custom_vjp
def my_fun(x, y, c=1.):
return c * (x + y)
@@ -8502,7 +8502,7 @@ class CustomVJPTest(jtu.JaxTestCase):
api.jit(foo)(arr) # doesn't crash
def test_lowering_out_of_traces(self):
- # https://github.com/google/jax/issues/2578
+ # https://github.com/jax-ml/jax/issues/2578
class F(collections.namedtuple("F", ["a"])):
def __call__(self, x):
@@ -8515,7 +8515,7 @@ class CustomVJPTest(jtu.JaxTestCase):
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
def test_clip_gradient(self):
- # https://github.com/google/jax/issues/2784
+ # https://github.com/jax-ml/jax/issues/2784
@jax.custom_vjp
def _clip_gradient(lo, hi, x):
return x # identity function when not differentiating
@@ -8538,7 +8538,7 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertAllClose(g, jnp.array(0.2))
def test_nestable_vjp(self):
- # Verify that https://github.com/google/jax/issues/3667 is resolved.
+ # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved.
def f(x):
return x ** 2
@@ -8571,7 +8571,7 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertAllClose(y, jnp.array(6.0))
def test_initial_style_vmap_2(self):
- # https://github.com/google/jax/issues/4173
+ # https://github.com/jax-ml/jax/issues/4173
x = jnp.ones((10, 3))
# Create the custom function
@@ -8837,7 +8837,7 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_vjp_closure_4521(self):
- # https://github.com/google/jax/issues/4521
+ # https://github.com/jax-ml/jax/issues/4521
@jax.custom_vjp
def g(x, y):
return None
@@ -8954,7 +8954,7 @@ class CustomVJPTest(jtu.JaxTestCase):
def test_closure_convert_mixed_consts(self):
# Like test_closure_convert, but close over values that
# participate in AD as well as values that do not.
- # See https://github.com/google/jax/issues/6415
+ # See https://github.com/jax-ml/jax/issues/6415
def cos_after(fn, x):
converted_fn, aux_args = jax.closure_convert(fn, x)
@@ -8993,7 +8993,7 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertAllClose(g_x, 17. * x, check_dtypes=False)
def test_closure_convert_pytree_mismatch(self):
- # See https://github.com/google/jax/issues/23588
+ # See https://github.com/jax-ml/jax/issues/23588
def f(x, z):
return z * x
@@ -9021,7 +9021,7 @@ class CustomVJPTest(jtu.JaxTestCase):
jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash
def test_custom_vjp_scan_batching_edge_case(self):
- # https://github.com/google/jax/issues/5832
+ # https://github.com/jax-ml/jax/issues/5832
@jax.custom_vjp
def mul(x, coeff): return x * coeff
def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff)
@@ -9052,7 +9052,7 @@ class CustomVJPTest(jtu.JaxTestCase):
modes=['rev'])
def test_closure_with_vmap2(self):
- # https://github.com/google/jax/issues/8783
+ # https://github.com/jax-ml/jax/issues/8783
def h(z):
def f(x):
@jax.custom_vjp
@@ -9094,7 +9094,7 @@ class CustomVJPTest(jtu.JaxTestCase):
jax.grad(f)(A([1.])) # doesn't crash
def test_vmap_vjp_called_twice(self):
- # https://github.com/google/jax/pull/14728
+ # https://github.com/jax-ml/jax/pull/14728
@jax.custom_vjp
def f(x):
return x
@@ -9390,7 +9390,7 @@ class CustomVJPTest(jtu.JaxTestCase):
_ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash!
def test_run_rules_more_than_once(self):
- # https://github.com/google/jax/issues/16614
+ # https://github.com/jax-ml/jax/issues/16614
@jax.custom_vjp
def f(x, y):
@@ -9420,7 +9420,7 @@ class CustomVJPTest(jtu.JaxTestCase):
g(1.) # doesn't crash
def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self):
- # https://github.com/google/jax/issues/8356
+ # https://github.com/jax-ml/jax/issues/8356
@jax.custom_vjp
def f(x):
return x[0]
@@ -9618,7 +9618,7 @@ class CustomVJPTest(jtu.JaxTestCase):
jax.grad(f)(x, y) # Doesn't error
def test_optimize_remat_custom_vmap(self):
- # See https://github.com/google/jax/pull/23000
+ # See https://github.com/jax-ml/jax/pull/23000
@jax.custom_vjp
def f(x, y):
return jnp.sin(x) * y
@@ -10908,7 +10908,7 @@ class AutodidaxTest(jtu.JaxTestCase):
class GarbageCollectionTest(jtu.JaxTestCase):
def test_xla_gc_callback(self):
- # https://github.com/google/jax/issues/14882
+ # https://github.com/jax-ml/jax/issues/14882
x_np = np.arange(10, dtype='int32')
x_jax = jax.device_put(x_np)
x_np_weakref = weakref.ref(x_np)
diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py
index c2cd4c0f9..5585f1bcc 100644
--- a/tests/array_interoperability_test.py
+++ b/tests/array_interoperability_test.py
@@ -185,7 +185,7 @@ class DLPackTest(jtu.JaxTestCase):
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testTensorFlowToJaxInt64(self):
- # See https://github.com/google/jax/issues/11895
+ # See https://github.com/jax-ml/jax/issues/11895
x = jax.dlpack.from_dlpack(
tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64)))
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
diff --git a/tests/attrs_test.py b/tests/attrs_test.py
index 4378a3c75..4eb354a8d 100644
--- a/tests/attrs_test.py
+++ b/tests/attrs_test.py
@@ -323,7 +323,7 @@ class AttrsTest(jtu.JaxTestCase):
self.assertEqual(count, 1)
def test_tracer_lifetime_bug(self):
- # regression test for https://github.com/google/jax/issues/20082
+ # regression test for https://github.com/jax-ml/jax/issues/20082
class StatefulRNG:
key: jax.Array
diff --git a/tests/batching_test.py b/tests/batching_test.py
index 6cd8c7bc2..2b0b0d63a 100644
--- a/tests/batching_test.py
+++ b/tests/batching_test.py
@@ -335,7 +335,7 @@ class BatchingTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected_ans, check_dtypes=False)
def testJacobianIssue54(self):
- # test modeling the code in https://github.com/google/jax/issues/54
+ # test modeling the code in https://github.com/jax-ml/jax/issues/54
def func(xs):
return jnp.array(list(xs))
@@ -345,7 +345,7 @@ class BatchingTest(jtu.JaxTestCase):
jacfwd(func)(xs) # don't crash
def testAny(self):
- # test modeling the code in https://github.com/google/jax/issues/108
+ # test modeling the code in https://github.com/jax-ml/jax/issues/108
ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]]))
expected = jnp.array([True, False])
@@ -368,7 +368,7 @@ class BatchingTest(jtu.JaxTestCase):
def testDynamicSlice(self):
# test dynamic_slice via numpy indexing syntax
- # see https://github.com/google/jax/issues/1613 for an explanation of why we
+ # see https://github.com/jax-ml/jax/issues/1613 for an explanation of why we
# need to use np rather than np to create x and idx
x = jnp.arange(30).reshape((10, 3))
@@ -933,7 +933,7 @@ class BatchingTest(jtu.JaxTestCase):
rtol=jtu.default_gradient_tolerance)
def testIssue387(self):
- # https://github.com/google/jax/issues/387
+ # https://github.com/jax-ml/jax/issues/387
R = self.rng().rand(100, 2)
def dist_sq(R):
@@ -951,7 +951,7 @@ class BatchingTest(jtu.JaxTestCase):
@jax.legacy_prng_key('allow')
def testIssue489(self):
- # https://github.com/google/jax/issues/489
+ # https://github.com/jax-ml/jax/issues/489
def f(key):
def body_fn(uk):
key = uk[1]
@@ -1131,7 +1131,7 @@ class BatchingTest(jtu.JaxTestCase):
x - np.arange(x.shape[0], dtype='int32'))
def testVmapKwargs(self):
- # https://github.com/google/jax/issues/912
+ # https://github.com/jax-ml/jax/issues/912
def f(a, b):
return (2*a, 3*b)
@@ -1242,7 +1242,7 @@ class BatchingTest(jtu.JaxTestCase):
self.assertEqual(jax.vmap(f)(jnp.ones((2, 3))).shape, (2, 3))
def testPpermuteBatcherTrivial(self):
- # https://github.com/google/jax/issues/8688
+ # https://github.com/jax-ml/jax/issues/8688
def ppermute(input):
return jax.lax.ppermute(input, axis_name="i", perm=[[0, 1], [1, 0]])
@@ -1255,7 +1255,7 @@ class BatchingTest(jtu.JaxTestCase):
self.assertAllClose(ans, jnp.ones(2), check_dtypes=False)
def testBatchingPreservesWeakType(self):
- # Regression test for https://github.com/google/jax/issues/10025
+ # Regression test for https://github.com/jax-ml/jax/issues/10025
x = jnp.ravel(1)
self.assertTrue(dtypes.is_weakly_typed(x))
@vmap
diff --git a/tests/core_test.py b/tests/core_test.py
index 0838702c4..94b701090 100644
--- a/tests/core_test.py
+++ b/tests/core_test.py
@@ -349,7 +349,7 @@ class CoreTest(jtu.JaxTestCase):
g_vmap(jnp.ones((1, )))
def test_concrete_array_string_representation(self):
- # https://github.com/google/jax/issues/5364
+ # https://github.com/jax-ml/jax/issues/5364
self.assertEqual(
str(core.ConcreteArray(np.dtype(np.int32),
np.array([1], dtype=np.int32))),
@@ -369,7 +369,7 @@ class CoreTest(jtu.JaxTestCase):
self.assertEqual(dropvar.aval, aval)
def test_input_residual_forwarding(self):
- # https://github.com/google/jax/pull/11151
+ # https://github.com/jax-ml/jax/pull/11151
x = jnp.arange(3 * 4.).reshape(3, 4)
y = jnp.arange(4 * 3.).reshape(4, 3)
diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py
index 830526826..857dc34d4 100644
--- a/tests/custom_linear_solve_test.py
+++ b/tests/custom_linear_solve_test.py
@@ -291,7 +291,7 @@ class CustomLinearSolveTest(jtu.JaxTestCase):
jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3)
- # regression test for https://github.com/google/jax/issues/1536
+ # regression test for https://github.com/jax-ml/jax/issues/1536
jtu.check_grads(jax.jit(linear_solve), (a, b), order=2,
rtol={np.float32: 2e-3})
@@ -396,7 +396,7 @@ class CustomLinearSolveTest(jtu.JaxTestCase):
def test_custom_linear_solve_pytree_with_aux(self):
# Check that lax.custom_linear_solve handles
# pytree inputs + has_aux=True
- # https://github.com/google/jax/pull/13093
+ # https://github.com/jax-ml/jax/pull/13093
aux_orig = {'a': 1, 'b': 2}
b = {'c': jnp.ones(2), 'd': jnp.ones(3)}
diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py
index 19e2a5893..020c9f744 100644
--- a/tests/debug_nans_test.py
+++ b/tests/debug_nans_test.py
@@ -127,7 +127,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
ans.block_until_ready()
def testDebugNansJitWithDonation(self):
- # https://github.com/google/jax/issues/12514
+ # https://github.com/jax-ml/jax/issues/12514
a = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x, donate_argnums=(0,))(a)
@@ -214,7 +214,7 @@ class DebugInfsTest(jtu.JaxTestCase):
f(1)
def testDebugNansDoesntCorruptCaches(self):
- # https://github.com/google/jax/issues/6614
+ # https://github.com/jax-ml/jax/issues/6614
@jax.jit
def f(x):
return jnp.divide(x, x)
diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py
index c6b12e2d8..e736e06da 100644
--- a/tests/dtypes_test.py
+++ b/tests/dtypes_test.py
@@ -667,7 +667,7 @@ class TestPromotionTables(jtu.JaxTestCase):
{"testcase_name": f"_{typ}", "typ": typ}
for typ in [bool, int, float, complex])
def testScalarWeakTypes(self, typ):
- # Regression test for https://github.com/google/jax/issues/11377
+ # Regression test for https://github.com/jax-ml/jax/issues/11377
val = typ(0)
result1 = jnp.array(val)
@@ -806,7 +806,7 @@ class TestPromotionTables(jtu.JaxTestCase):
for weak_type in [True, False]
)
def testUnaryPromotion(self, dtype, weak_type):
- # Regression test for https://github.com/google/jax/issues/6051
+ # Regression test for https://github.com/jax-ml/jax/issues/6051
if dtype in intn_dtypes:
self.skipTest("XLA support for int2 and int4 is incomplete.")
x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
@@ -852,7 +852,7 @@ class TestPromotionTables(jtu.JaxTestCase):
self.skipTest("XLA support for float8 is incomplete.")
if dtype in intn_dtypes:
self.skipTest("XLA support for int2 and int4 is incomplete.")
- # Regression test for https://github.com/google/jax/issues/6051
+ # Regression test for https://github.com/jax-ml/jax/issues/6051
x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
with jax.numpy_dtype_promotion(promotion):
y = (x + x)
diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py
index 101dddccb..cc2419fb3 100644
--- a/tests/dynamic_api_test.py
+++ b/tests/dynamic_api_test.py
@@ -621,7 +621,7 @@ class DynamicShapeStagingTest(jtu.JaxTestCase):
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3)
def test_shape_validation(self):
- # Regression test for https://github.com/google/jax/issues/18937
+ # Regression test for https://github.com/jax-ml/jax/issues/18937
msg = r"Shapes must be 1D sequences of integer scalars, got .+"
with self.assertRaisesRegex(TypeError, msg):
jax.make_jaxpr(jnp.ones)(5.0)
diff --git a/tests/export_test.py b/tests/export_test.py
index d5884b7e6..0d946d84d 100644
--- a/tests/export_test.py
+++ b/tests/export_test.py
@@ -1334,7 +1334,7 @@ class JaxExportTest(jtu.JaxTestCase):
def test_grad_sharding_different_mesh(self):
# Export and serialize with two similar meshes, the only difference being
# the order of the devices. grad and serialization should not fail.
- # https://github.com/google/jax/issues/21314
+ # https://github.com/jax-ml/jax/issues/21314
def f(x):
return jnp.sum(x * 2.)
diff --git a/tests/fft_test.py b/tests/fft_test.py
index 05fa96a93..a87b7b66e 100644
--- a/tests/fft_test.py
+++ b/tests/fft_test.py
@@ -175,7 +175,7 @@ class FftTest(jtu.JaxTestCase):
self.assertEqual(dtype, expected_dtype)
def testIrfftTranspose(self):
- # regression test for https://github.com/google/jax/issues/6223
+ # regression test for https://github.com/jax-ml/jax/issues/6223
def build_matrix(linear_func, size):
return jax.vmap(linear_func)(jnp.eye(size, size))
diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py
index 944b47dc8..837d205fb 100644
--- a/tests/host_callback_test.py
+++ b/tests/host_callback_test.py
@@ -1035,7 +1035,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
( 5.00 2 )""", testing_stream.output)
def test_tap_grad_float0_result(self):
- # https://github.com/google/jax/issues/7340
+ # https://github.com/jax-ml/jax/issues/7340
# x is a Tuple[f32[2], s32[3]]
x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32))
@@ -1058,7 +1058,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
def test_tap_higher_order_grad_float0_result(self):
- # https://github.com/google/jax/issues/7340
+ # https://github.com/jax-ml/jax/issues/7340
# x is a Tuple[f32[2], s32[3]]
x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32))
@@ -1935,7 +1935,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
hcb.id_tap(func, 1, y=2)
def test_tap_id_tap_random_key(self):
- # See https://github.com/google/jax/issues/13949
+ # See https://github.com/jax-ml/jax/issues/13949
with jax.enable_custom_prng():
@jax.jit
def f(x):
diff --git a/tests/image_test.py b/tests/image_test.py
index f3cd56ed7..0f6341086 100644
--- a/tests/image_test.py
+++ b/tests/image_test.py
@@ -180,7 +180,7 @@ class ImageTest(jtu.JaxTestCase):
antialias=[False, True],
)
def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias):
- # Regression test for https://github.com/google/jax/issues/7586
+ # Regression test for https://github.com/jax-ml/jax/issues/7586
image = np.ones(image_shape, dtype)
out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias)
self.assertArraysEqual(out, jnp.zeros(target_shape, dtype))
diff --git a/tests/jet_test.py b/tests/jet_test.py
index b1e2ef3f8..4e437c044 100644
--- a/tests/jet_test.py
+++ b/tests/jet_test.py
@@ -404,7 +404,7 @@ class JetTest(jtu.JaxTestCase):
self.assertArraysEqual(g_out_series, f_out_series)
def test_add_any(self):
- # https://github.com/google/jax/issues/5217
+ # https://github.com/jax-ml/jax/issues/5217
f = lambda x, eps: x * eps + eps + x
def g(eps):
x = jnp.array(1.)
@@ -412,7 +412,7 @@ class JetTest(jtu.JaxTestCase):
jet(g, (1.,), ([1.],)) # doesn't crash
def test_scatter_add(self):
- # very basic test from https://github.com/google/jax/issues/5365
+ # very basic test from https://github.com/jax-ml/jax/issues/5365
def f(x):
x0 = x[0]
x1 = x[1]
diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py
index ab3a18317..78d90cb8a 100644
--- a/tests/lax_autodiff_test.py
+++ b/tests/lax_autodiff_test.py
@@ -424,7 +424,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
assert "Precision.HIGHEST" in s
def testDotPreferredElementType(self):
- # https://github.com/google/jax/issues/10818
+ # https://github.com/jax-ml/jax/issues/10818
x = jax.numpy.ones((), jax.numpy.float16)
def f(x):
return jax.lax.dot_general(x, x, (((), ()), ((), ())),
@@ -513,7 +513,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
rtol={np.float32: 3e-3})
def testPowSecondDerivative(self):
- # https://github.com/google/jax/issues/12033
+ # https://github.com/jax-ml/jax/issues/12033
x, y = 4.0, 0.0
expected = ((0.0, 1/x), (1/x, np.log(x) ** 2))
@@ -528,18 +528,18 @@ class LaxAutodiffTest(jtu.JaxTestCase):
with self.subTest("zero to the zero"):
result = jax.grad(lax.pow)(0.0, 0.0)
# TODO(jakevdp) special-case zero in a way that doesn't break other cases
- # See https://github.com/google/jax/pull/12041#issuecomment-1222766191
+ # See https://github.com/jax-ml/jax/pull/12041#issuecomment-1222766191
# self.assertEqual(result, 0.0)
self.assertAllClose(result, np.nan)
def testPowIntPowerAtZero(self):
- # https://github.com/google/jax/issues/14397
+ # https://github.com/jax-ml/jax/issues/14397
ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0)
self.assertAllClose(ans, 0., check_dtypes=False)
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testPowIntPowerAtZero2(self):
- # https://github.com/google/jax/issues/17995
+ # https://github.com/jax-ml/jax/issues/17995
a = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=int))
b = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=float))
c = lambda z: 1 + z
@@ -634,7 +634,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)
def testDynamicSliceValueAndGrad(self):
- # Regression test for https://github.com/google/jax/issues/10984
+ # Regression test for https://github.com/jax-ml/jax/issues/10984
# Issue arose due to an out-of-range negative index.
rng = jtu.rand_default(self.rng())
shape = (5, 5)
@@ -649,7 +649,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
self.assertAllClose(result1, result2)
def testDynamicUpdateSliceValueAndGrad(self):
- # Regression test for https://github.com/google/jax/issues/10984
+ # Regression test for https://github.com/jax-ml/jax/issues/10984
# Issue arose due to an out-of-range negative index.
rng = jtu.rand_default(self.rng())
shape = (5, 5)
@@ -1004,7 +1004,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)
def testScatterGradSymbolicZeroUpdate(self):
- # https://github.com/google/jax/issues/1901
+ # https://github.com/jax-ml/jax/issues/1901
def f(x):
n = x.shape[0]
y = np.arange(n, dtype=x.dtype)
@@ -1111,7 +1111,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
check_grads(lax.rem, (x, y), 2, ["fwd", "rev"])
def testHigherOrderGradientOfReciprocal(self):
- # Regression test for https://github.com/google/jax/issues/3136
+ # Regression test for https://github.com/jax-ml/jax/issues/3136
def inv(x):
# N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
return 1 / x
@@ -1150,7 +1150,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
jax.jacrev(f)(x)
def testPowShapeMismatch(self):
- # Regression test for https://github.com/google/jax/issues/17294
+ # Regression test for https://github.com/jax-ml/jax/issues/17294
x = lax.iota('float32', 4)
y = 2
actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error
diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py
index 37ad22063..7fb118d47 100644
--- a/tests/lax_control_flow_test.py
+++ b/tests/lax_control_flow_test.py
@@ -733,7 +733,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(fun(4), (8, 16))
def testCondPredIsNone(self):
- # see https://github.com/google/jax/issues/11574
+ # see https://github.com/jax-ml/jax/issues/11574
def f(pred, x):
return lax.cond(pred, lambda x: x + 1, lambda x: x + 2, x)
@@ -743,7 +743,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lambda: jax.jit(f)(None, 1.))
def testCondTwoOperands(self):
- # see https://github.com/google/jax/issues/8469
+ # see https://github.com/jax-ml/jax/issues/8469
add, mul = lax.add, lax.mul
def fun(x):
@@ -775,7 +775,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(fun(1), cfun(1))
def testCondCallableOperands(self):
- # see https://github.com/google/jax/issues/16413
+ # see https://github.com/jax-ml/jax/issues/16413
@tree_util.register_pytree_node_class
class Foo:
@@ -1560,7 +1560,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondVmapGrad(self, cond):
- # https://github.com/google/jax/issues/2264
+ # https://github.com/jax-ml/jax/issues/2264
def f_1(x): return x ** 2
def f_2(x): return x ** 3
@@ -1839,7 +1839,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def testIssue711(self, scan):
# Tests reverse-mode differentiation through a scan for which the scanned
# function also involves reverse-mode differentiation.
- # See https://github.com/google/jax/issues/711
+ # See https://github.com/jax-ml/jax/issues/711
def harmonic_bond(conf, params):
return jnp.sum(conf * params)
@@ -2078,7 +2078,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False)
def testIssue757(self):
- # code from https://github.com/google/jax/issues/757
+ # code from https://github.com/jax-ml/jax/issues/757
def fn(a):
return jnp.cos(a)
@@ -2107,7 +2107,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(actual, expected)
def testMapEmpty(self):
- # https://github.com/google/jax/issues/2412
+ # https://github.com/jax-ml/jax/issues/2412
ans = lax.map(lambda x: x * x, jnp.array([]))
expected = jnp.array([])
self.assertAllClose(ans, expected)
@@ -2164,7 +2164,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.while_loop(cond, body, 0)
def test_caches_depend_on_axis_env(self):
- # https://github.com/google/jax/issues/9187
+ # https://github.com/jax-ml/jax/issues/9187
scanned_f = lambda _, __: (lax.psum(1, 'i'), None)
f = lambda: lax.scan(scanned_f, 0, None, length=1)[0]
ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)()
@@ -2443,7 +2443,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(h, length)
def test_disable_jit_cond_with_vmap(self):
- # https://github.com/google/jax/issues/3093
+ # https://github.com/jax-ml/jax/issues/3093
def fn(t):
return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1)
fn = jax.vmap(fn)
@@ -2452,14 +2452,14 @@ class LaxControlFlowTest(jtu.JaxTestCase):
_ = fn(jnp.array([1])) # doesn't crash
def test_disable_jit_while_loop_with_vmap(self):
- # https://github.com/google/jax/issues/2823
+ # https://github.com/jax-ml/jax/issues/2823
def trivial_while(y):
return lax.while_loop(lambda x: x < 10.0, lambda x: x + 1.0, y)
with jax.disable_jit():
jax.vmap(trivial_while)(jnp.array([3.0,4.0])) # doesn't crash
def test_vmaps_of_while_loop(self):
- # https://github.com/google/jax/issues/3164
+ # https://github.com/jax-ml/jax/issues/3164
def f(x, n): return lax.fori_loop(0, n, lambda _, x: x + 1, x)
x, n = jnp.arange(3), jnp.arange(4)
jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash
@@ -2567,7 +2567,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lambda: core.check_jaxpr(jaxpr))
def test_cond_transformation_rule_with_consts(self):
- # https://github.com/google/jax/pull/9731
+ # https://github.com/jax-ml/jax/pull/9731
@jax.custom_jvp
def f(x):
@@ -2584,7 +2584,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jax.jvp(g, (x,), (x,)) # doesn't crash
def test_cond_excessive_compilation(self):
- # Regression test for https://github.com/google/jax/issues/14058
+ # Regression test for https://github.com/jax-ml/jax/issues/14058
def f(x):
return x + 1
@@ -2632,7 +2632,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
('new_remat', new_checkpoint),
])
def test_scan_vjp_forwards_extensive_residuals(self, remat):
- # https://github.com/google/jax/issues/4510
+ # https://github.com/jax-ml/jax/issues/4510
def cumprod(x):
s = jnp.ones((2, 32), jnp.float32)
return lax.scan(lambda s, x: (x*s, s), s, x)
@@ -2671,7 +2671,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
(jnp.array([1.]), jnp.array([[0., 1., 2., 3., 4.]])), check_dtypes=False)
def test_xla_cpu_gpu_loop_cond_bug(self):
- # https://github.com/google/jax/issues/5900
+ # https://github.com/jax-ml/jax/issues/5900
def deriv(f):
return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0)
@@ -2750,7 +2750,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jax.grad(f)(1.) # doesn't crash
def test_custom_jvp_tangent_cond_transpose(self):
- # https://github.com/google/jax/issues/14026
+ # https://github.com/jax-ml/jax/issues/14026
def mask_fun(arr, choice):
out = (1 - choice) * arr.sum() + choice * (1 - arr.sum())
return out
@@ -2997,7 +2997,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertIsInstance(y, jax.Array)
def test_cond_memory_leak(self):
- # https://github.com/google/jax/issues/12719
+ # https://github.com/jax-ml/jax/issues/12719
def leak():
data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1)
diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py
index 02fecb7b3..d3dada0d7 100644
--- a/tests/lax_metal_test.py
+++ b/tests/lax_metal_test.py
@@ -914,7 +914,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
check_dtypes=False)
def testRoundMethod(self):
- # https://github.com/google/jax/issues/15190
+ # https://github.com/jax-ml/jax/issues/15190
(jnp.arange(3.) / 5.).round() # doesn't crash
@jtu.sample_product(shape=[(5,), (5, 2)])
@@ -1425,7 +1425,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
y=[0, 32, 64, 128],
)
def testIntegerPowerOverflow(self, x, y):
- # Regression test for https://github.com/google/jax/issues/5987
+ # Regression test for https://github.com/jax-ml/jax/issues/5987
args_maker = lambda: [x, y]
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
self._CompileAndCheck(jnp.power, args_maker)
@@ -1536,7 +1536,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testConcatenateAxisNone(self):
- # https://github.com/google/jax/issues/3419
+ # https://github.com/jax-ml/jax/issues/3419
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5]])
jnp.concatenate((a, b), axis=None)
@@ -2768,7 +2768,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testDiffPrepoendScalar(self):
- # Regression test for https://github.com/google/jax/issues/19362
+ # Regression test for https://github.com/jax-ml/jax/issues/19362
x = jnp.arange(10)
result_jax = jnp.diff(x, prepend=x[0], append=x[-1])
@@ -3359,7 +3359,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
_check([jnp.complex128(1)], np.complex128, False)
# Mixed inputs use JAX-style promotion.
- # (regression test for https://github.com/google/jax/issues/8945)
+ # (regression test for https://github.com/jax-ml/jax/issues/8945)
_check([0, np.int16(1)], np.int16, False)
_check([0.0, np.float16(1)], np.float16, False)
@@ -3932,17 +3932,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# TODO(mattjj): test other ndarray-like method overrides
def testNpMean(self):
- # from https://github.com/google/jax/issues/125
+ # from https://github.com/jax-ml/jax/issues/125
x = jnp.eye(3, dtype=float) + 0.
ans = np.mean(x)
self.assertAllClose(ans, np.array(1./3), check_dtypes=False)
def testArangeOnFloats(self):
np_arange = jtu.with_jax_dtype_defaults(np.arange)
- # from https://github.com/google/jax/issues/145
+ # from https://github.com/jax-ml/jax/issues/145
self.assertAllClose(np_arange(0.0, 1.0, 0.1),
jnp.arange(0.0, 1.0, 0.1))
- # from https://github.com/google/jax/issues/3450
+ # from https://github.com/jax-ml/jax/issues/3450
self.assertAllClose(np_arange(2.5),
jnp.arange(2.5))
self.assertAllClose(np_arange(0., 2.5),
@@ -4303,7 +4303,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_op, args_maker)
def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self):
- # https://github.com/google/jax/issues/5088
+ # https://github.com/jax-ml/jax/issues/5088
h = jtu.rand_default(self.rng())((256, 256, 100), np.float32)
g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-1)
@@ -4513,9 +4513,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because
# introducing the convention 0 * inf = 0 leads to silently wrong results in
# some cases. See this comment for details:
- # https://github.com/google/jax/issues/1052#issuecomment-514083352
+ # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352
# def testIssue347(self):
- # # https://github.com/google/jax/issues/347
+ # # https://github.com/jax-ml/jax/issues/347
# def test_fail(x):
# x = jnp.sqrt(jnp.sum(x ** 2, axis=1))
# ones = jnp.ones_like(x)
@@ -4526,7 +4526,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# assert not np.any(np.isnan(result))
def testIssue453(self):
- # https://github.com/google/jax/issues/453
+ # https://github.com/jax-ml/jax/issues/453
a = np.arange(6) + 1
ans = jnp.reshape(a, (3, 2), order='F')
expected = np.reshape(a, (3, 2), order='F')
@@ -4538,7 +4538,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
op=["atleast_1d", "atleast_2d", "atleast_3d"],
)
def testAtLeastNdLiterals(self, dtype, op):
- # Fixes: https://github.com/google/jax/issues/634
+ # Fixes: https://github.com/jax-ml/jax/issues/634
np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype])
jnp_fun = lambda arg: getattr(jnp, op)(arg)
args_maker = lambda: [dtype(2)]
@@ -5147,7 +5147,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.ones(2) + 3 # don't want to warn for scalars
def testStackArrayArgument(self):
- # tests https://github.com/google/jax/issues/1271
+ # tests https://github.com/jax-ml/jax/issues/1271
@jax.jit
def foo(x):
return jnp.stack(x)
@@ -5316,7 +5316,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testZerosShapeErrors(self):
- # see https://github.com/google/jax/issues/1822
+ # see https://github.com/jax-ml/jax/issues/1822
self.assertRaisesRegex(
TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*",
@@ -5334,7 +5334,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x))
def testIntegerPowersArePrecise(self):
- # See https://github.com/google/jax/pull/3036
+ # See https://github.com/jax-ml/jax/pull/3036
# Checks if the squares of float32 integers have no numerical errors.
# It should be satisfied with all integers less than sqrt(2**24).
x = jnp.arange(-2**12, 2**12, dtype=jnp.int32)
@@ -5405,7 +5405,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testIssue2347(self):
- # https://github.com/google/jax/issues/2347
+ # https://github.com/jax-ml/jax/issues/2347
object_list = list[tuple[jnp.array, float, float, jnp.array, bool]]
self.assertRaises(TypeError, jnp.array, object_list)
@@ -5617,7 +5617,7 @@ class ReportedIssuesTests(jtu.JaxTestCase):
return False
- #https://github.com/google/jax/issues/16420
+ #https://github.com/jax-ml/jax/issues/16420
def test_broadcast_dim(self):
x = jnp.arange(2)
f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (0,))
@@ -5640,7 +5640,7 @@ class ReportedIssuesTests(jtu.JaxTestCase):
res = jnp.triu(x)
jtu.check_eq(res, np.triu(x))
- #https://github.com/google/jax/issues/16471
+ #https://github.com/jax-ml/jax/issues/16471
def test_matmul_1d(self):
x = np.array(np.random.rand(3, 3))
y = np.array(np.random.rand(3))
@@ -5650,7 +5650,7 @@ class ReportedIssuesTests(jtu.JaxTestCase):
res = jnp.dot(x, y)
self.assertArraysAllClose(res, np.dot(x,y))
- #https://github.com/google/jax/issues/17175
+ #https://github.com/jax-ml/jax/issues/17175
def test_indexing(self):
x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32)
@jax.vmap
@@ -5661,7 +5661,7 @@ class ReportedIssuesTests(jtu.JaxTestCase):
res = f(idx)
jtu.check_eq(res, np.array([[4., 5., 6.], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.], [1., 2., 3.]]))
- #https://github.com/google/jax/issues/17344
+ #https://github.com/jax-ml/jax/issues/17344
def test_take_along_axis(self):
@jax.jit
def f():
@@ -5672,7 +5672,7 @@ class ReportedIssuesTests(jtu.JaxTestCase):
return jnp.take_along_axis(x, idx, axis=1)
jtu.check_eq(f(), self.dispatchOn([], f))
- #https://github.com/google/jax/issues/17590
+ #https://github.com/jax-ml/jax/issues/17590
def test_in1d(self):
a = np.array([123,2,4])
b = np.array([123,1])
@@ -5688,7 +5688,7 @@ class ReportedIssuesTests(jtu.JaxTestCase):
res = f(x)
jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]]))
- #https://github.com/google/jax/issues/16326
+ #https://github.com/jax-ml/jax/issues/16326
def test_indexing_update2(self):
@jax.jit
def f(x, r):
@@ -5722,7 +5722,7 @@ module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas
print(res)
jtu.check_eq(res, res_ref)
- #https://github.com/google/jax/issues/16366
+ #https://github.com/jax-ml/jax/issues/16366
def test_pad_interior_1(self):
if not ReportedIssuesTests.jax_metal_supported('0.0.6'):
raise unittest.SkipTest("jax-metal version doesn't support it.")
diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py
index 7397cf3e4..ea7bff1d0 100644
--- a/tests/lax_numpy_einsum_test.py
+++ b/tests/lax_numpy_einsum_test.py
@@ -89,7 +89,7 @@ class EinsumTest(jtu.JaxTestCase):
self._check(s, x, y)
def test_two_operands_6(self):
- # based on https://github.com/google/jax/issues/37#issuecomment-448572187
+ # based on https://github.com/jax-ml/jax/issues/37#issuecomment-448572187
r = self.rng()
x = r.randn(2, 1)
y = r.randn(2, 3, 4)
diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py
index bf2785f62..d58a5c2c3 100644
--- a/tests/lax_numpy_indexing_test.py
+++ b/tests/lax_numpy_indexing_test.py
@@ -496,7 +496,7 @@ class IndexingTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp_op_idx, args_maker)
def testIndexApplyBatchingBug(self):
- # https://github.com/google/jax/issues/16655
+ # https://github.com/jax-ml/jax/issues/16655
arr = jnp.array([[1, 2, 3, 4, 5, 6]])
ind = jnp.array([3])
func = lambda a, i: a.at[i].apply(lambda x: x - 1)
@@ -505,7 +505,7 @@ class IndexingTest(jtu.JaxTestCase):
self.assertArraysEqual(out, expected)
def testIndexUpdateScalarBug(self):
- # https://github.com/google/jax/issues/14923
+ # https://github.com/jax-ml/jax/issues/14923
a = jnp.arange(10.)
out = a.at[0].apply(jnp.cos)
self.assertArraysEqual(out, a.at[0].set(1))
@@ -835,7 +835,7 @@ class IndexingTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testBoolean1DIndexingWithEllipsis(self):
- # Regression test for https://github.com/google/jax/issues/8412
+ # Regression test for https://github.com/jax-ml/jax/issues/8412
x = np.arange(24).reshape(4, 3, 2)
idx = (..., np.array([True, False]))
ans = jnp.array(x)[idx]
@@ -843,7 +843,7 @@ class IndexingTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testBoolean1DIndexingWithEllipsis2(self):
- # Regression test for https://github.com/google/jax/issues/9050
+ # Regression test for https://github.com/jax-ml/jax/issues/9050
x = np.arange(3)
idx = (..., np.array([True, False, True]))
ans = jnp.array(x)[idx]
@@ -936,7 +936,7 @@ class IndexingTest(jtu.JaxTestCase):
self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p)
def testTrivialGatherIsntGenerated(self):
- # https://github.com/google/jax/issues/1621
+ # https://github.com/jax-ml/jax/issues/1621
jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4))
self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
self.assertNotIn('gather', str(jaxpr))
@@ -988,14 +988,14 @@ class IndexingTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testBooleanIndexingShapeMismatch(self):
- # Regression test for https://github.com/google/jax/issues/7329
+ # Regression test for https://github.com/jax-ml/jax/issues/7329
x = jnp.arange(4)
idx = jnp.array([True, False])
with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"):
x[idx]
def testBooleanIndexingWithNone(self):
- # Regression test for https://github.com/google/jax/issues/18542
+ # Regression test for https://github.com/jax-ml/jax/issues/18542
x = jnp.arange(6).reshape(2, 3)
idx = (None, jnp.array([True, False]))
ans = x[idx]
@@ -1003,7 +1003,7 @@ class IndexingTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected)
def testBooleanIndexingWithNoneAndEllipsis(self):
- # Regression test for https://github.com/google/jax/issues/18542
+ # Regression test for https://github.com/jax-ml/jax/issues/18542
x = jnp.arange(6).reshape(2, 3)
mask = jnp.array([True, False, False])
ans = x[None, ..., mask]
@@ -1011,7 +1011,7 @@ class IndexingTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected)
def testBooleanIndexingWithEllipsisAndNone(self):
- # Regression test for https://github.com/google/jax/issues/18542
+ # Regression test for https://github.com/jax-ml/jax/issues/18542
x = jnp.arange(6).reshape(2, 3)
mask = jnp.array([True, False, False])
ans = x[..., None, mask]
@@ -1038,7 +1038,7 @@ class IndexingTest(jtu.JaxTestCase):
[(3, 4, 5), (3, 0)],
)
def testEmptyBooleanIndexing(self, x_shape, m_shape):
- # Regression test for https://github.com/google/jax/issues/22886
+ # Regression test for https://github.com/jax-ml/jax/issues/22886
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)]
@@ -1120,7 +1120,7 @@ class IndexingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, msg):
jnp.zeros((2, 3))[:, 'abc']
- def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
+ def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245
x = jnp.arange(5, dtype=jnp.int32) + 1
self.assertAllClose(x, x[:10])
@@ -1613,7 +1613,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testIndexDtypeError(self):
- # https://github.com/google/jax/issues/2795
+ # https://github.com/jax-ml/jax/issues/2795
jnp.array(1) # get rid of startup warning
with self.assertNoWarnings():
jnp.zeros(5).at[::2].set(1)
@@ -1647,13 +1647,13 @@ class IndexedUpdateTest(jtu.JaxTestCase):
x.at[normalize(idx)].set(0)
def testIndexedUpdateAliasingBug(self):
- # https://github.com/google/jax/issues/7461
+ # https://github.com/jax-ml/jax/issues/7461
fn = lambda x: x.at[1:].set(1 + x[:-1])
y = jnp.zeros(8)
self.assertArraysEqual(fn(y), jax.jit(fn)(y))
def testScatterValuesCastToTargetDType(self):
- # https://github.com/google/jax/issues/15505
+ # https://github.com/jax-ml/jax/issues/15505
a = jnp.zeros(1, dtype=jnp.uint32)
val = 2**32 - 1 # too large for int32
diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py
index d9d6fa464..45a780c9f 100644
--- a/tests/lax_numpy_operators_test.py
+++ b/tests/lax_numpy_operators_test.py
@@ -697,7 +697,7 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
self.assertIsInstance(jax.jit(operator.mul)(b, a), MyArray)
def testI0Grad(self):
- # Regression test for https://github.com/google/jax/issues/11479
+ # Regression test for https://github.com/jax-ml/jax/issues/11479
dx = jax.grad(jax.numpy.i0)(0.0)
self.assertArraysEqual(dx, 0.0)
diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py
index 402e206ef..33830c541 100644
--- a/tests/lax_numpy_reducers_test.py
+++ b/tests/lax_numpy_reducers_test.py
@@ -425,7 +425,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
if (shape in [()] + scalar_shapes and
dtype in [jnp.int16, jnp.uint16] and
jnp_op in [jnp.min, jnp.max]):
- self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.")
+ self.skipTest("Known XLA failure; see https://github.com/jax-ml/jax/issues/4971.")
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
@@ -582,7 +582,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
size=[0, 1, 2]
)
def testStdOrVarLargeDdofReturnsNan(self, jnp_fn, size):
- # test for https://github.com/google/jax/issues/21330
+ # test for https://github.com/jax-ml/jax/issues/21330
x = jnp.arange(size)
self.assertTrue(np.isnan(jnp_fn(x, ddof=size)))
self.assertTrue(np.isnan(jnp_fn(x, ddof=size + 1)))
@@ -622,7 +622,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
atol=tol)
def testNanStdGrad(self):
- # Regression test for https://github.com/google/jax/issues/8128
+ # Regression test for https://github.com/jax-ml/jax/issues/8128
x = jnp.arange(5.0).at[0].set(jnp.nan)
y = jax.grad(jnp.nanvar)(x)
self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75]), check_dtypes=False)
@@ -740,7 +740,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
@unittest.skipIf(not config.enable_x64.value, "test requires X64")
@jtu.run_on_devices("cpu") # test is for CPU float64 precision
def testPercentilePrecision(self):
- # Regression test for https://github.com/google/jax/issues/8513
+ # Regression test for https://github.com/jax-ml/jax/issues/8513
x = jnp.float64([1, 2, 3, 4, 7, 10])
self.assertEqual(jnp.percentile(x, 50), 3.5)
@@ -778,14 +778,14 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)
def testMeanLargeArray(self):
- # https://github.com/google/jax/issues/15068
+ # https://github.com/jax-ml/jax/issues/15068
raise unittest.SkipTest("test is slow, but it passes!")
x = jnp.ones((16, 32, 1280, 4096), dtype='int8')
self.assertEqual(1.0, jnp.mean(x))
self.assertEqual(1.0, jnp.mean(x, where=True))
def testStdLargeArray(self):
- # https://github.com/google/jax/issues/15068
+ # https://github.com/jax-ml/jax/issues/15068
raise unittest.SkipTest("test is slow, but it passes!")
x = jnp.ones((16, 32, 1280, 4096), dtype='int8')
self.assertEqual(0.0, jnp.std(x))
diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py
index ddf42a28e..d3f9f2d61 100644
--- a/tests/lax_numpy_test.py
+++ b/tests/lax_numpy_test.py
@@ -1043,7 +1043,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
check_dtypes=False)
def testRoundMethod(self):
- # https://github.com/google/jax/issues/15190
+ # https://github.com/jax-ml/jax/issues/15190
(jnp.arange(3.) / 5.).round() # doesn't crash
@jtu.sample_product(shape=[(5,), (5, 2)])
@@ -1571,7 +1571,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
y=[0, 32, 64, 128],
)
def testIntegerPowerOverflow(self, x, y):
- # Regression test for https://github.com/google/jax/issues/5987
+ # Regression test for https://github.com/jax-ml/jax/issues/5987
args_maker = lambda: [x, y]
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
self._CompileAndCheck(jnp.power, args_maker)
@@ -1713,7 +1713,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testConcatenateAxisNone(self):
- # https://github.com/google/jax/issues/3419
+ # https://github.com/jax-ml/jax/issues/3419
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5]])
jnp.concatenate((a, b), axis=None)
@@ -2977,7 +2977,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testDiffPrepoendScalar(self):
- # Regression test for https://github.com/google/jax/issues/19362
+ # Regression test for https://github.com/jax-ml/jax/issues/19362
x = jnp.arange(10)
result_jax = jnp.diff(x, prepend=x[0], append=x[-1])
@@ -3611,7 +3611,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
_check([jnp.complex128(1)], np.complex128, False)
# Mixed inputs use JAX-style promotion.
- # (regression test for https://github.com/google/jax/issues/8945)
+ # (regression test for https://github.com/jax-ml/jax/issues/8945)
_check([0, np.int16(1)], np.int16, False)
_check([0.0, np.float16(1)], np.float16, False)
@@ -4229,17 +4229,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# TODO(mattjj): test other ndarray-like method overrides
def testNpMean(self):
- # from https://github.com/google/jax/issues/125
+ # from https://github.com/jax-ml/jax/issues/125
x = jnp.eye(3, dtype=float) + 0.
ans = np.mean(x)
self.assertAllClose(ans, np.array(1./3), check_dtypes=False)
def testArangeOnFloats(self):
np_arange = jtu.with_jax_dtype_defaults(np.arange)
- # from https://github.com/google/jax/issues/145
+ # from https://github.com/jax-ml/jax/issues/145
self.assertAllClose(np_arange(0.0, 1.0, 0.1),
jnp.arange(0.0, 1.0, 0.1))
- # from https://github.com/google/jax/issues/3450
+ # from https://github.com/jax-ml/jax/issues/3450
self.assertAllClose(np_arange(2.5),
jnp.arange(2.5))
self.assertAllClose(np_arange(0., 2.5),
@@ -4400,7 +4400,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
dtype=unsigned_dtypes,
)
def testPartitionUnsignedWithZeros(self, kth, dtype):
- # https://github.com/google/jax/issues/22137
+ # https://github.com/jax-ml/jax/issues/22137
max_val = np.iinfo(dtype).max
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2]], dtype=dtype)
axis = -1
@@ -4441,7 +4441,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
dtype=unsigned_dtypes,
)
def testArgpartitionUnsignedWithZeros(self, kth, dtype):
- # https://github.com/google/jax/issues/22137
+ # https://github.com/jax-ml/jax/issues/22137
max_val = np.iinfo(dtype).max
arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2, 3]], dtype=dtype)
axis = -1
@@ -4616,7 +4616,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_op, args_maker)
def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self):
- # https://github.com/google/jax/issues/5088
+ # https://github.com/jax-ml/jax/issues/5088
h = jtu.rand_default(self.rng())((256, 256, 100), np.float32)
g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8)
q0 = jnp.take_along_axis(h, g, axis=-1)
@@ -4837,9 +4837,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# NOTE(mattjj): I disabled this test when removing lax._safe_mul because
# introducing the convention 0 * inf = 0 leads to silently wrong results in
# some cases. See this comment for details:
- # https://github.com/google/jax/issues/1052#issuecomment-514083352
+ # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352
# def testIssue347(self):
- # # https://github.com/google/jax/issues/347
+ # # https://github.com/jax-ml/jax/issues/347
# def test_fail(x):
# x = jnp.sqrt(jnp.sum(x ** 2, axis=1))
# ones = jnp.ones_like(x)
@@ -4850,7 +4850,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# assert not np.any(np.isnan(result))
def testIssue453(self):
- # https://github.com/google/jax/issues/453
+ # https://github.com/jax-ml/jax/issues/453
a = np.arange(6) + 1
ans = jnp.reshape(a, (3, 2), order='F')
expected = np.reshape(a, (3, 2), order='F')
@@ -4861,7 +4861,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
op=["atleast_1d", "atleast_2d", "atleast_3d"],
)
def testAtLeastNdLiterals(self, dtype, op):
- # Fixes: https://github.com/google/jax/issues/634
+ # Fixes: https://github.com/jax-ml/jax/issues/634
np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype])
jnp_fun = lambda arg: getattr(jnp, op)(arg)
args_maker = lambda: [dtype(2)]
@@ -5489,7 +5489,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.ones(2) + 3 # don't want to warn for scalars
def testStackArrayArgument(self):
- # tests https://github.com/google/jax/issues/1271
+ # tests https://github.com/jax-ml/jax/issues/1271
@jax.jit
def foo(x):
return jnp.stack(x)
@@ -5536,7 +5536,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_op, args_maker)
def testBroadcastToInvalidShape(self):
- # Regression test for https://github.com/google/jax/issues/20533
+ # Regression test for https://github.com/jax-ml/jax/issues/20533
x = jnp.zeros((3, 4, 5))
with self.assertRaisesRegex(
ValueError, "Cannot broadcast to shape with fewer dimensions"):
@@ -5688,7 +5688,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp.gradient, args_maker)
def testZerosShapeErrors(self):
- # see https://github.com/google/jax/issues/1822
+ # see https://github.com/jax-ml/jax/issues/1822
self.assertRaisesRegex(
TypeError,
"Shapes must be 1D sequences of concrete values of integer type.*",
@@ -5706,7 +5706,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x))
def testIntegerPowersArePrecise(self):
- # See https://github.com/google/jax/pull/3036
+ # See https://github.com/jax-ml/jax/pull/3036
# Checks if the squares of float32 integers have no numerical errors.
# It should be satisfied with all integers less than sqrt(2**24).
x = jnp.arange(-2**12, 2**12, dtype=jnp.int32)
@@ -5777,7 +5777,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun, args_maker)
def testIssue2347(self):
- # https://github.com/google/jax/issues/2347
+ # https://github.com/jax-ml/jax/issues/2347
object_list = list[tuple[jnp.array, float, float, jnp.array, bool]]
self.assertRaises(TypeError, jnp.array, object_list)
@@ -6096,7 +6096,7 @@ class NumpyGradTests(jtu.JaxTestCase):
jax.grad(lambda x: jnp.sinc(x).sum())(jnp.arange(10.)) # doesn't crash
def testTakeAlongAxisIssue1521(self):
- # https://github.com/google/jax/issues/1521
+ # https://github.com/jax-ml/jax/issues/1521
idx = jnp.repeat(jnp.arange(3), 10).reshape((30, 1))
def f(x):
@@ -6207,7 +6207,7 @@ class NumpySignaturesTest(jtu.JaxTestCase):
if name == "clip":
# JAX's support of the Array API spec for clip, and the way it handles
# backwards compatibility was introduced in
- # https://github.com/google/jax/pull/20550 with a different signature
+ # https://github.com/jax-ml/jax/pull/20550 with a different signature
# from the one in numpy, introduced in
# https://github.com/numpy/numpy/pull/26724
# TODO(dfm): After our deprecation period for the clip arguments ends
diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py
index c65df8aa8..630d89f53 100644
--- a/tests/lax_numpy_ufuncs_test.py
+++ b/tests/lax_numpy_ufuncs_test.py
@@ -379,7 +379,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp_fun_at, args_maker)
def test_frompyfunc_at_broadcasting(self):
- # Regression test for https://github.com/google/jax/issues/18004
+ # Regression test for https://github.com/jax-ml/jax/issues/18004
args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]),
np.arange(9.0).reshape(3, 3)]
def np_fun(x, idx, y):
diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py
index 56fd0f781..985dba484 100644
--- a/tests/lax_numpy_vectorize_test.py
+++ b/tests/lax_numpy_vectorize_test.py
@@ -258,7 +258,7 @@ class VectorizeTest(jtu.JaxTestCase):
f(*args)
def test_rank_promotion_error(self):
- # Regression test for https://github.com/google/jax/issues/22305
+ # Regression test for https://github.com/jax-ml/jax/issues/22305
f = jnp.vectorize(jnp.add, signature="(),()->()")
rank2 = jnp.zeros((10, 10))
rank1 = jnp.zeros(10)
diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py
index d2e64833b..303c67c58 100644
--- a/tests/lax_scipy_sparse_test.py
+++ b/tests/lax_scipy_sparse_test.py
@@ -469,7 +469,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertTrue(dtypes.is_weakly_typed(x))
def test_linear_solve_batching_via_jacrev(self):
- # See https://github.com/google/jax/issues/14249
+ # See https://github.com/jax-ml/jax/issues/14249
rng = np.random.RandomState(0)
M = rng.randn(5, 5)
A = np.dot(M, M.T)
diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py
index df19750b1..8a8b1dd42 100644
--- a/tests/lax_scipy_test.py
+++ b/tests/lax_scipy_test.py
@@ -165,7 +165,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(sign * np.exp(logsumexp).astype(x.dtype), expected_sumexp, rtol=tol)
def testLogSumExpZeros(self):
- # Regression test for https://github.com/google/jax/issues/5370
+ # Regression test for https://github.com/jax-ml/jax/issues/5370
scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
@@ -173,14 +173,14 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker)
def testLogSumExpOnes(self):
- # Regression test for https://github.com/google/jax/issues/7390
+ # Regression test for https://github.com/jax-ml/jax/issues/7390
args_maker = lambda: [np.ones(4, dtype='float32')]
with jax.debug_infs(True):
self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker)
self._CompileAndCheck(lsp_special.logsumexp, args_maker)
def testLogSumExpNans(self):
- # Regression test for https://github.com/google/jax/issues/7634
+ # Regression test for https://github.com/jax-ml/jax/issues/7634
with jax.debug_nans(True):
with jax.disable_jit():
result = lsp_special.logsumexp(1.0)
@@ -246,7 +246,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)
def testGradOfXlogyAtZero(self):
- # https://github.com/google/jax/issues/15598
+ # https://github.com/jax-ml/jax/issues/15598
x0, y0 = 0.0, 3.0
d_xlog1py_dx = jax.grad(lsp_special.xlogy, argnums=0)(x0, y0)
self.assertAllClose(d_xlog1py_dx, lax.log(y0))
@@ -260,7 +260,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False)
def testGradOfXlog1pyAtZero(self):
- # https://github.com/google/jax/issues/15598
+ # https://github.com/jax-ml/jax/issues/15598
x0, y0 = 0.0, 3.0
d_xlog1py_dx = jax.grad(lsp_special.xlog1py, argnums=0)(x0, y0)
self.assertAllClose(d_xlog1py_dx, lax.log1p(y0))
@@ -284,7 +284,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
rtol=.1, eps=1e-3)
def testGradOfEntrAtZero(self):
- # https://github.com/google/jax/issues/15709
+ # https://github.com/jax-ml/jax/issues/15709
self.assertEqual(jax.jacfwd(lsp_special.entr)(0.0), jnp.inf)
self.assertEqual(jax.jacrev(lsp_special.entr)(0.0), jnp.inf)
diff --git a/tests/lax_test.py b/tests/lax_test.py
index 3d31bcb7d..0ae5f77af 100644
--- a/tests/lax_test.py
+++ b/tests/lax_test.py
@@ -1002,7 +1002,7 @@ class LaxTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(fun_via_grad, fun, args_maker)
def testConvTransposePaddingList(self):
- # Regression test for https://github.com/google/jax/discussions/8695
+ # Regression test for https://github.com/jax-ml/jax/discussions/8695
a = jnp.ones((28,28))
b = jnp.ones((3,3))
c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1))
@@ -1280,7 +1280,7 @@ class LaxTest(jtu.JaxTestCase):
self._CompileAndCheck(op, args_maker)
def testBroadcastInDimOperandShapeTranspose(self):
- # Regression test for https://github.com/google/jax/issues/5276
+ # Regression test for https://github.com/jax-ml/jax/issues/5276
def f(x):
return lax.broadcast_in_dim(x, (2, 3, 4), broadcast_dimensions=(0, 1, 2)).sum()
def g(x):
@@ -1681,7 +1681,7 @@ class LaxTest(jtu.JaxTestCase):
lax.dynamic_update_slice, args_maker)
def testDynamicUpdateSliceBatched(self):
- # Regression test for https://github.com/google/jax/issues/9083
+ # Regression test for https://github.com/jax-ml/jax/issues/9083
x = jnp.arange(5)
y = jnp.arange(6, 9)
ind = jnp.arange(6)
@@ -2236,7 +2236,7 @@ class LaxTest(jtu.JaxTestCase):
self.assertEqual(shape, result.shape)
def testReduceWindowWithEmptyOutput(self):
- # https://github.com/google/jax/issues/10315
+ # https://github.com/jax-ml/jax/issues/10315
shape = (5, 3, 2)
operand, padding, strides = np.ones(shape), 'VALID', (1,) * len(shape)
out = jax.eval_shape(lambda x: lax.reduce_window(x, 0., lax.add, padding=padding,
@@ -2859,13 +2859,13 @@ class LaxTest(jtu.JaxTestCase):
op(2+3j, 4+5j)
def test_population_count_booleans_not_supported(self):
- # https://github.com/google/jax/issues/3886
+ # https://github.com/jax-ml/jax/issues/3886
msg = "population_count does not accept dtype bool"
with self.assertRaisesRegex(TypeError, msg):
lax.population_count(True)
def test_conv_general_dilated_different_input_ranks_error(self):
- # https://github.com/google/jax/issues/4316
+ # https://github.com/jax-ml/jax/issues/4316
msg = ("conv_general_dilated lhs and rhs must have the same number of "
"dimensions")
dimension_numbers = lax.ConvDimensionNumbers(lhs_spec=(0, 1, 2),
@@ -2885,7 +2885,7 @@ class LaxTest(jtu.JaxTestCase):
lax.conv_general_dilated(lhs, rhs, **kwargs)
def test_window_strides_dimension_shape_rule(self):
- # https://github.com/google/jax/issues/5087
+ # https://github.com/jax-ml/jax/issues/5087
msg = ("conv_general_dilated window and window_strides must have "
"the same number of dimensions")
lhs = jax.numpy.zeros((1, 1, 3, 3))
@@ -2894,7 +2894,7 @@ class LaxTest(jtu.JaxTestCase):
jax.lax.conv(lhs, rhs, [1], 'SAME')
def test_reduce_window_scalar_init_value_shape_rule(self):
- # https://github.com/google/jax/issues/4574
+ # https://github.com/jax-ml/jax/issues/4574
args = { "operand": np.ones((4, 4), dtype=np.int32)
, "init_value": np.zeros((1,), dtype=np.int32)
, "computation": lax.max
@@ -3045,7 +3045,7 @@ class LaxTest(jtu.JaxTestCase):
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])
def test_dot_general_batching_python_builtin_arg(self):
- # https://github.com/google/jax/issues/16805
+ # https://github.com/jax-ml/jax/issues/16805
@jax.remat
def f(x):
return jax.lax.dot_general(x, x, (([], []), ([], [])))
@@ -3053,7 +3053,7 @@ class LaxTest(jtu.JaxTestCase):
jax.hessian(f)(1.0) # don't crash
def test_constant_folding_complex_to_real_scan_regression(self):
- # regression test for github.com/google/jax/issues/19059
+ # regression test for github.com/jax-ml/jax/issues/19059
def g(hiddens):
hiddens_aug = jnp.vstack((hiddens[0], hiddens))
new_hiddens = hiddens_aug.copy()
@@ -3088,7 +3088,7 @@ class LaxTest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(asarray_closure)()
self.assertLen(jaxpr.eqns, 0)
- # Regression test for https://github.com/google/jax/issues/19334
+ # Regression test for https://github.com/jax-ml/jax/issues/19334
# lax.asarray as a closure should not trigger transfer guard.
with jax.transfer_guard('disallow'):
jax.jit(asarray_closure)()
@@ -3254,7 +3254,7 @@ class LazyConstantTest(jtu.JaxTestCase):
def testUnaryWeakTypes(self, op_name, rec_dtypes):
"""Test that all lax unary ops propagate weak_type information appropriately."""
if op_name == "bitwise_not":
- raise unittest.SkipTest("https://github.com/google/jax/issues/12066")
+ raise unittest.SkipTest("https://github.com/jax-ml/jax/issues/12066")
# Find a valid dtype for the function.
for dtype in [float, int, complex, bool]:
dtype = dtypes.canonicalize_dtype(dtype)
@@ -3648,7 +3648,7 @@ class CustomElementTypesTest(jtu.JaxTestCase):
self.assertEqual(ys.shape, (3, 2, 1))
def test_gather_batched_index_dtype(self):
- # Regression test for https://github.com/google/jax/issues/16557
+ # Regression test for https://github.com/jax-ml/jax/issues/16557
dtype = jnp.int8
size = jnp.iinfo(dtype).max + 10
indices = jnp.zeros(size, dtype=dtype)
diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py
index 37d51c04f..37a0011e7 100644
--- a/tests/lax_vmap_test.py
+++ b/tests/lax_vmap_test.py
@@ -693,8 +693,8 @@ class LaxVmapTest(jtu.JaxTestCase):
# TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU
@jtu.skip_on_devices("gpu")
def test_variadic_reduce_window(self):
- # https://github.com/google/jax/discussions/9818 and
- # https://github.com/google/jax/issues/9837
+ # https://github.com/jax-ml/jax/discussions/9818 and
+ # https://github.com/jax-ml/jax/issues/9837
def normpool(x):
norms = jnp.linalg.norm(x, axis=-1)
idxs = jnp.arange(x.shape[0])
diff --git a/tests/linalg_test.py b/tests/linalg_test.py
index 446e10abd..15963b10b 100644
--- a/tests/linalg_test.py
+++ b/tests/linalg_test.py
@@ -329,7 +329,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.run_on_devices("cpu")
def testEigvalsInf(self):
- # https://github.com/google/jax/issues/2661
+ # https://github.com/jax-ml/jax/issues/2661
x = jnp.array([[jnp.inf]])
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
@@ -1004,7 +1004,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16):
- # Regression test for https://github.com/google/jax/issues/10530
+ # Regression test for https://github.com/jax-ml/jax/issues/10530
rng = jtu.rand_default(self.rng())
arr = rng(shape, dtype)
if jtu.test_device_matches(['cpu']):
@@ -1422,7 +1422,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
@parameterized.parameters(lax_linalg.lu, lax_linalg._lu_python)
def testLuOnZeroMatrix(self, lu):
- # Regression test for https://github.com/google/jax/issues/19076
+ # Regression test for https://github.com/jax-ml/jax/issues/19076
x = jnp.zeros((2, 2), dtype=np.float32)
x_lu, _, _ = lu(x)
self.assertArraysEqual(x_lu, x)
diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py
index 5daa0e0e5..a3b6b1efa 100644
--- a/tests/multi_device_test.py
+++ b/tests/multi_device_test.py
@@ -174,7 +174,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
def test_closed_over_values_device_placement(self):
- # see https://github.com/google/jax/issues/1431
+ # see https://github.com/jax-ml/jax/issues/1431
devices = self.get_devices()
def f(): return lax.add(3., 4.)
diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py
index 4f2e36c64..4697ba8b2 100644
--- a/tests/multibackend_test.py
+++ b/tests/multibackend_test.py
@@ -148,7 +148,7 @@ class MultiBackendTest(jtu.JaxTestCase):
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
def test_closed_over_values_device_placement(self):
- # see https://github.com/google/jax/issues/1431
+ # see https://github.com/jax-ml/jax/issues/1431
def f(): return jnp.add(3., 4.)
self.assertNotEqual(jax.jit(f)().devices(),
{jax.devices('cpu')[0]})
@@ -186,7 +186,7 @@ class MultiBackendTest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_indexing(self):
- # https://github.com/google/jax/issues/2905
+ # https://github.com/jax-ml/jax/issues/2905
cpus = jax.devices("cpu")
x = jax.device_put(np.ones(2), cpus[0])
@@ -195,7 +195,7 @@ class MultiBackendTest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_sum(self):
- # https://github.com/google/jax/issues/2905
+ # https://github.com/jax-ml/jax/issues/2905
cpus = jax.devices("cpu")
x = jax.device_put(np.ones(2), cpus[0])
diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py
index 40c5b6b2e..5c84f8c69 100644
--- a/tests/multiprocess_gpu_test.py
+++ b/tests/multiprocess_gpu_test.py
@@ -46,7 +46,7 @@ jax.config.parse_flags_with_absl()
@unittest.skipIf(not portpicker, "Test requires portpicker")
class DistributedTest(jtu.JaxTestCase):
- # TODO(phawkins): Enable after https://github.com/google/jax/issues/11222
+ # TODO(phawkins): Enable after https://github.com/jax-ml/jax/issues/11222
# is fixed.
@unittest.SkipTest
def testInitializeAndShutdown(self):
diff --git a/tests/nn_test.py b/tests/nn_test.py
index be07de184..416beffce 100644
--- a/tests/nn_test.py
+++ b/tests/nn_test.py
@@ -325,12 +325,12 @@ class NNFunctionsTest(jtu.JaxTestCase):
self.assertEqual(out.dtype, dtype)
def testEluMemory(self):
- # see https://github.com/google/jax/pull/1640
+ # see https://github.com/jax-ml/jax/pull/1640
with jax.enable_checks(False): # With checks we materialize the array
jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom
def testHardTanhMemory(self):
- # see https://github.com/google/jax/pull/1640
+ # see https://github.com/jax-ml/jax/pull/1640
with jax.enable_checks(False): # With checks we materialize the array
jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
@@ -367,7 +367,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.parameters([nn.softmax, nn.log_softmax])
def testSoftmaxWhereGrad(self, fn):
- # regression test for https://github.com/google/jax/issues/19490
+ # regression test for https://github.com/jax-ml/jax/issues/19490
x = jnp.array([36., 10000.])
mask = x < 1000
@@ -443,7 +443,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
self.assertAllClose(actual, expected)
def testOneHotConcretizationError(self):
- # https://github.com/google/jax/issues/3654
+ # https://github.com/jax-ml/jax/issues/3654
msg = r"in jax.nn.one_hot argument `num_classes`"
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
jax.jit(nn.one_hot)(3, 5)
@@ -463,7 +463,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
nn.tanh # doesn't crash
def testCustomJVPLeak(self):
- # https://github.com/google/jax/issues/8171
+ # https://github.com/jax-ml/jax/issues/8171
@jax.jit
def fwd():
a = jnp.array(1.)
@@ -479,7 +479,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
fwd() # doesn't crash
def testCustomJVPLeak2(self):
- # https://github.com/google/jax/issues/8171
+ # https://github.com/jax-ml/jax/issues/8171
# The above test uses jax.nn.sigmoid, as in the original #8171, but that
# function no longer actually has a custom_jvp! So we inline the old def.
diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb
index e8cd40a67..f5dcff837 100644
--- a/tests/notebooks/colab_cpu.ipynb
+++ b/tests/notebooks/colab_cpu.ipynb
@@ -20,7 +20,7 @@
"colab_type": "text"
},
"source": [
- "
"
+ "
"
]
},
{
diff --git a/tests/notebooks/colab_gpu.ipynb b/tests/notebooks/colab_gpu.ipynb
index 8352bdaf7..2335455e6 100644
--- a/tests/notebooks/colab_gpu.ipynb
+++ b/tests/notebooks/colab_gpu.ipynb
@@ -7,7 +7,7 @@
"id": "view-in-github"
},
"source": [
- "
"
+ "
"
]
},
{
diff --git a/tests/ode_test.py b/tests/ode_test.py
index 834745e1c..acdfa1fc6 100644
--- a/tests/ode_test.py
+++ b/tests/ode_test.py
@@ -139,7 +139,7 @@ class ODETest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_odeint_vmap_grad(self):
- # https://github.com/google/jax/issues/2531
+ # https://github.com/jax-ml/jax/issues/2531
def dx_dt(x, *args):
return 0.1 * x
@@ -169,7 +169,7 @@ class ODETest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_disable_jit_odeint_with_vmap(self):
- # https://github.com/google/jax/issues/2598
+ # https://github.com/jax-ml/jax/issues/2598
with jax.disable_jit():
t = jnp.array([0.0, 1.0])
x0_eval = jnp.zeros((5, 2))
@@ -178,7 +178,7 @@ class ODETest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_grad_closure(self):
- # simplification of https://github.com/google/jax/issues/2718
+ # simplification of https://github.com/jax-ml/jax/issues/2718
def experiment(x):
def model(y, t):
return -x * y
@@ -188,7 +188,7 @@ class ODETest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_grad_closure_with_vmap(self):
- # https://github.com/google/jax/issues/2718
+ # https://github.com/jax-ml/jax/issues/2718
@jax.jit
def experiment(x):
def model(y, t):
@@ -209,7 +209,7 @@ class ODETest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_forward_mode_error(self):
- # https://github.com/google/jax/issues/3558
+ # https://github.com/jax-ml/jax/issues/3558
def f(k):
return odeint(lambda x, t: k*x, 1., jnp.linspace(0, 1., 50)).sum()
@@ -219,7 +219,7 @@ class ODETest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_closure_nondiff(self):
- # https://github.com/google/jax/issues/3584
+ # https://github.com/jax-ml/jax/issues/3584
def dz_dt(z, t):
return jnp.stack([z[0], z[1]])
@@ -232,8 +232,8 @@ class ODETest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu", "gpu")
def test_complex_odeint(self):
- # https://github.com/google/jax/issues/3986
- # https://github.com/google/jax/issues/8757
+ # https://github.com/jax-ml/jax/issues/3986
+ # https://github.com/jax-ml/jax/issues/8757
def dy_dt(y, t, alpha):
return alpha * y * jnp.exp(-t).astype(y.dtype)
diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py
index b7710d9b9..c4eca0707 100644
--- a/tests/optimizers_test.py
+++ b/tests/optimizers_test.py
@@ -260,7 +260,7 @@ class OptimizerTests(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testIssue758(self):
- # code from https://github.com/google/jax/issues/758
+ # code from https://github.com/jax-ml/jax/issues/758
# this is more of a scan + jacfwd/jacrev test, but it lives here to use the
# optimizers.py code
diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py
index e2f0e2152..7692294cd 100644
--- a/tests/pallas/gpu_ops_test.py
+++ b/tests/pallas/gpu_ops_test.py
@@ -178,7 +178,7 @@ class FusedAttentionTest(PallasBaseTest):
(1, 384, 8, 64, True, True, True, {}),
(1, 384, 8, 64, True, True, False, {}),
(2, 384, 8, 64, True, True, True, {}),
- # regression test: https://github.com/google/jax/pull/17314
+ # regression test: https://github.com/jax-ml/jax/pull/17314
(1, 384, 8, 64, True, False, False, {'block_q': 128, 'block_k': 64}),
]
]
@@ -419,7 +419,7 @@ class SoftmaxTest(PallasBaseTest):
}[dtype]
# We upcast to float32 because NumPy <2.0 does not handle custom dtypes
- # properly. See https://github.com/google/jax/issues/11014.
+ # properly. See https://github.com/jax-ml/jax/issues/11014.
np.testing.assert_allclose(
softmax.softmax(x, axis=-1).astype(jnp.float32),
jax.nn.softmax(x, axis=-1).astype(jnp.float32),
diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py
index 63c3148e8..d8f890c06 100644
--- a/tests/pallas/ops_test.py
+++ b/tests/pallas/ops_test.py
@@ -798,7 +798,7 @@ class OpsExtraTest(PallasBaseTest):
np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6)
def test_abs_weak_type(self):
- # see https://github.com/google/jax/issues/23191
+ # see https://github.com/jax-ml/jax/issues/23191
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32),
)
@@ -999,7 +999,7 @@ class OpsExtraTest(PallasBaseTest):
x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype)
# We upcast to float32 because NumPy <2.0 does not handle custom dtypes
- # properly. See https://github.com/google/jax/issues/11014.
+ # properly. See https://github.com/jax-ml/jax/issues/11014.
np.testing.assert_allclose(
kernel(x).astype(jnp.float32),
jnp.tanh(x).astype(jnp.float32),
@@ -1260,7 +1260,7 @@ class OpsExtraTest(PallasBaseTest):
np.testing.assert_array_equal(out, o_new)
def test_strided_load(self):
- # Reproducer from https://github.com/google/jax/issues/20895.
+ # Reproducer from https://github.com/jax-ml/jax/issues/20895.
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py
index 5ee30ba33..aec7fd54c 100644
--- a/tests/pallas/pallas_test.py
+++ b/tests/pallas/pallas_test.py
@@ -484,7 +484,7 @@ class PallasCallTest(PallasBaseTest):
self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32))
def test_hoisted_consts(self):
- # See https://github.com/google/jax/issues/21557.
+ # See https://github.com/jax-ml/jax/issues/21557.
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))
diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py
index 4b3f47e6f..fefccfe7e 100644
--- a/tests/pallas/pallas_vmap_test.py
+++ b/tests/pallas/pallas_vmap_test.py
@@ -209,7 +209,7 @@ class PallasCallVmapTest(PallasBaseTest):
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.skip_on_devices("cpu") # Test is very slow on CPU
def test_small_large_vmap(self):
- # Catches https://github.com/google/jax/issues/18361
+ # Catches https://github.com/jax-ml/jax/issues/18361
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
grid=(2,))
diff --git a/tests/pjit_test.py b/tests/pjit_test.py
index 57106948f..c20084c3c 100644
--- a/tests/pjit_test.py
+++ b/tests/pjit_test.py
@@ -1510,7 +1510,7 @@ class CustomPartitionerTest(jtu.JaxTestCase):
def test_custom_partitioner_with_scan(self):
self.skip_if_custom_partitioning_not_supported()
- # This is a reproducer from https://github.com/google/jax/issues/20864.
+ # This is a reproducer from https://github.com/jax-ml/jax/issues/20864.
@custom_partitioning
def f(x):
@@ -1921,7 +1921,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(s.data, input_data)
def test_sds_full_like(self):
- # https://github.com/google/jax/issues/20390
+ # https://github.com/jax-ml/jax/issues/20390
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
x = jax.ShapeDtypeStruct((4, 4), jnp.float32, sharding=s)
@@ -4113,7 +4113,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_spmd_preserves_input_sharding_vmap_grad(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy doesn't support PositionalSharding")
- # https://github.com/google/jax/issues/20710
+ # https://github.com/jax-ml/jax/issues/20710
n_devices = jax.device_count()
sharding = PositionalSharding(jax.devices())
diff --git a/tests/pmap_test.py b/tests/pmap_test.py
index 8b121d91a..d7dcc7ba3 100644
--- a/tests/pmap_test.py
+++ b/tests/pmap_test.py
@@ -122,7 +122,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testDeviceBufferToArray(self):
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
- # Changed in https://github.com/google/jax/pull/10584 not to access
+ # Changed in https://github.com/jax-ml/jax/pull/10584 not to access
# sda.device_buffers, which isn't supported, and instead ensure fast slices
# of the arrays returned by pmap are set up correctly.
# buf = sda.device_buffers[-1]
@@ -336,7 +336,7 @@ class PythonPmapTest(jtu.JaxTestCase):
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
def test_pmap_replicated_copy(self):
- # https://github.com/google/jax/issues/17690
+ # https://github.com/jax-ml/jax/issues/17690
inp = jnp.arange(jax.device_count())
x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(inp)
out = jnp.copy(x)
@@ -605,7 +605,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(y, ref)
def testNestedPmapAxisSwap(self):
- # Regression test for https://github.com/google/jax/issues/5757
+ # Regression test for https://github.com/jax-ml/jax/issues/5757
if jax.device_count() < 8:
raise SkipTest("test requires at least 8 devices")
f = jax.pmap(jax.pmap(lambda x: x, in_axes=1, out_axes=0), in_axes=0,
@@ -1180,7 +1180,7 @@ class PythonPmapTest(jtu.JaxTestCase):
"`perm` does not represent a permutation: \\[1.*\\]", g)
def testPpermuteWithZipObject(self):
- # https://github.com/google/jax/issues/1703
+ # https://github.com/jax-ml/jax/issues/1703
num_devices = jax.device_count()
perm = [num_devices - 1] + list(range(num_devices - 1))
f = self.pmap(lambda x: lax.ppermute(x, "i", zip(perm, range(num_devices))), "i")
@@ -1501,7 +1501,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertEqual(ans.shape, (13, N_DEVICES))
def testVmapOfPmap3(self):
- # https://github.com/google/jax/issues/3399
+ # https://github.com/jax-ml/jax/issues/3399
device_count = jax.device_count()
if device_count < 2:
raise SkipTest("test requires at least two devices")
@@ -1661,7 +1661,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_jit_of_pmap_warning()
def testIssue1065(self):
- # from https://github.com/google/jax/issues/1065
+ # from https://github.com/jax-ml/jax/issues/1065
device_count = jax.device_count()
def multi_step_pmap(state, count):
@@ -1697,7 +1697,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# replica.
@unittest.skip("need eager multi-replica support")
def testPostProcessMap(self):
- # test came from https://github.com/google/jax/issues/1369
+ # test came from https://github.com/jax-ml/jax/issues/1369
nrep = jax.device_count()
def pmvm(a, b):
@@ -1730,7 +1730,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@jax.default_matmul_precision("float32")
def testPostProcessMap2(self):
- # code from https://github.com/google/jax/issues/2787
+ # code from https://github.com/jax-ml/jax/issues/2787
def vv(x, y):
"""Vector-vector multiply"""
return jnp.dot(x, y)
@@ -1758,7 +1758,7 @@ class PythonPmapTest(jtu.JaxTestCase):
('_new', new_checkpoint),
])
def testAxisIndexRemat(self, remat):
- # https://github.com/google/jax/issues/2716
+ # https://github.com/jax-ml/jax/issues/2716
n = len(jax.devices())
def f(key):
@@ -1769,7 +1769,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.pmap(remat(f), axis_name='i')(keys)
def testPmapMapVmapCombinations(self):
- # https://github.com/google/jax/issues/2822
+ # https://github.com/jax-ml/jax/issues/2822
def vv(x, y):
"""Vector-vector multiply"""
return jnp.dot(x, y)
@@ -1802,7 +1802,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3)
def testPmapAxisNameError(self):
- # https://github.com/google/jax/issues/3120
+ # https://github.com/jax-ml/jax/issues/3120
a = np.arange(4)[np.newaxis,:]
def test(x):
return jax.lax.psum(x, axis_name='batch')
@@ -1811,7 +1811,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.pmap(test)(a)
def testPsumOnBooleanDtype(self):
- # https://github.com/google/jax/issues/3123
+ # https://github.com/jax-ml/jax/issues/3123
n = jax.device_count()
if n > 1:
x = jnp.array([True, False])
@@ -1889,7 +1889,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertIn("mhlo.num_partitions = 1", hlo)
def testPsumZeroCotangents(self):
- # https://github.com/google/jax/issues/3651
+ # https://github.com/jax-ml/jax/issues/3651
def loss(params, meta_params):
(net, mpo) = params
return meta_params * mpo * net
@@ -1914,7 +1914,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_jit_of_pmap_warning()
def test_issue_1062(self):
- # code from https://github.com/google/jax/issues/1062 @shoyer
+ # code from https://github.com/jax-ml/jax/issues/1062 @shoyer
# this tests, among other things, whether ShardedDeviceTuple constants work
device_count = jax.device_count()
@@ -1938,7 +1938,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# TODO(skye): fix backend caching so we always have multiple CPUs available
if jax.device_count("cpu") < 4:
self.skipTest("test requires 4 CPU device")
- # https://github.com/google/jax/issues/4223
+ # https://github.com/jax-ml/jax/issues/4223
def fn(indices):
return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32)
mapped_fn = self.pmap(fn, axis_name='i', backend='cpu')
@@ -1982,7 +1982,7 @@ class PythonPmapTest(jtu.JaxTestCase):
for dtype in [np.float32, np.int32]
)
def testPmapDtype(self, dtype):
- # Regression test for https://github.com/google/jax/issues/6022
+ # Regression test for https://github.com/jax-ml/jax/issues/6022
@partial(self.pmap, axis_name='i')
def func(_):
return jax.lax.psum(dtype(0), axis_name='i')
@@ -1991,7 +1991,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertEqual(out_dtype, dtype)
def test_num_replicas_with_switch(self):
- # https://github.com/google/jax/issues/7411
+ # https://github.com/jax-ml/jax/issues/7411
def identity(x):
return x
@@ -2154,7 +2154,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@jtu.run_on_devices("cpu")
def test_pmap_stack_size(self):
- # Regression test for https://github.com/google/jax/issues/20428
+ # Regression test for https://github.com/jax-ml/jax/issues/20428
# pmap isn't particularly important here, but it guarantees that the CPU
# client runs the computation on a threadpool rather than inline.
if jax.device_count() < 2:
@@ -2164,7 +2164,7 @@ class PythonPmapTest(jtu.JaxTestCase):
y.block_until_ready() # doesn't crash
def test_pmap_of_prng_key(self):
- # Regression test for https://github.com/google/jax/issues/20392
+ # Regression test for https://github.com/jax-ml/jax/issues/20392
keys = jax.random.split(jax.random.key(0), jax.device_count())
result1 = jax.pmap(jax.random.bits)(keys)
with jtu.ignore_warning(
diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py
index 76854beae..d9887cf7b 100644
--- a/tests/python_callback_test.py
+++ b/tests/python_callback_test.py
@@ -1245,7 +1245,7 @@ class IOCallbackTest(jtu.JaxTestCase):
np.testing.assert_array_equal(shard[0] + 1, shard[1])
def test_batching_with_side_effects(self):
- # https://github.com/google/jax/issues/20628#issuecomment-2050800195
+ # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195
x_lst = []
def append_x(x):
nonlocal x_lst
@@ -1261,7 +1261,7 @@ class IOCallbackTest(jtu.JaxTestCase):
self.assertAllClose(x_lst, [0., 1., 2., 0., 2., 4.], check_dtypes=False)
def test_batching_with_side_effects_while_loop(self):
- # https://github.com/google/jax/issues/20628#issuecomment-2050921219
+ # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219
x_lst = []
def append_x(x):
nonlocal x_lst
diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py
index 8d00b5eed..2e0fc3223 100644
--- a/tests/pytorch_interoperability_test.py
+++ b/tests/pytorch_interoperability_test.py
@@ -109,7 +109,7 @@ class DLPackTest(jtu.JaxTestCase):
self.assertAllClose(np, y.cpu().numpy())
def testTorchToJaxInt64(self):
- # See https://github.com/google/jax/issues/11895
+ # See https://github.com/jax-ml/jax/issues/11895
x = jax.dlpack.from_dlpack(
torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64)))
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py
index f69687ddc..63510b729 100644
--- a/tests/random_lax_test.py
+++ b/tests/random_lax_test.py
@@ -178,7 +178,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testNormalBfloat16(self):
# Passing bfloat16 as dtype string.
- # https://github.com/google/jax/issues/6813
+ # https://github.com/jax-ml/jax/issues/6813
res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16')
res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16)
self.assertAllClose(res_bfloat16, res_bfloat16_str)
@@ -391,7 +391,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testBetaSmallParameters(self, dtype=np.float32):
- # Regression test for beta version of https://github.com/google/jax/issues/9896
+ # Regression test for beta version of https://github.com/jax-ml/jax/issues/9896
key = self.make_key(0)
a, b = 0.0001, 0.0002
samples = random.beta(key, a, b, shape=(100,), dtype=dtype)
@@ -441,7 +441,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu") # lower accuracy leads to failures.
def testDirichletSmallAlpha(self, dtype=np.float32):
- # Regression test for https://github.com/google/jax/issues/9896
+ # Regression test for https://github.com/jax-ml/jax/issues/9896
key = self.make_key(0)
alpha = 0.00001 * jnp.ones(3)
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)
@@ -530,7 +530,7 @@ class LaxRandomTest(jtu.JaxTestCase):
rtol=rtol)
def testGammaGradType(self):
- # Regression test for https://github.com/google/jax/issues/2130
+ # Regression test for https://github.com/jax-ml/jax/issues/2130
key = self.make_key(0)
a = jnp.array(1., dtype=jnp.float32)
b = jnp.array(3., dtype=jnp.float32)
@@ -663,7 +663,7 @@ class LaxRandomTest(jtu.JaxTestCase):
)
def testGeneralizedNormalKS(self, p, shape, dtype):
self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws
- "sensitive to random key - https://github.com/google/jax/issues/18941")
+ "sensitive to random key - https://github.com/jax-ml/jax/issues/18941")
key = lambda: self.make_key(2)
rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype)
crand = jax.jit(rand)
@@ -700,7 +700,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testBallKS(self, d, p, shape, dtype):
self.skipTest(
- "sensitive to random key - https://github.com/google/jax/issues/18932")
+ "sensitive to random key - https://github.com/jax-ml/jax/issues/18932")
key = lambda: self.make_key(123)
rand = lambda key, p: random.ball(key, d, p, (100, *shape), dtype)
crand = jax.jit(rand)
@@ -800,7 +800,7 @@ class LaxRandomTest(jtu.JaxTestCase):
assert samples.shape == shape + (dim,)
def testMultivariateNormalCovariance(self):
- # test code based on https://github.com/google/jax/issues/1869
+ # test code based on https://github.com/jax-ml/jax/issues/1869
N = 100000
mean = jnp.zeros(4)
cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00],
@@ -827,7 +827,7 @@ class LaxRandomTest(jtu.JaxTestCase):
@jtu.sample_product(method=['cholesky', 'eigh', 'svd'])
@jtu.skip_on_devices('gpu', 'tpu') # Some NaNs on accelerators.
def testMultivariateNormalSingularCovariance(self, method):
- # Singular covariance matrix https://github.com/google/jax/discussions/13293
+ # Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293
mu = jnp.zeros((2,))
sigma = jnp.ones((2, 2))
key = self.make_key(0)
@@ -889,7 +889,7 @@ class LaxRandomTest(jtu.JaxTestCase):
def testRandomBroadcast(self):
"""Issue 4033"""
- # test for broadcast issue in https://github.com/google/jax/issues/4033
+ # test for broadcast issue in https://github.com/jax-ml/jax/issues/4033
key = lambda: self.make_key(0)
shape = (10, 2)
with jax.numpy_rank_promotion('allow'):
@@ -1071,7 +1071,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertGreater((r == 255).sum(), 0)
def test_large_prng(self):
- # https://github.com/google/jax/issues/11010
+ # https://github.com/jax-ml/jax/issues/11010
def f():
return random.uniform(
self.make_key(3), (308000000, 128), dtype=jnp.bfloat16)
@@ -1086,7 +1086,7 @@ class LaxRandomTest(jtu.JaxTestCase):
logits_shape_base=[(3, 4), (3, 1), (1, 4)],
axis=[-3, -2, -1, 0, 1, 2])
def test_categorical_shape_argument(self, shape, logits_shape_base, axis):
- # https://github.com/google/jax/issues/13124
+ # https://github.com/jax-ml/jax/issues/13124
logits_shape = list(logits_shape_base)
logits_shape.insert(axis % (len(logits_shape_base) + 1), 10)
assert logits_shape[axis] == 10
diff --git a/tests/random_test.py b/tests/random_test.py
index 941172f75..da182dbcc 100644
--- a/tests/random_test.py
+++ b/tests/random_test.py
@@ -436,7 +436,7 @@ class PrngTest(jtu.JaxTestCase):
@skipIf(not config.threefry_partitionable.value, 'enable after upgrade')
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_threefry_split_vmapped_fold_in_symmetry(self, make_key):
- # See https://github.com/google/jax/issues/7708
+ # See https://github.com/jax-ml/jax/issues/7708
with jax.default_prng_impl('threefry2x32'):
key = make_key(72)
f1, f2, f3 = vmap(lambda k, _: random.fold_in(k, lax.axis_index('batch')),
@@ -450,7 +450,7 @@ class PrngTest(jtu.JaxTestCase):
@skipIf(config.threefry_partitionable.value, 'changed random bit values')
def test_loggamma_nan_corner_case(self):
- # regression test for https://github.com/google/jax/issues/17922
+ # regression test for https://github.com/jax-ml/jax/issues/17922
# This particular key previously led to NaN output.
# If the underlying implementation ever changes, this test will no longer
# exercise this corner case, so we compare to a particular output value
@@ -545,7 +545,7 @@ class PrngTest(jtu.JaxTestCase):
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_key_output_vjp(self, make_key):
- # See https://github.com/google/jax/issues/14856
+ # See https://github.com/jax-ml/jax/issues/14856
def f(seed): return make_key(seed)
jax.vjp(f, 1) # doesn't crash
@@ -578,7 +578,7 @@ class ThreefryPrngTest(jtu.JaxTestCase):
partial(random.PRNGKey, impl='threefry2x32'),
partial(random.key, impl='threefry2x32')]])
def test_seed_no_implicit_transfers(self, make_key):
- # See https://github.com/google/jax/issues/15613
+ # See https://github.com/jax-ml/jax/issues/15613
with jax.transfer_guard('disallow'):
make_key(jax.device_put(42)) # doesn't crash
@@ -922,14 +922,14 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(ys.shape, (3, 2))
def test_select_scalar_cond(self):
- # regression test for https://github.com/google/jax/issues/16422
+ # regression test for https://github.com/jax-ml/jax/issues/16422
ks = self.make_keys(3)
ys = lax.select(True, ks, ks)
self.assertIsInstance(ys, prng_internal.PRNGKeyArray)
self.assertEqual(ys.shape, (3,))
def test_vmap_of_cond(self):
- # See https://github.com/google/jax/issues/15869
+ # See https://github.com/jax-ml/jax/issues/15869
def f(x):
keys = self.make_keys(*x.shape)
return lax.select(x, keys, keys)
@@ -1126,7 +1126,7 @@ class KeyArrayTest(jtu.JaxTestCase):
self.assertEqual(repr(spec), f"PRNGSpec({name!r})")
def test_keyarray_custom_vjp(self):
- # Regression test for https://github.com/google/jax/issues/18442
+ # Regression test for https://github.com/jax-ml/jax/issues/18442
@jax.custom_vjp
def f(_, state):
return state
diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py
index 701b7c570..dd34a99a7 100644
--- a/tests/scipy_ndimage_test.py
+++ b/tests/scipy_ndimage_test.py
@@ -129,7 +129,7 @@ class NdimageTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_op, lsp_op, args_maker)
def testContinuousGradients(self):
- # regression test for https://github.com/google/jax/issues/3024
+ # regression test for https://github.com/jax-ml/jax/issues/3024
def loss(delta):
x = np.arange(100.0)
diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py
index ad988bba6..983cb6bdc 100644
--- a/tests/scipy_stats_test.py
+++ b/tests/scipy_stats_test.py
@@ -314,7 +314,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
rtol={np.float32: 2e-3, np.float64: 1e-4})
def testBetaLogPdfZero(self):
- # Regression test for https://github.com/google/jax/issues/7645
+ # Regression test for https://github.com/jax-ml/jax/issues/7645
a = b = 1.
x = np.array([0., 1.])
self.assertAllClose(
@@ -539,7 +539,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker)
def testGammaLogPdfZero(self):
- # Regression test for https://github.com/google/jax/issues/7256
+ # Regression test for https://github.com/jax-ml/jax/issues/7256
self.assertAllClose(
osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)
@@ -710,7 +710,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker)
def testLogisticLogpdfOverflow(self):
- # Regression test for https://github.com/google/jax/issues/10219
+ # Regression test for https://github.com/jax-ml/jax/issues/10219
self.assertAllClose(
np.array([-100, -100], np.float32),
lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
@@ -855,7 +855,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker)
def testNormSfNearZero(self):
- # Regression test for https://github.com/google/jax/issues/17199
+ # Regression test for https://github.com/jax-ml/jax/issues/17199
value = np.array(10, np.float32)
self.assertAllClose(osp_stats.norm.sf(value).astype('float32'),
lsp_stats.norm.sf(value),
@@ -1208,7 +1208,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
def testBinomPmfOutOfRange(self):
- # Regression test for https://github.com/google/jax/issues/19150
+ # Regression test for https://github.com/jax-ml/jax/issues/19150
self.assertEqual(lsp_stats.binom.pmf(k=6.5, n=5, p=0.8), 0.0)
def testBinomLogPmfZerokZeron(self):
diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py
index 27199c874..77e5273d1 100644
--- a/tests/shape_poly_test.py
+++ b/tests/shape_poly_test.py
@@ -2973,7 +2973,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
(2, x.shape[0]), (1, 1), "VALID"),
arg_descriptors=[RandArg((3, 8), _f32)],
polymorphic_shapes=["b, ..."]),
- # https://github.com/google/jax/issues/11804
+ # https://github.com/jax-ml/jax/issues/11804
# Use the reshape trick to simulate a polymorphic dimension of 16*b.
# (See test "conv_general_dilated.1d_1" above for more details.)
PolyHarness("reduce_window", "add_monoid_strides_window_size=static",
diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py
index ae22eeca0..20bc33475 100644
--- a/tests/shard_map_test.py
+++ b/tests/shard_map_test.py
@@ -724,7 +724,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
def test_nested_vmap_with_capture_spmd_axis_name(self):
- self.skipTest('https://github.com/google/jax/issues/23476')
+ self.skipTest('https://github.com/jax-ml/jax/issues/23476')
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
def to_map_with_capture(x, y):
@@ -902,7 +902,7 @@ class ShardMapTest(jtu.JaxTestCase):
@jax.legacy_prng_key('allow')
def test_prngkeyarray_eager(self):
- # https://github.com/google/jax/issues/15398
+ # https://github.com/jax-ml/jax/issues/15398
mesh = jtu.create_mesh((4,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
@@ -1069,7 +1069,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(out, 1.)
def test_jaxpr_shardings_with_no_outputs(self):
- # https://github.com/google/jax/issues/15385
+ # https://github.com/jax-ml/jax/issues/15385
mesh = jtu.create_mesh((4,), ('i',))
@jax.jit
@@ -1109,7 +1109,7 @@ class ShardMapTest(jtu.JaxTestCase):
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def test_key_array_with_replicated_last_tile_dim(self):
- # See https://github.com/google/jax/issues/16137
+ # See https://github.com/jax-ml/jax/issues/16137
mesh = jtu.create_mesh((2, 4), ('i', 'j'))
@@ -1690,7 +1690,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(grad, jnp.ones(4) * 4 * 4, check_dtypes=False)
def test_repeated_psum_allowed(self):
- # https://github.com/google/jax/issues/19175
+ # https://github.com/jax-ml/jax/issues/19175
mesh = jtu.create_mesh((4,), 'i')
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
@@ -1927,7 +1927,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
- # https://github.com/google/jax/pull/21032
+ # https://github.com/jax-ml/jax/pull/21032
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
@partial(
@@ -1944,7 +1944,7 @@ class ShardMapTest(jtu.JaxTestCase):
jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash
def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self):
- # https://github.com/google/jax/pull/21056
+ # https://github.com/jax-ml/jax/pull/21056
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
@partial(jax.remat, policy=lambda *_, **__: True)
@@ -1962,7 +1962,7 @@ class ShardMapTest(jtu.JaxTestCase):
jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash
def test_grad_shmap_residuals_axis_names_in_mesh_order(self):
- # https://github.com/google/jax/issues/21236
+ # https://github.com/jax-ml/jax/issues/21236
mesh = jtu.create_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a'))
@partial(
@@ -2108,7 +2108,7 @@ class ShardMapTest(jtu.JaxTestCase):
)((object(), object()), x)
def test_custom_linear_solve_rep_rules(self):
- # https://github.com/google/jax/issues/20162
+ # https://github.com/jax-ml/jax/issues/20162
mesh = jtu.create_mesh((1,), ('i',))
a = jnp.array(1).reshape(1, 1)
b = jnp.array(1).reshape(1)
diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py
index 545d73bff..38fde72f0 100644
--- a/tests/sparse_bcoo_bcsr_test.py
+++ b/tests/sparse_bcoo_bcsr_test.py
@@ -310,7 +310,7 @@ class BCOOTest(sptu.SparseTestCase):
self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]]))
def test_bcoo_extract_batching(self):
- # https://github.com/google/jax/issues/9431
+ # https://github.com/jax-ml/jax/issues/9431
indices = jnp.zeros((4, 1, 1), dtype=int)
mat = jnp.arange(4.).reshape((4, 1))
@@ -353,7 +353,7 @@ class BCOOTest(sptu.SparseTestCase):
self.assertEqual(hess.shape, data.shape + 2 * M.shape)
def test_bcoo_extract_zero_nse(self):
- # Regression test for https://github.com/google/jax/issues/13653
+ # Regression test for https://github.com/jax-ml/jax/issues/13653
# (n_batch, n_sparse, n_dense) = (1, 0, 0), nse = 2
args_maker = lambda: (jnp.zeros((3, 2, 0), dtype='int32'), jnp.arange(3))
@@ -974,7 +974,7 @@ class BCOOTest(sptu.SparseTestCase):
self.assertEqual(out.nse, expected_nse)
def test_bcoo_spdot_general_ad_bug(self):
- # Regression test for https://github.com/google/jax/issues/10163
+ # Regression test for https://github.com/jax-ml/jax/issues/10163
A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]])
A_values = jnp.array([-2.0, 1.0, -1.0, 0.5, 2.0])
A_shape = (2, 3)
@@ -1287,7 +1287,7 @@ class BCOOTest(sptu.SparseTestCase):
self.assertEqual(y2.nse, x.nse)
def test_bcoo_sum_duplicates_padding(self):
- # Regression test for https://github.com/google/jax/issues/8163
+ # Regression test for https://github.com/jax-ml/jax/issues/8163
size = 3
data = jnp.array([1, 0, 0])
indices = jnp.array([1, size, size])[:, None]
@@ -1606,7 +1606,7 @@ class BCOOTest(sptu.SparseTestCase):
self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol)
def test_bcoo_mul_sparse_with_duplicates(self):
- # Regression test for https://github.com/google/jax/issues/8888
+ # Regression test for https://github.com/jax-ml/jax/issues/8888
indices = jnp.array([[0, 1, 0, 0, 1, 1],
[1, 0, 1, 2, 0, 2]]).T
data = jnp.array([1, 2, 3, 4, 5, 6])
@@ -1940,7 +1940,7 @@ class BCSRTest(sptu.SparseTestCase):
self._CheckGradsSparse(dense_func, sparse_func, args_maker)
def test_bcoo_spdot_abstract_eval_bug(self):
- # Regression test for https://github.com/google/jax/issues/21921
+ # Regression test for https://github.com/jax-ml/jax/issues/21921
lhs = sparse.BCOO(
(jnp.float32([[1]]), lax.broadcasted_iota(jnp.int32, (10, 1, 1), 0)),
shape=(10, 10))
diff --git a/tests/sparse_test.py b/tests/sparse_test.py
index 616396222..eb8d70be1 100644
--- a/tests/sparse_test.py
+++ b/tests/sparse_test.py
@@ -323,7 +323,7 @@ class cuSparseTest(sptu.SparseTestCase):
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=sptu.MATMUL_TOL)
def test_coo_matmat_layout(self):
- # Regression test for https://github.com/google/jax/issues/7533
+ # Regression test for https://github.com/jax-ml/jax/issues/7533
d = jnp.array([1.0, 2.0, 3.0, 4.0])
i = jnp.array([0, 0, 1, 2])
j = jnp.array([0, 2, 0, 0])
diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py
index 46086511d..46c2f5aaf 100644
--- a/tests/sparsify_test.py
+++ b/tests/sparsify_test.py
@@ -610,7 +610,7 @@ class SparsifyTest(jtu.JaxTestCase):
self.assertArraysEqual(jit(func)(Msp).todense(), expected)
def testWeakTypes(self):
- # Regression test for https://github.com/google/jax/issues/8267
+ # Regression test for https://github.com/jax-ml/jax/issues/8267
M = jnp.arange(12, dtype='int32').reshape(3, 4)
Msp = BCOO.fromdense(M)
self.assertArraysEqual(
diff --git a/tests/stax_test.py b/tests/stax_test.py
index 6850f36a0..e21300ddd 100644
--- a/tests/stax_test.py
+++ b/tests/stax_test.py
@@ -216,7 +216,7 @@ class StaxTest(jtu.JaxTestCase):
def testBatchNormShapeNCHW(self):
key = random.PRNGKey(0)
- # Regression test for https://github.com/google/jax/issues/461
+ # Regression test for https://github.com/jax-ml/jax/issues/461
init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
input_shape = (4, 5, 6, 7)
diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py
index bc741702c..f8792a263 100644
--- a/tests/tree_util_test.py
+++ b/tests/tree_util_test.py
@@ -343,7 +343,7 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(h.args, (3,))
def testPartialFuncAttributeHasStableHash(self):
- # https://github.com/google/jax/issues/9429
+ # https://github.com/jax-ml/jax/issues/9429
fun = functools.partial(print, 1)
p1 = tree_util.Partial(fun, 2)
p2 = tree_util.Partial(fun, 2)
@@ -359,7 +359,7 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual([c0, c1], tree.children())
def testTreedefTupleFromChildren(self):
- # https://github.com/google/jax/issues/7377
+ # https://github.com/jax-ml/jax/issues/7377
tree = ((1, 2, (3, 4)), (5,))
leaves, treedef1 = tree_util.tree_flatten(tree)
treedef2 = tree_util.treedef_tuple(treedef1.children())
@@ -368,7 +368,7 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
def testTreedefTupleComparesEqual(self):
- # https://github.com/google/jax/issues/9066
+ # https://github.com/jax-ml/jax/issues/9066
self.assertEqual(tree_util.tree_structure((3,)),
tree_util.treedef_tuple((tree_util.tree_structure(3),)))
@@ -978,7 +978,7 @@ class RavelUtilTest(jtu.JaxTestCase):
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
def testDtypePolymorphicUnravel(self):
- # https://github.com/google/jax/issues/7809
+ # https://github.com/jax-ml/jax/issues/7809
x = jnp.arange(10, dtype=jnp.float32)
x_flat, unravel = flatten_util.ravel_pytree(x)
y = x_flat < 5.3
@@ -987,7 +987,7 @@ class RavelUtilTest(jtu.JaxTestCase):
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
def testDtypeMonomorphicUnravel(self):
- # https://github.com/google/jax/issues/7809
+ # https://github.com/jax-ml/jax/issues/7809
x1 = jnp.arange(10, dtype=jnp.float32)
x2 = jnp.arange(10, dtype=jnp.int32)
x_flat, unravel = flatten_util.ravel_pytree((x1, x2))
diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py
index 58cf4a2ba..d4403b7e5 100644
--- a/tests/x64_context_test.py
+++ b/tests/x64_context_test.py
@@ -128,7 +128,7 @@ class X64ContextTests(jtu.JaxTestCase):
@unittest.skip("test fails, see #8552")
def test_convert_element_type(self):
- # Regression test for part of https://github.com/google/jax/issues/5982
+ # Regression test for part of https://github.com/jax-ml/jax/issues/5982
with enable_x64():
x = jnp.int64(1)
self.assertEqual(x.dtype, jnp.int64)