diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index c19832e63..628310519 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -20,11 +20,11 @@ body: * If you prefer a non-templated issue report, click [here][Raw report]. - [Discussions]: https://github.com/google/jax/discussions + [Discussions]: https://github.com/jax-ml/jax/discussions - [issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues + [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues - [Raw report]: http://github.com/google/jax/issues/new + [Raw report]: http://github.com/jax-ml/jax/issues/new - type: textarea attributes: label: Description diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index cabbed589..f078e8e94 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,5 @@ blank_issues_enabled: false contact_links: - name: Have questions or need support? - url: https://github.com/google/jax/discussions + url: https://github.com/jax-ml/jax/discussions about: Please ask questions on the Discussions tab diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 1e345954d..74cb45920 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -84,7 +84,7 @@ jobs: failure() && steps.status.outcome == 'failure' && github.event_name == 'schedule' - && github.repository == 'google/jax' + && github.repository == 'jax-ml/jax' uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 with: name: output-${{ matrix.python-version }}-log.jsonl diff --git a/CHANGELOG.md b/CHANGELOG.md index ee782d04a..43db6e197 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -279,7 +279,7 @@ See the 0.4.33 release notes for more details. which manifested as an incorrect output for cumulative reductions (#21403). * Fixed a bug where XLA:CPU miscompiled certain matmul fusions (https://github.com/openxla/xla/pull/13301). - * Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396). + * Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396). * Deprecations * `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will @@ -401,7 +401,7 @@ See the 0.4.33 release notes for more details. branch consistent with that of NumPy 2.0. * The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'` and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has - changed](https://github.com/google/jax/issues/19085) so that + changed](https://github.com/jax-ml/jax/issues/19085) so that mapping over keys results in random generation only from the first key in the batch. * Docs now use `jax.random.key` for construction of PRNG key arrays @@ -433,7 +433,7 @@ See the 0.4.33 release notes for more details. * JAX export does not support older serialization versions anymore. Version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. - See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). + See [a description of the versions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). This change could break clients that set a specific JAX serialization version lower than 9. @@ -506,7 +506,7 @@ See the 0.4.33 release notes for more details. * added the ability to specify symbolic constraints on the dimension variables. This makes shape polymorphism more expressive, and gives a way to workaround limitations in the reasoning about inequalities. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. * with the addition of symbolic constraints ({jax-issue}`#19235`) we now consider dimension variables from different scopes to be different, even if they have the same name. Symbolic expressions from different scopes @@ -516,7 +516,7 @@ See the 0.4.33 release notes for more details. The scope of a symbolic expression `e` can be read with `e.scope` and passed into the above functions to direct them to construct symbolic expressions in a given scope. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. * simplified and faster equality comparisons, where we consider two symbolic dimensions to be equal if the normalized form of their difference reduces to 0 ({jax-issue}`#19231`; note that this may result in user-visible behavior @@ -535,7 +535,7 @@ See the 0.4.33 release notes for more details. strings for polymorphic shapes specifications ({jax-issue}`#19284`). * JAX default native serialization version is now 9. This is relevant for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`. - See [description of version numbers](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). + See [description of version numbers](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). * Refactored the API for `jax.experimental.export`. Instead of `from jax.experimental.export import export` you should use now `from jax.experimental import export`. The old way of importing will @@ -781,19 +781,19 @@ See the 0.4.33 release notes for more details. * When not running under IPython: when an exception is raised, JAX now filters out the entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace" that previously appeared.) This should produce much friendlier-looking tracebacks. See - [here](https://github.com/google/jax/pull/16949) for an example. + [here](https://github.com/jax-ml/jax/pull/16949) for an example. This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two separate unfiltered/filtered tracebacks, which was the old behavior) or `JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback). * jax2tf default serialization version is now 7, which introduces new shape - [safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). + [safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). * Devices passed to `jax.sharding.Mesh` should be hashable. This specifically applies to mock devices or user created devices. `jax.devices()` are already hashable. * Breaking changes: * jax2tf now uses native serialization by default. See - the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) + the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) for details and for mechanisms to override the default. * The option `--jax_coordination_service` has been removed. It is now always `True`. @@ -922,7 +922,7 @@ See the 0.4.33 release notes for more details. arguments will always resolve to the "common operands" `cond` behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See - [#16413](https://github.com/google/jax/issues/16413). + [#16413](https://github.com/jax-ml/jax/issues/16413). * The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`, which did nothing, have been removed. These options have been true by default for many releases. @@ -933,7 +933,7 @@ See the 0.4.33 release notes for more details. serialization version ({jax-issue}`#16746`). * jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism. ## jaxlib 0.4.14 (July 27, 2023) @@ -1095,14 +1095,14 @@ See the 0.4.33 release notes for more details. {func}`jax.experimental.host_callback` is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the [JAX issue - tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs + tracker](https://github.com/jax-ml/jax/issues) if the new `jax.debug` APIs are insufficient for your use case. The old runtime component will be available for at least the next three months by setting the environment variable `JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new runtime for any reason, please let us know on the [JAX issue - tracker](https://github.com/google/jax/issues). + tracker](https://github.com/jax-ml/jax/issues). * Changes * The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7. @@ -1126,7 +1126,7 @@ See the 0.4.33 release notes for more details. StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. - See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + See [documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). As part of this change the config flag `--jax2tf_default_experimental_native_lowering` has been renamed to `--jax2tf_native_serialization`. * JAX now depends on `ml_dtypes`, which contains definitions of NumPy types @@ -1403,7 +1403,7 @@ Changes: ## jaxlib 0.3.22 (Oct 11, 2022) ## jax 0.3.21 (Sep 30, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.20...jax-v0.3.21). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.20...jax-v0.3.21). * Changes * The persistent compilation cache will now warn instead of raising an exception on error ({jax-issue}`#12582`), so program execution can continue @@ -1417,18 +1417,18 @@ Changes: * Fix incorrect `pip` url in `setup.py` comment ({jax-issue}`#12528`). ## jaxlib 0.3.20 (Sep 28, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20). * Bug fixes * Fixes support for limiting the visible CUDA devices via `jax_cuda_visible_devices` in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU ({jax-issue}`#12533`). ## jax 0.3.19 (Sep 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.18...jax-v0.3.19). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.18...jax-v0.3.19). * Fixes required jaxlib version. ## jax 0.3.18 (Sep 26, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.17...jax-v0.3.18). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.17...jax-v0.3.18). * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the @@ -1446,7 +1446,7 @@ Changes: would have been provided. ## jax 0.3.17 (Aug 31, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.16...jax-v0.3.17). * Bugs * Fix corner case issue in gradient of `lax.pow` with an exponent of zero ({jax-issue}`12041`) @@ -1462,7 +1462,7 @@ Changes: * `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead. ## jax 0.3.16 -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). @@ -1486,7 +1486,7 @@ Changes: deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...jax-v0.3.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). * Changes * `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These classes have been deprecated since v0.3.1 ({jax-issue}`#11248`). @@ -1507,10 +1507,10 @@ Changes: following a similar deprecation in {func}`scipy.linalg.solve`. ## jaxlib 0.3.15 (July 22, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15). ## jax 0.3.14 (June 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.13...jax-v0.3.14). * Breaking changes * {func}`jax.experimental.compilation_cache.initialize_cache` does not support `max_cache_size_ bytes` anymore and will not get that as an input. @@ -1563,22 +1563,22 @@ Changes: coefficients have leading zeros ({jax-issue}`#11215`). ## jaxlib 0.3.14 (June 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14). * x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement. * The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks. * The Python flatbuffers package is no longer a dependency of jaxlib. ## jax 0.3.13 (May 16, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.12...jax-v0.3.13). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.12...jax-v0.3.13). ## jax 0.3.12 (May 15, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.11...jax-v0.3.12). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.11...jax-v0.3.12). * Changes - * Fixes [#10717](https://github.com/google/jax/issues/10717). + * Fixes [#10717](https://github.com/jax-ml/jax/issues/10717). ## jax 0.3.11 (May 15, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...jax-v0.3.11). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.10...jax-v0.3.11). * Changes * {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument that allows users to opt out of eigenvalue sorting on TPU. @@ -1592,22 +1592,22 @@ Changes: scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead. ## jax 0.3.10 (May 3, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.9...jax-v0.3.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.9...jax-v0.3.10). ## jaxlib 0.3.10 (May 3, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10). * Changes * [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a) fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs. ## jax 0.3.9 (May 2, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.8...jax-v0.3.9). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.8...jax-v0.3.9). * Changes * Added support for fully asynchronous checkpointing for GlobalDeviceArray. ## jax 0.3.8 (April 29 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...jax-v0.3.8). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.7...jax-v0.3.8). * Changes * {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver. * {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input. @@ -1666,7 +1666,7 @@ Changes: ## jax 0.3.7 (April 15, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.6...jax-v0.3.7). * Changes: * Fixed a performance problem if the indices passed to {func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`). @@ -1684,17 +1684,17 @@ Changes: ## jax 0.3.6 (April 12, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.5...jax-v0.3.6). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.5...jax-v0.3.6). * Changes: * Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU - pod. Fixes [#10218](https://github.com/google/jax/issues/10218). + pod. Fixes [#10218](https://github.com/jax-ml/jax/issues/10218). * Deprecations: * {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278` for an alternative API. ## jax 0.3.5 (April 7, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.4...jax-v0.3.5). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.4...jax-v0.3.5). * Changes: * added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta` and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`). @@ -1717,17 +1717,17 @@ Changes: ## jax 0.3.4 (March 18, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.3...jax-v0.3.4). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.3...jax-v0.3.4). ## jax 0.3.3 (March 17, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.2...jax-v0.3.3). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.2...jax-v0.3.3). ## jax 0.3.2 (March 16, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.1...jax-v0.3.2). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.1...jax-v0.3.2). * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use @@ -1751,7 +1751,7 @@ Changes: ## jax 0.3.1 (Feb 18, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.0...jax-v0.3.1). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.0...jax-v0.3.1). * Changes: * `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated. @@ -1774,7 +1774,7 @@ Changes: ## jax 0.3.0 (Feb 10, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.28...jax-v0.3.0). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) @@ -1788,7 +1788,7 @@ Changes: ## jax 0.2.28 (Feb 1, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.27...jax-v0.2.28). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.27...jax-v0.2.28). * `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no `dialect=` is passed. * The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR @@ -1813,7 +1813,7 @@ Changes: * The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311). ## jax 0.2.27 (Jan 18 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...jax-v0.2.27). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.26...jax-v0.2.27). * Breaking changes: * Support for NumPy 1.18 has been dropped, per the @@ -1858,7 +1858,7 @@ Changes: ## jax 0.2.26 (Dec 8, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.25...jax-v0.2.26). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.25...jax-v0.2.26). * Bug fixes: * Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with @@ -1875,7 +1875,7 @@ Changes: ## jax 0.2.25 (Nov 10, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.24...jax-v0.2.25). * New features: * (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend. @@ -1889,7 +1889,7 @@ Changes: ## jax 0.2.24 (Oct 19, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.22...jax-v0.2.24). * New features: * `jax.random.choice` and `jax.random.permutation` now support @@ -1923,7 +1923,7 @@ Changes: ## jax 0.2.22 (Oct 12, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.21...jax-v0.2.22). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.21...jax-v0.2.22). * Breaking Changes * Static arguments to `jax.pmap` must now be hashable. @@ -1958,13 +1958,13 @@ Changes: * Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+. * Bug fixes: - * Fixes https://github.com/google/jax/issues/7461, which caused wrong + * Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler. ## jax 0.2.21 (Sept 23, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21). * Breaking Changes * `jax.api` has been removed. Functions that were available as `jax.api.*` were aliases for functions in `jax.*`; please use the functions in @@ -1992,7 +1992,7 @@ Changes: ## jax 0.2.20 (Sept 2, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.19...jax-v0.2.20). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20). * Breaking Changes * `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`) * `jnp.unique` and other set-like operations now require array-like inputs @@ -2005,7 +2005,7 @@ Changes: ## jax 0.2.19 (Aug 12, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.18...jax-v0.2.19). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). @@ -2042,7 +2042,7 @@ Changes: called in sequence. ## jax 0.2.18 (July 21 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...jax-v0.2.18). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18). * Breaking changes: * Support for Python 3.6 has been dropped, per the @@ -2065,7 +2065,7 @@ Changes: * Fix bugs in TFRT CPU backend that results in incorrect results. ## jax 0.2.17 (July 9 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17). * Bug fixes: * Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency @@ -2082,12 +2082,12 @@ Changes: ## jax 0.2.16 (June 23 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.15...jax-v0.2.16). ## jax 0.2.15 (June 23 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.14...jax-v0.2.15). * New features: - * [#7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend + * [#7042](https://github.com/jax-ml/jax/pull/7042) Turned on TFRT CPU backend with significant dispatch performance improvements on CPU. * The {func}`jax2tf.convert` supports inequalities and min/max for booleans ({jax-issue}`#6956`). @@ -2107,7 +2107,7 @@ Changes: CPU. ## jax 0.2.14 (June 10 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.13...jax-v0.2.14). * New features: * The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`. * A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters @@ -2165,7 +2165,7 @@ Changes: {func}`jit` transformed functions. ## jax 0.2.13 (May 3 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.12...jax-v0.2.13). * New features: * When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static keyword arguments. A new `static_argnames` option has been added to specify @@ -2209,7 +2209,7 @@ Changes: ## jaxlib 0.1.65 (April 7 2021) ## jax 0.2.12 (April 1 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.11...v0.2.12). * New features * New profiling APIs: {func}`jax.profiler.start_trace`, {func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace` @@ -2222,7 +2222,7 @@ Changes: * `TraceContext` --> {func}`~jax.profiler.TraceAnnotation` * `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation` * `trace_function` --> {func}`~jax.profiler.annotate_function` - * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md) + * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md) for more information. * Python integers larger than the maximum `int64` value will now lead to an overflow in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`). @@ -2236,23 +2236,23 @@ Changes: ## jax 0.2.11 (March 23 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.10...jax-v0.2.11). * New features: - * [#6112](https://github.com/google/jax/pull/6112) added context managers: + * [#6112](https://github.com/jax-ml/jax/pull/6112) added context managers: `jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`, `jax.debug_infs`, `jax.log_compiles`. - * [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete` + * [#6085](https://github.com/jax-ml/jax/pull/6085) added `jnp.delete` * Bug fixes: - * [#6136](https://github.com/google/jax/pull/6136) generalized + * [#6136](https://github.com/jax-ml/jax/pull/6136) generalized `jax.flatten_util.ravel_pytree` to handle integer dtypes. - * [#6129](https://github.com/google/jax/issues/6129) fixed a bug with handling + * [#6129](https://github.com/jax-ml/jax/issues/6129) fixed a bug with handling some constants like `enum.IntEnums` - * [#6145](https://github.com/google/jax/pull/6145) fixed batching issues with + * [#6145](https://github.com/jax-ml/jax/pull/6145) fixed batching issues with incomplete beta functions - * [#6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during + * [#6014](https://github.com/jax-ml/jax/pull/6014) fixed H2D transfers during tracing - * [#6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when + * [#6165](https://github.com/jax-ml/jax/pull/6165) avoids OverflowErrors when converting some large Python integers to floats * Breaking changes: * The minimum jaxlib version is now 0.1.62. @@ -2264,13 +2264,13 @@ Changes: ## jax 0.2.10 (March 5 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.9...jax-v0.2.10). * New features: * {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods. * {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods. * Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions from JAX ({jax-issue}`#5627`) - and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)). + and [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)). * Extended the batching rule for `lax.pad` to support batching of the padding values. * Bug fixes: * {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`) @@ -2314,7 +2314,7 @@ Changes: ## jax 0.2.9 (January 26 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.8...jax-v0.2.9). * New features: * Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved error checking and error messages. @@ -2330,7 +2330,7 @@ Changes: ## jax 0.2.8 (January 12 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.7...jax-v0.2.8). * New features: * Add {func}`jax.closure_convert` for use with higher-order custom derivative functions. ({jax-issue}`#5244`) @@ -2362,7 +2362,7 @@ Changes: ## jax 0.2.7 (Dec 4 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.6...jax-v0.2.7). * New features: * Add `jax.device_put_replicated` * Add multi-host support to `jax.experimental.sharded_jit` @@ -2382,14 +2382,14 @@ Changes: ## jax 0.2.6 (Nov 18 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.5...jax-v0.2.6). * New Features: * Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. - See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + See [README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). * Breaking change cleanup * Raise an error on non-hashable static arguments for jax.jit and - xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42). + xla_computation. See [cb48f42](https://github.com/jax-ml/jax/commit/cb48f42). * Improve consistency of type promotion behavior ({jax-issue}`#4744`): * Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously @@ -2441,15 +2441,15 @@ Changes: ## jax 0.2.5 (October 27 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.4...jax-v0.2.5). * Improvements: * Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`. * Expanded the set of JAX primitives converted by jax2tf. - See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md). + See [primitives_with_limited_support.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md). ## jax 0.2.4 (October 19 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.3...jax-v0.2.4). * Improvements: * Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`. * Deprecations @@ -2461,17 +2461,17 @@ Changes: ## jax 0.2.3 (October 14 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.2...jax-v0.2.3). * The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation ## jax 0.2.2 (October 13 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.1...jax-v0.2.2). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.1...jax-v0.2.2). ## jax 0.2.1 (October 6 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.0...jax-v0.2.1). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.0...jax-v0.2.1). * Improvements: * As a benefit of omnistaging, the host_callback functions are executed (in program order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/ @@ -2479,10 +2479,10 @@ Changes: ## jax (0.2.0) (September 23 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.77...jax-v0.2.0). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.77...jax-v0.2.0). * Improvements: * Omnistaging on by default. See {jax-issue}`#3370` and - [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md) + [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md) ## jax (0.1.77) (September 15 2020) @@ -2496,11 +2496,11 @@ Changes: ## jax 0.1.76 (September 8, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.75...jax-v0.1.76). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.75...jax-v0.1.76). ## jax 0.1.75 (July 30, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.74...jax-v0.1.75). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.74...jax-v0.1.75). * Bug Fixes: * make jnp.abs() work for unsigned inputs (#3914) * Improvements: @@ -2508,7 +2508,7 @@ Changes: ## jax 0.1.74 (July 29, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.73...jax-v0.1.74). * New Features: * BFGS (#3101) * TPU support for half-precision arithmetic (#3878) @@ -2525,7 +2525,7 @@ Changes: ## jax 0.1.73 (July 22, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.72...jax-v0.1.73). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.72...jax-v0.1.73). * The minimum jaxlib version is now 0.1.51. * New Features: * jax.image.resize. (#3703) @@ -2563,14 +2563,14 @@ Changes: ## jax 0.1.72 (June 28, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.71...jax-v0.1.72). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.71...jax-v0.1.72). * Bug fixes: * Fix an odeint bug introduced in the previous release, see {jax-issue}`#3587`. ## jax 0.1.71 (June 25, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.70...jax-v0.1.71). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.70...jax-v0.1.71). * The minimum jaxlib version is now 0.1.48. * Bug fixes: * Allow `jax.experimental.ode.odeint` dynamics functions to close over @@ -2606,7 +2606,7 @@ Changes: ## jax 0.1.70 (June 8, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.69...jax-v0.1.70). * New features: * `lax.switch` introduces indexed conditionals with multiple branches, together with a generalization of the `cond` @@ -2615,11 +2615,11 @@ Changes: ## jax 0.1.69 (June 3, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.68...jax-v0.1.69). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.68...jax-v0.1.69). ## jax 0.1.68 (May 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.67...jax-v0.1.68). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.67...jax-v0.1.68). * New features: * {func}`lax.cond` supports a single-operand form, taken as the argument to both branches @@ -2630,7 +2630,7 @@ Changes: ## jax 0.1.67 (May 12, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.66...jax-v0.1.67). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.66...jax-v0.1.67). * New features: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. @@ -2648,7 +2648,7 @@ Changes: ## jax 0.1.66 (May 5, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.65...jax-v0.1.66). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.65...jax-v0.1.66). * New features: * Support for `in_axes=None` on {func}`pmap` {jax-issue}`#2896`. @@ -2661,7 +2661,7 @@ Changes: ## jax 0.1.65 (April 30, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.64...jax-v0.1.65). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.64...jax-v0.1.65). * New features: * Differentiation of determinants of singular matrices {jax-issue}`#2809`. @@ -2679,7 +2679,7 @@ Changes: ## jax 0.1.64 (April 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.63...jax-v0.1.64). * New features: * Add syntactic sugar for functional indexed updates {jax-issue}`#2684`. @@ -2706,7 +2706,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). * Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. @@ -2727,7 +2727,7 @@ Changes: ## jax 0.1.62 (March 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.61...jax-v0.1.62). * JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer. * Removed the internal function `lax._safe_mul`, which implemented the convention `0. * nan == 0.`. This change means some programs when @@ -2745,13 +2745,13 @@ Changes: ## jax 0.1.61 (March 17, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.60...jax-v0.1.61). * Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5. ## jax 0.1.60 (March 17, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.59...jax-v0.1.60). * New features: * {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows the user to specify arguments that should be treated as compile-time @@ -2777,7 +2777,7 @@ Changes: ## jax 0.1.59 (February 11, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.58...jax-v0.1.59). * Breaking changes * The minimum jaxlib version is now 0.1.38. @@ -2809,7 +2809,7 @@ Changes: ## jax 0.1.58 (January 28, 2020) -* [GitHub commits](https://github.com/google/jax/compare/46014da21...jax-v0.1.58). +* [GitHub commits](https://github.com/jax-ml/jax/compare/46014da21...jax-v0.1.58). * Breaking changes * JAX has dropped Python 2 support, because Python 2 reached its end of life on diff --git a/CITATION.bib b/CITATION.bib index 88049a146..777058b5a 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -1,7 +1,7 @@ @software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, - url = {http://github.com/google/jax}, + url = {http://github.com/jax-ml/jax}, version = {0.3.13}, year = {2018}, } diff --git a/README.md b/README.md index 35307bee3..d67bdac82 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@
-logo +logo
# Transformable numerical computing at scale -![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg) +![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jax) [**Quickstart**](#quickstart-colab-in-the-cloud) @@ -50,7 +50,7 @@ parallel programming of multiple accelerators, with more to come. This is a research project, not an official Google product. Expect bugs and [sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). Please help by trying it out, [reporting -bugs](https://github.com/google/jax/issues), and letting us know what you +bugs](https://github.com/jax-ml/jax/issues), and letting us know what you think! ```python @@ -84,16 +84,16 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: - [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) -- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) +- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU -Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs). +Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). For a deeper dive into JAX: - [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) - [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) - See the [full list of -notebooks](https://github.com/google/jax/tree/main/docs/notebooks). +notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). ## Transformations @@ -300,7 +300,7 @@ print(normalize(jnp.arange(4.))) # prints [0. 0.16666667 0.33333334 0.5 ] ``` -You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more +You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more sophisticated communication patterns. It all composes, so you're free to differentiate through parallel computations: @@ -333,9 +333,9 @@ When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the backward pass of the computation is parallelized just like the forward pass. See the [SPMD -Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) and the [SPMD MNIST classifier from scratch -example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) +example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) for more. ## Current gotchas @@ -349,7 +349,7 @@ Some standouts: 1. [In-place mutating updates of arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md). + different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. @@ -437,7 +437,7 @@ To cite this repository: @software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, - url = {http://github.com/google/jax}, + url = {http://github.com/jax-ml/jax}, version = {0.3.13}, year = {2018}, } diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index cb5a42ced..edaa71b93 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -451,7 +451,7 @@ "id": "jC-KIMQ1q-lK" }, "source": [ - "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." + "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." ] }, { diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index 4952cdbe9..d7ba5ed33 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -837,7 +837,7 @@ "id": "f-FBsWeo1AXE" }, "source": [ - "" + "" ] }, { @@ -847,7 +847,7 @@ "id": "jC-KIMQ1q-lK" }, "source": [ - "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." + "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." ] }, { diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index 981f0a9e8..ea126ac4f 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -15,13 +15,13 @@ "id": "sk-3cPGIBTq8" }, "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n", "\n", "This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n", "\n", "**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.\n", "\n", - "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)." + "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)." ] }, { diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index 4a795f718..db3dc5f30 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -13,25 +13,25 @@ VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab: -### [Pmap Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +### [Pmap Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) A guide to getting started with `pmap`, a transform for easily distributing SPMD computations across devices. -### [Lorentz ODE Solver](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb) +### [Lorentz ODE Solver](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb) Contributed by Alex Alemi (alexalemi@) Solve and plot parallel ODE solutions with `pmap`. - + -### [Wave Equation](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb) +### [Wave Equation](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb) Contributed by Stephan Hoyer (shoyer@) Solve the wave equation with `pmap`, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU. -![](https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/wave_movie.gif) +![](https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/wave_movie.gif) -### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb) +### [JAX Demo](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb) An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/) and the [Compilers for ML workshop at CGO 2020](https://www.c4ml.org/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`. ## Performance notes @@ -53,7 +53,7 @@ By default\*, matrix multiplication in JAX on TPUs [uses bfloat16](https://cloud JAX also adds the `bfloat16` dtype, which you can use to explicitly cast arrays to bfloat16, e.g., `jax.numpy.array(x, dtype=jax.numpy.bfloat16)`. -\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/google/jax/issues/2161) if it affects you! +\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/jax-ml/jax/issues/2161) if it affects you! ## Running JAX on a Cloud TPU VM @@ -65,8 +65,8 @@ documentation](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPU VM), please email , or if you are a [TRC](https://sites.research.google/trc/) member. You can also [file a -JAX issue](https://github.com/google/jax/issues) or [ask a discussion -question](https://github.com/google/jax/discussions) for any issues with these +JAX issue](https://github.com/jax-ml/jax/issues) or [ask a discussion +question](https://github.com/jax-ml/jax/discussions) for any issues with these notebooks or using JAX in general. If you have any other questions or comments regarding JAX on Cloud TPUs, please diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/_tutorials/advanced-autodiff.md index 180f65f5d..287487ad4 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/_tutorials/advanced-autodiff.md @@ -571,7 +571,7 @@ print("Naive full Hessian materialization") ### Jacobian-Matrix and Matrix-Jacobian products -Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: +Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: ```{code-cell} # Isolate the function from the weight matrix to the predictions diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index ed242ecc5..9a956670c 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -27,7 +27,7 @@ "metadata": {}, "source": [ "[![Open in\n", - "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)" + "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)" ] }, { @@ -1781,7 +1781,7 @@ "metadata": {}, "source": [ "This is precisely the issue that\n", - "[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n", + "[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.\n", "We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n", "applied, regardless of whether any inputs to `bind` are boxed in corresponding\n", "`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 471dd7c63..937e1012a 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -33,7 +33,7 @@ limitations under the License. ``` [![Open in -Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb) +Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb) +++ @@ -1399,7 +1399,7 @@ print(jaxpr) ``` This is precisely the issue that -[omnistaging](https://github.com/google/jax/pull/3370) fixed. +[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed. We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always applied, regardless of whether any inputs to `bind` are boxed in corresponding `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` diff --git a/docs/autodidax.py b/docs/autodidax.py index 6d295fc50..c10e6365e 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -27,7 +27,7 @@ # --- # [![Open in -# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb) +# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb) # # Autodidax: JAX core from scratch # @@ -1396,7 +1396,7 @@ print(jaxpr) # This is precisely the issue that -# [omnistaging](https://github.com/google/jax/pull/3370) fixed. +# [omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed. # We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always # applied, regardless of whether any inputs to `bind` are boxed in corresponding # `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` diff --git a/docs/beginner_guide.rst b/docs/beginner_guide.rst index 204659ec2..783d3b49a 100644 --- a/docs/beginner_guide.rst +++ b/docs/beginner_guide.rst @@ -52,4 +52,4 @@ questions answered are: .. _Flax: https://flax.readthedocs.io/ .. _Haiku: https://dm-haiku.readthedocs.io/ .. _JAX on StackOverflow: https://stackoverflow.com/questions/tagged/jax -.. _JAX GitHub discussions: https://github.com/google/jax/discussions \ No newline at end of file +.. _JAX GitHub discussions: https://github.com/jax-ml/jax/discussions \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index ed6fcfd0d..e77916e26 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -168,7 +168,7 @@ html_theme = 'sphinx_book_theme' # documentation. html_theme_options = { 'show_toc_level': 2, - 'repository_url': 'https://github.com/google/jax', + 'repository_url': 'https://github.com/jax-ml/jax', 'use_repository_button': True, # add a "link to repository" button 'navigation_with_keys': False, } @@ -345,7 +345,7 @@ def linkcode_resolve(domain, info): return None filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__)) lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else "" - return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}" + return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}" # Generate redirects from deleted files to new sources rediraffe_redirects = { diff --git a/docs/contributing.md b/docs/contributing.md index d7fa6e9da..99d78453c 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -5,22 +5,22 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are several ways to contribute, including: -- Answering questions on JAX's [discussions page](https://github.com/google/jax/discussions) +- Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) - Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) -- Contributing to JAX's [code-base](http://github.com/google/jax/) -- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/google/jax#neural-network-libraries) +- Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) +- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). ## Ways to contribute We welcome pull requests, in particular for those issues marked with -[contributions welcome](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or -[good first issue](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). +[contributions welcome](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or +[good first issue](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). For other proposals, we ask that you first open a GitHub -[Issue](https://github.com/google/jax/issues/new/choose) or -[Discussion](https://github.com/google/jax/discussions) +[Issue](https://github.com/jax-ml/jax/issues/new/choose) or +[Discussion](https://github.com/jax-ml/jax/discussions) to seek feedback on your planned contribution. ## Contributing code using pull requests @@ -33,7 +33,7 @@ Follow these steps to contribute code: For more information, see the Pull Request Checklist below. 2. Fork the JAX repository by clicking the **Fork** button on the - [repository page](http://www.github.com/google/jax). This creates + [repository page](http://www.github.com/jax-ml/jax). This creates a copy of the JAX repository in your own account. 3. Install Python >= 3.10 locally in order to run tests. @@ -52,7 +52,7 @@ Follow these steps to contribute code: changes. ```bash - git remote add upstream https://www.github.com/google/jax + git remote add upstream https://www.github.com/jax-ml/jax ``` 6. Create a branch where you will develop from: diff --git a/docs/developer.md b/docs/developer.md index 40ad51e87..4f3361413 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -6,7 +6,7 @@ First, obtain the JAX source code: ``` -git clone https://github.com/google/jax +git clone https://github.com/jax-ml/jax cd jax ``` @@ -26,7 +26,7 @@ If you're only modifying Python portions of JAX, we recommend installing pip install jaxlib ``` -See the [JAX readme](https://github.com/google/jax#installation) for full +See the [JAX readme](https://github.com/jax-ml/jax#installation) for full guidance on pip installation (e.g., for GPU and TPU support). ### Building `jaxlib` from source @@ -621,7 +621,7 @@ pytest --doctest-modules jax/_src/numpy/lax_numpy.py Keep in mind that there are several files that are marked to be skipped when the doctest command is run on the full package; you can see the details in -[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml) +[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml) ## Type checking @@ -712,7 +712,7 @@ jupytext --sync docs/notebooks/thinking_in_jax.ipynb ``` The jupytext version should match that specified in -[.pre-commit-config.yaml](https://github.com/google/jax/blob/main/.pre-commit-config.yaml). +[.pre-commit-config.yaml](https://github.com/jax-ml/jax/blob/main/.pre-commit-config.yaml). To check that the markdown and ipynb files are properly synced, you may use the [pre-commit](https://pre-commit.com/) framework to perform the same check used @@ -740,12 +740,12 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked Some of the notebooks are built automatically as part of the pre-submit checks and as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, -or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)). +or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else re-saves the notebook. We exclude some notebooks from the build, e.g., because they contain long computations. -See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs/conf.py). +See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs/conf.py). ### Documentation building on `readthedocs.io` @@ -772,7 +772,7 @@ I saw in the Readthedocs logs: mkvirtualenv jax-docs # A new virtualenv mkdir jax-docs # A new directory cd jax-docs -git clone --no-single-branch --depth 50 https://github.com/google/jax +git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax cd jax git checkout --force origin/test-docs git clean -d -f -f diff --git a/docs/export/export.md b/docs/export/export.md index 0ca1a6480..4e4d50556 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -153,7 +153,7 @@ JAX runtime system that are: an inference system that is already deployed when the exporting is done. (The particular compatibility window lengths are the same that JAX -[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), +[promised for jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow). The terminology “backward compatibility” is from the perspective of the consumer, e.g., the inference system.) @@ -626,7 +626,7 @@ We list here a history of the calling convention version numbers: June 13th, 2023 (JAX 0.4.13). * Version 7 adds support for `stablehlo.shape_assertion` operations and for `shape_assertions` specified in `disabled_checks`. - See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule since July 12th, 2023 (cl/547482522), available in JAX serialization since July 20th, 2023 (JAX 0.4.14), and the default since August 12th, 2023 (JAX 0.4.15). @@ -721,7 +721,7 @@ that live in jaxlib): 2. Day “D”, we add the new custom call target `T_NEW`. We should create a new custom call target, and clean up the old target roughly after 6 months, rather than updating `T` in place: - * See the example [PR #20997](https://github.com/google/jax/pull/20997) + * See the example [PR #20997](https://github.com/jax-ml/jax/pull/20997) implementing the steps below. * We add the custom call target `T_NEW`. * We change the JAX lowering rules that were previous using `T`, diff --git a/docs/export/jax2tf.md b/docs/export/jax2tf.md index 498a0418f..9c0ee90a0 100644 --- a/docs/export/jax2tf.md +++ b/docs/export/jax2tf.md @@ -2,4 +2,4 @@ ## Interoperation with TensorFlow -See the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +See the [JAX2TF documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). diff --git a/docs/faq.rst b/docs/faq.rst index 3ac7d89fb..af14f382b 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -372,7 +372,7 @@ device. Jitted functions behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device. -(Before `PR #6002 `_ in March 2021 +(Before `PR #6002 `_ in March 2021 there was some laziness in creation of array constants, so that ``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually create the array of zeros on ``jax.devices()[1]``, instead of creating the @@ -385,7 +385,7 @@ and its use is not recommended.) For a worked-out example, we recommend reading through ``test_computation_follows_data`` in -`multi_device_test.py `_. +`multi_device_test.py `_. .. _faq-benchmark: @@ -691,7 +691,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.:: Additional reading: - * `Issue: gradients through jnp.where when one of branches is nan `_. + * `Issue: gradients through jnp.where when one of branches is nan `_. * `How to avoid NaN gradients when using where `_. diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 7f7bcc07c..04ae80cbf 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -406,7 +406,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)." + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." ] }, { @@ -492,7 +492,7 @@ "source": [ "At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n", "One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n", - "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n", + "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n", "\n", "One other JAX feature that this example doesn't support is higher-order AD.\n", "It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n", diff --git a/docs/ffi.md b/docs/ffi.md index d96d9ff8c..03acf876b 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -333,7 +333,7 @@ def rms_norm_not_vectorized(x, eps=1e-5): jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) ``` -If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues). +If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues). +++ @@ -406,7 +406,7 @@ np.testing.assert_allclose( At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`. One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode. -JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice. +JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice. One other JAX feature that this example doesn't support is higher-order AD. It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here. diff --git a/docs/installation.md b/docs/installation.md index 93df4a240..acb802ea9 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -176,7 +176,7 @@ installation. JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package. Make sure that it is present in your CUDA installation. -Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) +Please let the JAX team know on [the GitHub issue tracker](https://github.com/jax-ml/jax/issues) if you run into any errors or problems with the pre-built wheels. (docker-containers-nvidia-gpu)= @@ -216,7 +216,7 @@ refer to **Note:** There are several caveats with the Metal plugin: * The Metal plugin is new and experimental and has a number of - [known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). + [known issues](https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). Please report any issues on the JAX issue tracker. * The Metal plugin currently requires very specific versions of `jax` and `jaxlib`. This restriction will be relaxed over time as the plugin API diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index 389cc0b5a..61d219d1b 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -9,7 +9,7 @@ Let's first make a JAX issue. But if you can pinpoint the commit that triggered the regression, it will really help us. This document explains how we identified the commit that caused a -[15% performance regression](https://github.com/google/jax/issues/17686). +[15% performance regression](https://github.com/jax-ml/jax/issues/17686). ## Steps @@ -34,7 +34,7 @@ containers](https://github.com/NVIDIA/JAX-Toolbox). - test_runner.sh: will start the containers and the test. - test.sh: will install missing dependencies and run the test -Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686 +Here are real example scripts used for the issue: https://github.com/jax-ml/jax/issues/17686 - test_runner.sh: ``` for m in 7 8 9; do diff --git a/docs/jep/11830-new-remat-checkpoint.md b/docs/jep/11830-new-remat-checkpoint.md index da0adaf18..019188349 100644 --- a/docs/jep/11830-new-remat-checkpoint.md +++ b/docs/jep/11830-new-remat-checkpoint.md @@ -14,7 +14,7 @@ ## What’s going on? -As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade) +As of [#11830](https://github.com/jax-ml/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade) ## How can I disable the change, and go back to the old behavior for now? @@ -29,7 +29,7 @@ If you need to revert to the old implementation, **please reach out** on a GitHu As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer available. If you have an issue, please reach out on [the issue -tracker](https://github.com/google/jax/issues) so we can help fix it! +tracker](https://github.com/jax-ml/jax/issues) so we can help fix it! ## Why are we doing this? @@ -82,7 +82,7 @@ The old `jax.checkpoint` implementation was forced to save the value of `a`, whi ### Significantly less Python overhead in some cases -The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/google/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice! +The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/jax-ml/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice! ### Enabling new JAX features by simplifying internals diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 9137e3e71..7a20958c5 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -12,7 +12,7 @@ The current state of type annotations in JAX is a bit patchwork, and efforts to This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX. Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally. -In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/google/jax/pull/9917), [PR #10322](https://github.com/google/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python. +In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/jax-ml/jax/pull/9917), [PR #10322](https://github.com/jax-ml/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python. This document details JAX's goals and recommendations for type annotations within the package. ## Why type annotations? @@ -21,7 +21,7 @@ There are a number of reasons that a Python project might wish to annotate their ### Level 1: Annotations as documentation -When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]: +When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]: ```python Array = Any @@ -44,14 +44,14 @@ Many modern IDEs take advantage of type annotations as inputs to [intelligent co This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development. -JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]: +JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]: ```python def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray: ... ``` -In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result: +In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result: ![VSCode Intellisense Screenshot](../_static/vscode-completion.png) @@ -232,7 +232,7 @@ assert jit(f)(x) # x will be a tracer ``` Again, there are a couple mechanisms that could be used for this: -- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/google/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)). +- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/jax-ml/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)). - define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer` - restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer` diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index bec060001..a5625abf8 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -170,7 +170,7 @@ print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32) This module could expose our mechanism for defining new RNG implementations, and functions for working with PRNG key internals -(see issue [#9263](https://github.com/google/jax/issues/9263)), +(see issue [#9263](https://github.com/jax-ml/jax/issues/9263)), such as the current `jax._src.prng.random_wrap` and `random_unwrap`. diff --git a/docs/jep/18137-numpy-scipy-scope.md b/docs/jep/18137-numpy-scipy-scope.md index 2371e11ee..eaebe8fb8 100644 --- a/docs/jep/18137-numpy-scipy-scope.md +++ b/docs/jep/18137-numpy-scipy-scope.md @@ -78,8 +78,8 @@ to JAX which have relatively complex implementations which are difficult to vali and introduce outsized maintenance burdens; an example is {func}`jax.scipy.special.bessel_jn`: as of the writing of this JEP, its current implementation is a non-straightforward iterative approximation that has -[convergence issues in some domains](https://github.com/google/jax/issues/12402#issuecomment-1384828637), -and [proposed fixes](https://github.com/google/jax/pull/17038/files) introduce further +[convergence issues in some domains](https://github.com/jax-ml/jax/issues/12402#issuecomment-1384828637), +and [proposed fixes](https://github.com/jax-ml/jax/pull/17038/files) introduce further complexity. Had we more carefully weighed the complexity and robustness of the implementation when accepting the contribution, we may have chosen not to accept this contribution to the package. diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index aa568adc0..ce149fa6f 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -35,9 +35,9 @@ behavior of their code. This customization Python control flow and workflows for NaN debugging. As **JAX developers** we want to write library functions, like -[`logit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83) +[`logit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83) and -[`expit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91), +[`expit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91), that are defined in terms of other primitives, but for the purposes of differentiation have primitive-like behavior in the sense that we want to define custom differentiation rules for them, which may be more numerically stable or @@ -50,9 +50,9 @@ looking to add custom differentiation rules for higher-order functions like want to be confident we’re not going to preclude good solutions to that problem. That is, our primary goals are -1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/google/jax/issues/1249)), and +1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/jax-ml/jax/issues/1249)), and 2. allow Python in custom VJPs, e.g. to debug NaNs - ([#1275](https://github.com/google/jax/issues/1275)). + ([#1275](https://github.com/jax-ml/jax/issues/1275)). Secondary goals are 3. clean up and simplify user experience (symbolic zeros, kwargs, etc) @@ -60,18 +60,18 @@ Secondary goals are `odeint`, `root`, etc. Overall, we want to close -[#116](https://github.com/google/jax/issues/116), -[#1097](https://github.com/google/jax/issues/1097), -[#1249](https://github.com/google/jax/issues/1249), -[#1275](https://github.com/google/jax/issues/1275), -[#1366](https://github.com/google/jax/issues/1366), -[#1723](https://github.com/google/jax/issues/1723), -[#1670](https://github.com/google/jax/issues/1670), -[#1875](https://github.com/google/jax/issues/1875), -[#1938](https://github.com/google/jax/issues/1938), +[#116](https://github.com/jax-ml/jax/issues/116), +[#1097](https://github.com/jax-ml/jax/issues/1097), +[#1249](https://github.com/jax-ml/jax/issues/1249), +[#1275](https://github.com/jax-ml/jax/issues/1275), +[#1366](https://github.com/jax-ml/jax/issues/1366), +[#1723](https://github.com/jax-ml/jax/issues/1723), +[#1670](https://github.com/jax-ml/jax/issues/1670), +[#1875](https://github.com/jax-ml/jax/issues/1875), +[#1938](https://github.com/jax-ml/jax/issues/1938), and replace the custom_transforms machinery (from -[#636](https://github.com/google/jax/issues/636), -[#818](https://github.com/google/jax/issues/818), +[#636](https://github.com/jax-ml/jax/issues/636), +[#818](https://github.com/jax-ml/jax/issues/818), and others). ## Non-goals @@ -400,7 +400,7 @@ There are some other bells and whistles to the API: resolved to positions using the `inspect` module. This is a bit of an experiment with Python 3’s improved ability to programmatically inspect argument signatures. I believe it is sound but not complete, which is a fine place to be. - (See also [#2069](https://github.com/google/jax/issues/2069).) + (See also [#2069](https://github.com/jax-ml/jax/issues/2069).) * Arguments can be marked non-differentiable using `nondiff_argnums`, and as with `jit`’s `static_argnums` these arguments don’t have to be JAX types. We need to set a convention for how these arguments are passed to the rules. For a primal @@ -433,5 +433,5 @@ There are some other bells and whistles to the API: `custom_lin` to the tangent values; `custom_lin` carries with it the user’s custom backward-pass function, and as a primitive it only has a transpose rule. - * This mechanism is described more in [#636](https://github.com/google/jax/issues/636). + * This mechanism is described more in [#636](https://github.com/jax-ml/jax/issues/636). * To prevent diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 65235dc64..1e2270e05 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -9,7 +9,7 @@ notebook. ## What to update -After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments +After JAX [PR #4008](https://github.com/jax-ml/jax/pull/4008), the arguments passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or containers of `Tracer`s), which basically means to allow for arbitrarily-transformable code `nondiff_argnums` shouldn't be used for @@ -95,7 +95,7 @@ acted very much like lexical closure. But lexical closure over `Tracer`s wasn't at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing `nondiff_argnums` that way was a mistake! -**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure +**[PR #4008](https://github.com/jax-ml/jax/pull/4008) fixes all lexical closure issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp` and `custom_vjp` functions and rules can close over `Tracer`s to our hearts' content. For all non-autodiff transformations, things will Just Work. For @@ -120,9 +120,9 @@ manageable, until you think through how we have to handle arbitrary pytrees! Moreover, that complexity isn't necessary: if user code treats array-like non-differentiable arguments just like regular arguments and residuals, everything already works. (Before -[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about +[#4039](https://github.com/jax-ml/jax/pull/4039) JAX might've complained about involving integer-valued inputs and outputs in autodiff, but after -[#4039](https://github.com/google/jax/pull/4039) those will just work!) +[#4039](https://github.com/jax-ml/jax/pull/4039) those will just work!) Unlike `custom_vjp`, it was easy to make `custom_jvp` work with `nondiff_argnums` arguments that were `Tracer`s. So these updates only need to diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index eb68ee5f0..f95c15f40 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -20,7 +20,7 @@ This is more of an upgrade guide than a design doc. ### What's going on? A change to JAX's tracing infrastructure called “omnistaging” -([google/jax#3370](https://github.com/google/jax/pull/3370)) was switched on in +([jax-ml/jax#3370](https://github.com/jax-ml/jax/pull/3370)) was switched on in jax==0.2.0. This change improves memory performance, trace execution time, and simplifies jax internals, but may cause some existing code to break. Breakage is usually a result of buggy code, so long-term it’s best to fix the bugs, but @@ -191,7 +191,7 @@ and potentially even fragmenting memory. (The `broadcast` that corresponds to the construction of the zeros array for `jnp.zeros_like(x)` is staged out because JAX is lazy about very simple -expressions from [google/jax#1668](https://github.com/google/jax/pull/1668). After +expressions from [jax-ml/jax#1668](https://github.com/jax-ml/jax/pull/1668). After omnistaging, we can remove that lazy sublanguage and simplify JAX internals.) The reason the creation of `mask` is not staged out is that, before omnistaging, diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md index 828b95e8c..d520f6f63 100644 --- a/docs/jep/9263-typed-keys.md +++ b/docs/jep/9263-typed-keys.md @@ -321,7 +321,7 @@ Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same extended dtype mechanism elsewhere internally. For example, the `jax._src.core.bint` object, a bounded integer type used for experimental work on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies -the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)). +the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/jax-ml/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)). ### PRNG dtypes PRNG dtypes are defined as a particular case of extended dtypes. Specifically, diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index 3e99daabe..a1ede3177 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -8,7 +8,7 @@ "source": [ "# Design of Type Promotion Semantics for JAX\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n", "\n", "*Jake VanderPlas, December 2021*\n", "\n", diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index cdb1f7805..ff67a8c21 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -16,7 +16,7 @@ kernelspec: # Design of Type Promotion Semantics for JAX -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) *Jake VanderPlas, December 2021* diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index 759a9be86..b964aa2af 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -58,11 +58,11 @@ These constraints imply the following rules for releases: * If a new `jaxlib` is released, a `jax` release must be made at the same time. These -[version constraints](https://github.com/google/jax/blob/main/jax/version.py) +[version constraints](https://github.com/jax-ml/jax/blob/main/jax/version.py) are currently checked by `jax` at import time, instead of being expressed as Python package version constraints. `jax` checks the `jaxlib` version at runtime rather than using a `pip` package version constraint because we -[provide separate `jaxlib` wheels](https://github.com/google/jax#installation) +[provide separate `jaxlib` wheels](https://github.com/jax-ml/jax#installation) for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we do not know which is the right choice for any given user, we do not want `pip` to install a `jaxlib` package for us automatically. @@ -119,7 +119,7 @@ no released `jax` version uses that API. ## How is the source to `jaxlib` laid out? `jaxlib` is split across two main repositories, namely the -[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib) +[`jaxlib/` subdirectory in the main JAX repository](https://github.com/jax-ml/jax/tree/main/jaxlib) and in the [XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla). The JAX-specific pieces inside XLA are primarily in the @@ -146,7 +146,7 @@ level. `jaxlib` is built using Bazel out of the `jax` repository. The pieces of `jaxlib` from the XLA repository are incorporated into the build -[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE). +[as a Bazel submodule](https://github.com/jax-ml/jax/blob/main/WORKSPACE). To update the version of XLA used during the build, one must update the pinned version in the Bazel `WORKSPACE`. This is done manually on an as-needed basis, but can be overridden on a build-by-build basis. diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 194eb0cb9..f9dda2657 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -32,7 +32,7 @@ should be linked to this issue. Then create a pull request that adds a file named `%d-{short-title}.md` - with the number being the issue number. -.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP +.. _JEP label: https://github.com/jax-ml/jax/issues?q=label%3AJEP .. toctree:: :maxdepth: 1 diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 0cffc22f1..71bd45276 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" ] }, { @@ -661,7 +661,7 @@ "source": [ "Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n", "\n", - "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)." + "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)." ] }, { @@ -1003,7 +1003,7 @@ "id": "COjzGBpO4tzL" }, "source": [ - "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", + "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", "\n", "The random state is described by a special array element that we call a __key__:" ] @@ -1349,7 +1349,7 @@ "\n", "For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n", "\n", - "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", + "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", "\n", "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 543d9ecb1..741fa3af0 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +++ {"id": "4k5PVzEo2uJO"} @@ -312,7 +312,7 @@ jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan) Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error. -Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior). +Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior). +++ {"id": "LwB07Kx5sgHu"} @@ -460,7 +460,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha +++ {"id": "COjzGBpO4tzL"} -JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. +JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. The random state is described by a special array element that we call a __key__: @@ -623,7 +623,7 @@ When we `jit`-compile a function, we usually want to compile a version of the fu For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. -To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. +To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index dd7a36e57..5c09a0a4f 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 930887af1..8a9b93155 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -17,7 +17,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) There are two ways to define differentiation rules in JAX: diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 8bc0e0a52..32d332d9a 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -17,7 +17,7 @@ "id": "pFtQjv4SzHRj" }, "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", "\n", "This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer." ] diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index 97b07172b..2142db986 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -19,7 +19,7 @@ kernelspec: +++ {"id": "pFtQjv4SzHRj"} -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer. diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb index 0c20fc47d..e9924e18d 100644 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ b/docs/notebooks/How_JAX_primitives_work.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", "\n", "*necula@google.com*, October 2019.\n", "\n", diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md index b926c22ea..7c24ac11a 100644 --- a/docs/notebooks/How_JAX_primitives_work.md +++ b/docs/notebooks/How_JAX_primitives_work.md @@ -17,7 +17,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) *necula@google.com*, October 2019. diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index a4a4d7d16..a7ef2a017 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", "\n", "**Copyright 2018 The JAX Authors.**\n", "\n", @@ -32,9 +32,9 @@ "id": "B_XlLLpcWjkA" }, "source": [ - "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", + "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", + "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 03b8415fc..cd98022e7 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) **Copyright 2018 The JAX Authors.** @@ -35,9 +35,9 @@ limitations under the License. +++ {"id": "B_XlLLpcWjkA"} -![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) +![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). +Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 2c231bf99..00ba9186e 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" ] }, { @@ -79,7 +79,7 @@ "id": "gA8V51wZdsjh" }, "source": [ - "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/google/jax#how-it-works) section in the README." + "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/jax-ml/jax#how-it-works) section in the README." ] }, { @@ -320,7 +320,7 @@ "source": [ "Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n", "\n", - "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover." + "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover." ] }, { @@ -333,7 +333,7 @@ "\n", "An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n", "\n", - "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)." + "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234)." ] }, { diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 41d7a7e51..10c4e7cb6 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +++ {"id": "r-3vMiKRYXPJ"} @@ -57,7 +57,7 @@ fast_f = jit(f) +++ {"id": "gA8V51wZdsjh"} -When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/google/jax#how-it-works) section in the README. +When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/jax-ml/jax#how-it-works) section in the README. +++ {"id": "2Th1vYLVaFBz"} @@ -223,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5)) Notice that `eval_jaxpr` will always return a flat list even if the original function does not. -Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover. +Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover. +++ {"id": "0vb2ZoGrCMM4"} @@ -231,7 +231,7 @@ Furthermore, this interpreter does not handle higher-order primitives (like `jit An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry. -It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234). +It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234). ```{code-cell} ipython3 :id: gSMIT2z1vUpO diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 3f2f0fd56..5538b70da 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "\n", "JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics." ] @@ -255,7 +255,7 @@ "id": "cJ2NxiN58bfI" }, "source": [ - "You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)." + "You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)." ] }, { @@ -1015,7 +1015,7 @@ "source": [ "### Jacobian-Matrix and Matrix-Jacobian products\n", "\n", - "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products." + "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products." ] }, { diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index 496d676f7..db6fde805 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics. @@ -151,7 +151,7 @@ print(grad(loss2)({'W': W, 'b': b})) +++ {"id": "cJ2NxiN58bfI"} -You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.). +You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.). +++ {"id": "PaCHzAtGruBz"} @@ -592,7 +592,7 @@ print("Naive full Hessian materialization") ### Jacobian-Matrix and Matrix-Jacobian products -Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products. +Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products. ```{code-cell} ipython3 :id: asAWvxVaCmsx diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index f628625bd..9d91804b6 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", "\n", "JAX provides a number of interfaces to compute convolutions across data, including:\n", "\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 83ab2d9fd..b98099aa9 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) JAX provides a number of interfaces to compute convolutions across data, including: diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 91f2ee571..c31a99746 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -40,11 +40,11 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", "\n", "_Forked from_ `neural_network_and_data_loading.ipynb`\n", "\n", - "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", + "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 480e7477b..53b7d4735 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -38,11 +38,11 @@ limitations under the License. -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) _Forked from_ `neural_network_and_data_loading.ipynb` -![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) +![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index e4f9d888e..b5f8074c0 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", "\n", "JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively." ] diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index dd0c73ec7..b3672b90e 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -17,7 +17,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively. diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 9aef1a8eb..dccc83168 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", "\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", "\n", diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index 5989a87bc..3f836e680 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs. diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 93d7e5547..4a2d4daa6 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -21,7 +21,7 @@ software emulation, and can slow down the computation. If you see unexpected outputs, please compare them against a kernel run with ``interpret=True`` passed in to ``pallas_call``. If the results diverge, - please file a `bug report `_. + please file a `bug report `_. What is a TPU? -------------- diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 1d6a5d9b7..47a7587b6 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -29,7 +29,7 @@ f(x) ### Setting cache directory The compilation cache is enabled when the -[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206) +[cache location](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206) is set. This should be done prior to the first compilation. Set the location as follows: @@ -54,7 +54,7 @@ os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache" jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") ``` -(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18) +(3) Using [`set_cache_dir()`](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18) ```python from jax.experimental.compilation_cache import compilation_cache as cc diff --git a/docs/sphinxext/jax_extensions.py b/docs/sphinxext/jax_extensions.py index 3a7855763..7cce8b882 100644 --- a/docs/sphinxext/jax_extensions.py +++ b/docs/sphinxext/jax_extensions.py @@ -26,14 +26,14 @@ def jax_issue_role(name, rawtext, text, lineno, inliner, options=None, :jax-issue:`1234` This will output a hyperlink of the form - `#1234 `_. These links work even + `#1234 `_. These links work even for PR numbers. """ text = text.lstrip('#') if not text.isdigit(): raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.") options = {} if options is None else options - url = f"https://github.com/google/jax/issues/{text}" + url = f"https://github.com/jax-ml/jax/issues/{text}" node = nodes.reference(rawtext, '#' + text, refuri=url, **options) return [node], [] diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 4eb6e7a66..2ff82e043 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -234,4 +234,4 @@ Handling parameters manually seems fine if you're dealing with two parameters, b 2) Are we supposed to pipe all these things around manually? -The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples. +The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/jax-ml/jax#neural-network-libraries) for some examples. diff --git a/jax/__init__.py b/jax/__init__.py index e2e302adb..c6e073699 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -29,7 +29,7 @@ except Exception as exc: # Defensively swallow any exceptions to avoid making jax unimportable from warnings import warn as _warn _warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report " - f"an issue at https://github.com/google/jax/issues") + f"an issue at https://github.com/jax-ml/jax/issues") del _warn del _cloud_tpu_init @@ -38,7 +38,7 @@ import jax.core as _core del _core # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.basearray import Array as Array from jax import tree as tree diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 8c7fe2f48..39df07359 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -546,7 +546,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): # To avoid precision mismatches in fwd and bwd passes due to XLA excess # precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls - # on producers of any residuals. See https://github.com/google/jax/pull/22244. + # on producers of any residuals. See https://github.com/jax-ml/jax/pull/22244. jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res) # compute known outputs and residuals (hoisted out of remat primitive) diff --git a/jax/_src/api.py b/jax/_src/api.py index aae99a28b..bd8a95195 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -956,7 +956,7 @@ def vmap(fun: F, # list: if in_axes is not a leaf, it must be a tuple of trees. However, # in cases like these users expect tuples and lists to be treated # essentially interchangeably, so we canonicalize lists to tuples here - # rather than raising an error. https://github.com/google/jax/issues/2367 + # rather than raising an error. https://github.com/jax-ml/jax/issues/2367 in_axes = tuple(in_axes) if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}): @@ -2505,7 +2505,7 @@ class ShapeDtypeStruct: def __hash__(self): # TODO(frostig): avoid the conversion from dict by addressing - # https://github.com/google/jax/issues/8182 + # https://github.com/jax-ml/jax/issues/8182 return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) def _sds_aval_mapping(x): diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 453a4eba4..3a18dcdfa 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -196,7 +196,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): device_assignment = axis_context.device_assignment if device_assignment is None: raise AssertionError( - "Please file a bug at https://github.com/google/jax/issues") + "Please file a bug at https://github.com/jax-ml/jax/issues") try: device_index = device_assignment.index(device) except IndexError as e: diff --git a/jax/_src/config.py b/jax/_src/config.py index fe56ec68f..b21d2f35f 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1170,7 +1170,7 @@ softmax_custom_jvp = bool_state( upgrade=True, help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' 'improve memory usage and stability. Set True to use new ' - 'behavior. See https://github.com/google/jax/pull/15677'), + 'behavior. See https://github.com/jax-ml/jax/pull/15677'), update_global_hook=lambda val: _update_global_jit_state( softmax_custom_jvp=val), update_thread_local_hook=lambda val: update_thread_local_jit_state( diff --git a/jax/_src/core.py b/jax/_src/core.py index 057a79925..bff59625b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -935,7 +935,7 @@ aval_method = namedtuple("aval_method", ["fun"]) class EvalTrace(Trace): - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 def pure(self, x): return x lift = sublift = pure @@ -998,7 +998,7 @@ class MainTrace: return self.trace_type(self, cur_sublevel(), **self.payload) class TraceStack: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack: list[MainTrace] dynamic: MainTrace @@ -1167,7 +1167,7 @@ def _why_alive(ignore_ids: set[int], x: Any) -> str: # parent->child jump. We do that by setting `parent` here to be a # grandparent (or great-grandparent) of `child`, and then handling that case # in _why_alive_container_info. See example: - # https://github.com/google/jax/pull/13022#discussion_r1008456599 + # https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599 # To prevent this collapsing behavior, just comment out this code block. if (isinstance(parent, dict) and getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]): @@ -1213,7 +1213,7 @@ def _why_alive_container_info(container, obj_id) -> str: @contextmanager def new_main(trace_type: type[Trace], dynamic: bool = False, **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack = thread_local_state.trace_state.trace_stack level = stack.next_level() main = MainTrace(level, trace_type, **payload) @@ -1254,7 +1254,7 @@ def dynamic_level() -> int: @contextmanager def new_base_main(trace_type: type[Trace], **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack = thread_local_state.trace_state.trace_stack main = MainTrace(0, trace_type, **payload) prev_dynamic, stack.dynamic = stack.dynamic, main @@ -1319,7 +1319,7 @@ def ensure_compile_time_eval(): else: return jnp.cos(x) - Here's a real-world example from https://github.com/google/jax/issues/3974:: + Here's a real-world example from https://github.com/jax-ml/jax/issues/3974:: import jax import jax.numpy as jnp @@ -1680,7 +1680,7 @@ class UnshapedArray(AbstractValue): @property def shape(self): msg = ("UnshapedArray has no shape. Please open an issue at " - "https://github.com/google/jax/issues because it's unexpected for " + "https://github.com/jax-ml/jax/issues because it's unexpected for " "UnshapedArray instances to ever be produced.") raise TypeError(msg) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 07df2321b..35e7d3343 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -191,7 +191,7 @@ def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree): # TODO(frostig): assert these also equal: # treedef_tuple((in_tree, in_tree)) - # once https://github.com/google/jax/issues/9066 is fixed + # once https://github.com/jax-ml/jax/issues/9066 is fixed assert tree_ps_ts == tree_ps_ts2 del tree_ps_ts2 diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 88be655a0..f5ecdfcda 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1144,7 +1144,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]: def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. - # See https://github.com/google/jax/issues/6415 for motivation. + # See https://github.com/jax-ml/jax/issues/6415 for motivation. x = core.full_lower(x) if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 8f48746dd..984d55fe2 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -492,7 +492,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, devices = axis_context.device_assignment if devices is None: raise AssertionError( - 'Please file a bug at https://github.com/google/jax/issues') + 'Please file a bug at https://github.com/jax-ml/jax/issues') if axis_context.mesh_shape is not None: ma, ms = list(zip(*axis_context.mesh_shape)) mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 337349694..465dc90e2 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -389,7 +389,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, devices = axis_context.device_assignment if devices is None: raise AssertionError( - 'Please file a bug at https://github.com/google/jax/issues') + 'Please file a bug at https://github.com/jax-ml/jax/issues') elif isinstance(axis_context, sharding_impls.SPMDAxisContext): devices = axis_context.mesh._flat_devices_tuple else: diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index d76b80ad3..352a3e550 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -793,7 +793,7 @@ def check_user_dtype_supported(dtype, fun_name=None): "and will be truncated to dtype {}. To enable more dtypes, set the " "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " "environment variable. " - "See https://github.com/google/jax#current-gotchas for more.") + "See https://github.com/jax-ml/jax#current-gotchas for more.") fun_name = f"requested in {fun_name}" if fun_name else "" truncated_dtype = canonicalize_dtype(np_dtype).name warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3) diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index e18ad1f6e..11a9dda66 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -61,7 +61,7 @@ def _ravel_list(lst): if all(dt == to_dtype for dt in from_dtypes): # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. - # See https://github.com/google/jax/issues/7809. + # See https://github.com/jax-ml/jax/issues/7809. del from_dtypes, to_dtype raveled = jnp.concatenate([jnp.ravel(e) for e in lst]) return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 4bf6d1ceb..2c9490756 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -30,7 +30,7 @@ Instead of writing this information as conditions inside one particular test, we write them as `Limitation` objects that can be reused in multiple tests and can also be used to generate documentation, e.g., the report of [unsupported and partially-implemented JAX -primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) +primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) The limitations are used to filter out from tests the harnesses that are known to fail. A Limitation is specific to a harness. @@ -515,7 +515,7 @@ def _make_convert_element_type_harness(name, for old_dtype in jtu.dtypes.all: # TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating # point numbers and new_dtype is an unsigned integer. See issue - # https://github.com/google/jax/issues/5082 for details. + # https://github.com/jax-ml/jax/issues/5082 for details. for new_dtype in (jtu.dtypes.all if not (dtypes.issubdtype(old_dtype, np.floating) or dtypes.issubdtype(old_dtype, np.complexfloating)) @@ -2336,7 +2336,7 @@ _make_select_and_scatter_add_harness("select_prim", select_prim=lax.le_p) # Validate padding for padding in [ # TODO(bchetioui): commented out the test based on - # https://github.com/google/jax/issues/4690 + # https://github.com/jax-ml/jax/issues/4690 # ((1, 2), (2, 3), (3, 4)) # non-zero padding ((1, 1), (1, 1), (1, 1)) # non-zero padding ]: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6bc6539b9..6bc3cceb7 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1262,7 +1262,7 @@ def partial_eval_jaxpr_custom_rule_not_implemented( name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool], inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult: msg = (f'custom-policy remat rule not implemented for {name}, ' - 'open a feature request at https://github.com/google/jax/issues!') + 'open a feature request at https://github.com/jax-ml/jax/issues!') raise NotImplementedError(msg) @@ -2688,7 +2688,7 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): # TODO(mattjj): the following are deprecated; update callers to _nounits version -# See https://github.com/google/jax/pull/9498 +# See https://github.com/jax-ml/jax/pull/9498 @lu.transformation def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool], pvals: Sequence[PartialVal]): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 944e20fa7..de668090e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -500,7 +500,7 @@ class MapTrace(core.Trace): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. " - "Please open an issue at https://github.com/google/jax/issues !") + "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros # always base main, can drop jvp in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) @@ -513,7 +513,7 @@ class MapTrace(core.Trace): out_trees, symbolic_zeros): if symbolic_zeros: msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. " - "Please open an issue at https://github.com/google/jax/issues !") + "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) @@ -1869,7 +1869,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " - "See https://github.com/google/jax/issues/2926.") + "See https://github.com/jax-ml/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 4cb38d28c..d3065d0f9 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -389,7 +389,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, branch_outs = [] for i, jaxpr in enumerate(branches_batched): # Perform a select on the inputs for safety of reverse-mode autodiff; see - # https://github.com/google/jax/issues/1052 + # https://github.com/jax-ml/jax/issues/1052 predicate = lax.eq(index, lax._const(index, i)) ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops] branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 41d809f8d..7a9596bf2 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -715,7 +715,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: raise NotImplementedError if not all(init_lin): - pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963 + pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963 consts, _, xs = split_list(args, [num_consts, num_carry]) ires, _ = split_list(consts, [num_ires]) @@ -1169,7 +1169,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, if discharged_consts: raise NotImplementedError("Discharged jaxpr has consts. If you see this, " "please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") def wrapped(*wrapped_args): val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args, [n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs]) @@ -1838,7 +1838,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, if body_jaxpr_consts: raise NotImplementedError("Body jaxpr has consts. If you see this error, " "please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") # body_jaxpr has the signature (*body_consts, *carry) -> carry. # Some of these body_consts are actually `Ref`s so when we discharge # them, they also turn into outputs, effectively turning those consts into diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 0cbee6d2b..36553e512 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -157,7 +157,7 @@ def _irfft_transpose(t, fft_lengths): out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients - # https://github.com/google/jax/issues/6223#issuecomment-807740707 + # https://github.com/jax-ml/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def _fft_transpose_rule(t, operand, fft_type, fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 48af9c64f..394a54c35 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1077,7 +1077,7 @@ def _reduction_jaxpr(computation, aval): if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " - "at https://github.com/google/jax.") + "at https://github.com/jax-ml/jax.") return jaxpr, tuple(consts) @cache() @@ -1090,7 +1090,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree): if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " - "at https://github.com/google/jax.") + "at https://github.com/jax-ml/jax.") return core.ClosedJaxpr(jaxpr, consts), out_tree() def _get_monoid_reducer(monoid_op: Callable, @@ -4911,7 +4911,7 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs): return tree_util.tree_unflatten(p.out_tree(), out_flat) -# TODO(https://github.com/google/jax/issues/13552): Look into making this a +# TODO(https://github.com/jax-ml/jax/issues/13552): Look into making this a # method on jax.Array so that we can bypass the XLA compilation here. def _copy_impl(prim, *args, **kwargs): a, = args diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index ec0a075da..453e79a5c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -781,7 +781,7 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' 'eigenvalues. See ' - 'https://github.com/google/jax/issues/2748 for discussion.') + 'https://github.com/jax-ml/jax/issues/2748 for discussion.') # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in # https://arxiv.org/abs/1701.00392 a, = primals diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index b2bcc53a5..e8fcb4334 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -27,7 +27,7 @@ try: except ModuleNotFoundError as err: raise ModuleNotFoundError( 'jax requires jaxlib to be installed. See ' - 'https://github.com/google/jax#installation for installation instructions.' + 'https://github.com/jax-ml/jax#installation for installation instructions.' ) from err import jax.version @@ -92,7 +92,7 @@ pytree = xla_client._xla.pytree jax_jit = xla_client._xla.jax_jit pmap_lib = xla_client._xla.pmap_lib -# XLA garbage collection: see https://github.com/google/jax/issues/14882 +# XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): xla_client._xla.collect_garbage() gc.callbacks.append(_xla_gc_callback) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 20234b678..08c8bfcb3 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -333,7 +333,7 @@ class AbstractMesh: should use this as an input to the sharding passed to with_sharding_constraint and mesh passed to shard_map to avoid tracing and lowering cache misses when your mesh shape and names stay the same but the devices change. - See the description of https://github.com/google/jax/pull/23022 for more + See the description of https://github.com/jax-ml/jax/pull/23022 for more details. """ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a1601e920..5b9362685 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3953,7 +3953,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], arrays_out = [asarray(arr, dtype=dtype) for arr in arrays] # lax.concatenate can be slow to compile for wide concatenations, so form a # tree of concatenations as a workaround especially for op-by-op mode. - # (https://github.com/google/jax/issues/653). + # (https://github.com/jax-ml/jax/issues/653). k = 16 while len(arrays_out) > 1: arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) @@ -4645,7 +4645,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in - # https://github.com/google/jax/pull/6047. More correct would be to call + # https://github.com/jax-ml/jax/pull/6047. More correct would be to call # coerce_to_array on each leaf, but this may have performance implications. out = np.asarray(object, dtype=dtype) elif isinstance(object, Array): @@ -10150,11 +10150,11 @@ def _eliminate_deprecated_list_indexing(idx): if any(_should_unpack_list_index(i) for i in idx): msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " "use `arr[tuple(seq)]` instead of `arr[seq]`. " - "See https://github.com/google/jax/issues/4564 for more information.") + "See https://github.com/jax-ml/jax/issues/4564 for more information.") else: msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " "use `arr[array(seq)]` instead of `arr[seq]`. " - "See https://github.com/google/jax/issues/4564 for more information.") + "See https://github.com/jax-ml/jax/issues/4564 for more information.") raise TypeError(msg) else: idx = (idx,) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index b6aea9e19..043c976ef 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -968,7 +968,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy msg = ("jax.numpy.var does not yet support real dtype parameters when " "computing the variance of an array of complex values. The " "semantics of numpy.var seem unclear in this case. Please comment " - "on https://github.com/google/jax/issues/2283 if this behavior is " + "on https://github.com/jax-ml/jax/issues/2283 if this behavior is " "important to you.") raise ValueError(msg) computation_dtype = dtype diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 7e8acb090..6491a7617 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -134,7 +134,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. - The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. + The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. In order to ensure statically-known output shapes, you can pass a static ``size`` argument: @@ -217,7 +217,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. - The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. + The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. In order to ensure statically-known output shapes, you can pass a static ``size`` argument: diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index b45b3370f..00b5311b8 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2176,7 +2176,7 @@ def sinc(x: ArrayLike, /) -> Array: def _sinc_maclaurin(k, x): # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we # compute the monomial term in the jvp rule) - # TODO(mattjj): see https://github.com/google/jax/issues/10750 + # TODO(mattjj): see https://github.com/jax-ml/jax/issues/10750 if k % 2: return x * 0 else: diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py index 5340eb3fa..f8231f5b2 100644 --- a/jax/_src/pallas/mosaic/error_handling.py +++ b/jax/_src/pallas/mosaic/error_handling.py @@ -35,7 +35,7 @@ FRAME_PATTERN = re.compile( ) MLIR_ERR_PREFIX = ( 'Pallas encountered an internal verification error.' - 'Please file a bug at https://github.com/google/jax/issues. ' + 'Please file a bug at https://github.com/jax-ml/jax/issues. ' 'Error details: ' ) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f120bbabf..d4dc534d0 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -833,7 +833,7 @@ def jaxpr_subcomp( raise NotImplementedError( "Unimplemented primitive in Pallas TPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues.") + "Please file an issue on https://github.com/jax-ml/jax/issues.") if eqn.primitive.multiple_results: map(write_env, eqn.outvars, ans) else: diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5eaf6e523..4b8199b36 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -549,7 +549,7 @@ def lower_jaxpr_to_mosaic_gpu( raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues." + "Please file an issue on https://github.com/jax-ml/jax/issues." ) rule = mosaic_lowering_rules[eqn.primitive] rule_ctx = LoweringRuleContext( diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index ac28bd21a..6a1615627 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -381,7 +381,7 @@ def lower_jaxpr_to_triton_ir( raise NotImplementedError( "Unimplemented primitive in Pallas GPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues.") + "Please file an issue on https://github.com/jax-ml/jax/issues.") rule = triton_lowering_rules[eqn.primitive] avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0abaa3fd0..ac1318ed7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -459,7 +459,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, # list: if in_axes is not a leaf, it must be a tuple of trees. However, # in cases like these users expect tuples and lists to be treated # essentially interchangeably, so we canonicalize lists to tuples here - # rather than raising an error. https://github.com/google/jax/issues/2367 + # rather than raising an error. https://github.com/jax-ml/jax/issues/2367 in_shardings = tuple(in_shardings) in_layouts, in_shardings = _split_layout_and_sharding(in_shardings) @@ -1276,7 +1276,7 @@ def explain_tracing_cache_miss( return done() # we think this is unreachable... - p("explanation unavailable! please open an issue at https://github.com/google/jax") + p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax") return done() @partial(lu.cache, explain=explain_tracing_cache_miss) @@ -1701,7 +1701,7 @@ def _pjit_call_impl_python( "`jit` decorator, at the cost of losing optimizations. " "\n\n" "If you see this error, consider opening a bug report at " - "https://github.com/google/jax.") + "https://github.com/jax-ml/jax.") raise FloatingPointError(msg) diff --git a/jax/_src/random.py b/jax/_src/random.py index d889713f6..203f72d40 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -508,7 +508,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: span = lax.convert_element_type(maxval - minval, unsigned_dtype) # Ensure that span=1 when maxval <= minval, so minval is always returned; - # https://github.com/google/jax/issues/222 + # https://github.com/jax-ml/jax/issues/222 span = lax.select(maxval <= minval, lax.full_like(span, 1), span) # When maxval is out of range, the span has to be one larger. @@ -2540,7 +2540,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array: _btrs(key, count_btrs, q_btrs, shape, dtype, max_iters), ) # ensure nan q always leads to nan output and nan or neg count leads to nan - # as discussed in https://github.com/google/jax/pull/16134#pullrequestreview-1446642709 + # as discussed in https://github.com/jax-ml/jax/pull/16134#pullrequestreview-1446642709 invalid = (q_l_0 | q_is_nan | count_nan_or_neg) samples = lax.select( invalid, diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index d81008308..ee144eaf9 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -176,7 +176,7 @@ def map_coordinates( Note: Interpolation near boundaries differs from the scipy function, because JAX - fixed an outstanding bug; see https://github.com/google/jax/issues/11097. + fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097. This function interprets the ``mode`` argument as documented by SciPy, but not as implemented by SciPy. """ diff --git a/jax/_src/shard_alike.py b/jax/_src/shard_alike.py index 2361eaf64..574d725c4 100644 --- a/jax/_src/shard_alike.py +++ b/jax/_src/shard_alike.py @@ -44,7 +44,7 @@ def shard_alike(x, y): raise ValueError( 'The leaves shapes of `x` and `y` should match. Got `x` leaf shape:' f' {x_aval.shape} and `y` leaf shape: {y_aval.shape}. File an issue at' - ' https://github.com/google/jax/issues if you want this feature.') + ' https://github.com/jax-ml/jax/issues if you want this feature.') outs = [shard_alike_p.bind(x_, y_) for x_, y_ in safe_zip(x_flat, y_flat)] x_out_flat, y_out_flat = zip(*outs) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 5afcd5e3a..81737f275 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1208,7 +1208,7 @@ class JaxTestCase(parameterized.TestCase): y = np.asarray(y) if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): - # See https://github.com/google/jax/issues/17867 + # See https://github.com/jax-ml/jax/issues/17867 raise TypeError( "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " "If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. " diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 0caa6e7c6..010841b45 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -21,7 +21,7 @@ exported at `jax.typing`. Until then, the contents here should be considered uns and may change without notice. To see the proposal that led to the development of these tools, see -https://github.com/google/jax/pull/11859/. +https://github.com/jax-ml/jax/pull/11859/. """ from __future__ import annotations diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 1d3c50403..796093b62 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -1232,7 +1232,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs): if library_path is None: raise RuntimeError( "JAX TPU support not installed; cannot generate TPU topology. See" - " https://github.com/google/jax#installation") + " https://github.com/jax-ml/jax#installation") c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path) xla_client.profiler.register_plugin_profiler(c_api) assert xla_client.pjrt_plugin_loaded("tpu") diff --git a/jax/core.py b/jax/core.py index 9857fcf88..cdf8d7655 100644 --- a/jax/core.py +++ b/jax/core.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.core import ( AbstractToken as AbstractToken, diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 8e517f5d4..ea1ef4f02 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.custom_derivatives import ( _initial_style_jaxpr, diff --git a/jax/dtypes.py b/jax/dtypes.py index f2071fd4f..a6f1b7645 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.dtypes import ( bfloat16 as bfloat16, diff --git a/jax/errors.py b/jax/errors.py index 15a6654fa..2a811661d 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.errors import ( JAXTypeError as JAXTypeError, diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index caf27ec7a..1b22c2c2a 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.experimental.x64_context import ( enable_x64 as enable_x64, diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 69b25d0b6..3ac1d4246 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 import sys as _sys import warnings as _warnings diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index 0b6b51f71..8e11d4173 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.checkify import ( Error as Error, diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 3c7bfac40..6da3ad7c5 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.custom_partitioning import ( custom_partitioning as custom_partitioning, diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 63c3299c5..49162809a 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -17,7 +17,7 @@ The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. This module introduces the host callback functions :func:`call`, :func:`id_tap`, and :func:`id_print`, that send their arguments from the device @@ -363,11 +363,11 @@ using the :func:`jax.custom_vjp` mechanism. This is relatively easy to do, once one understands both the JAX custom VJP and the TensorFlow autodiff mechanisms. The code for how this can be done is shown in the ``call_tf_full_ad`` -function in `host_callback_to_tf_test.py `_. +function in `host_callback_to_tf_test.py `_. This example supports arbitrary higher-order differentiation as well. Note that if you just want to call TensorFlow functions from JAX, you can also -use the `jax2tf.call_tf function `_. +use the `jax2tf.call_tf function `_. Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support ------------------------------------------------------------------------------------------------ @@ -378,7 +378,7 @@ the host, and then to the outside device on which the JAX host computation will run, and then the results are sent back to the original accelerator. The code for how this can be done is shown in the ``call_jax_other_device function`` -in `host_callback_test.py `_. +in `host_callback_test.py `_. Low-level details and debugging ------------------------------- @@ -572,7 +572,7 @@ _HOST_CALLBACK_LEGACY = config.bool_flag( help=( 'Use old implementation of host_callback, documented in the module docstring.' 'If False, use the jax.experimental.io_callback implementation. ' - 'See https://github.com/google/jax/issues/20385.' + 'See https://github.com/jax-ml/jax/issues/20385.' ) ) @@ -592,7 +592,7 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): "See https://jax.readthedocs.io/en/latest/debugging/index.html and " "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html" " for alternatives. Please file a feature request at " - "https://github.com/google/jax/issues if none of the alternatives are " + "https://github.com/jax-ml/jax/issues if none of the alternatives are " "sufficient.") @@ -608,7 +608,7 @@ DType = Any class CallbackFlavor(enum.Enum): """Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. """ IO_CALLBACK = 1 # uses jax.experimental.io_callback PURE = 2 # uses jax.pure_callback @@ -629,7 +629,7 @@ def _deprecated_id_tap(tap_func, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime @@ -655,7 +655,7 @@ def _deprecated_id_tap(tap_func, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Returns: ``arg``, or ``result`` if given. @@ -712,7 +712,7 @@ def _deprecated_id_print(arg, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. On each invocation of the printing tap, the ``kwargs`` if present will be printed first (sorted by keys). Then arg will be printed, @@ -730,7 +730,7 @@ def _deprecated_id_print(arg, * ``threshold`` is passed to ``numpy.array2string``. * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. For more details see the :mod:`jax.experimental.host_callback` module documentation. """ @@ -757,7 +757,7 @@ def _deprecated_call(callback_func: Callable, arg, *, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Args: callback_func: The Python function to invoke on the host as @@ -787,7 +787,7 @@ def _deprecated_call(callback_func: Callable, arg, *, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Returns: the result of the ``callback_func`` invocation. @@ -800,7 +800,7 @@ def _deprecated_call(callback_func: Callable, arg, *, raise NotImplementedError( "When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` " "flavor of callback only when the `result_shape` is None. " - "See https://github.com/google/jax/issues/20385." + "See https://github.com/jax-ml/jax/issues/20385." ) return _call(callback_func, arg, result_shape=result_shape, call_with_device=call_with_device, identity=False, @@ -819,7 +819,7 @@ class _CallbackWrapper: raise NotImplementedError( "When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs" " do not support `tap_with_device` and `call_with_device`. " - "See https://github.com/google/jax/issues/20385.") + "See https://github.com/jax-ml/jax/issues/20385.") def __hash__(self): return hash((self.callback_func, self.identity, self.call_with_device)) @@ -2121,7 +2121,7 @@ def _deprecated_stop_outfeed_receiver(): _deprecation_msg = ( "The host_callback APIs are deprecated as of March 20, 2024. The functionality " "is subsumed by the new JAX external callbacks. " - "See https://github.com/google/jax/issues/20385.") + "See https://github.com/jax-ml/jax/issues/20385.") _deprecations = { # Added March 20, 2024 diff --git a/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb b/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb index 4f23d88e0..3613dba0e 100644 --- a/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb +++ b/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb @@ -26,7 +26,7 @@ "Link: go/jax2tf-colab\n", "\n", "The JAX2TF colab has been deprecated, and the example code has\n", - "been moved to [jax2tf/examples](https://github.com/google/jax/tree/main/jax/experimental/jax2tf/examples). \n" + "been moved to [jax2tf/examples](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf/examples). \n" ] } ] diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index dbdc4f563..b77474c03 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -103,10 +103,10 @@ For more involved examples, please see examples involving: * SavedModel for archival ([examples below](#usage-saved-model)), including saving [batch-polymorphic functions](#shape-polymorphic-conversion), - * TensorFlow Lite ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)), - * TensorFlow.js ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)), + * TensorFlow Lite ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)), + * TensorFlow.js ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)), * TFX ([examples](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/README.md#instructions-for-using-flax)), - * TensorFlow Hub and Keras ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md)). + * TensorFlow Hub and Keras ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md)). [TOC] @@ -249,7 +249,7 @@ graph (they will be saved in a `variables` area of the model, which is not subject to the 2GB limitation). For examples of how to save a Flax model as a SavedModel see the -[examples directory](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md). +[examples directory](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md). ### Saved model and differentiation @@ -619,7 +619,7 @@ Cannot solve for values of dimension variables {'a', 'b'}. " We can only solve linear uni-variate constraints. " Using the following polymorphic shapes specifications: args[0].shape = (a + b,). Unprocessed specifications: 'a + b' for dimension size args[0].shape[0]. " -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details. +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details. ``` ### Shape assertion errors @@ -645,7 +645,7 @@ Input shapes do not match the polymorphic shapes specification. Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. ``` When using native serialization these are checked by the `tf.XlaCallModule` @@ -869,7 +869,7 @@ leads to errors for the following expressions `b == a or b == b` or `b in [a, b] even though the error is avoided if we change the order of the comparisons. We attempted to retain soundness and hashability by creating both hashable and unhashable -kinds of symbolic dimensions [PR #14200](https://github.com/google/jax/pull/14200), +kinds of symbolic dimensions [PR #14200](https://github.com/jax-ml/jax/pull/14200), but it turned out to be very hard to diagnose hashing failures in user programs because often hashing is implicit when using sets or memo tables. @@ -989,7 +989,7 @@ We list here a history of the serialization version numbers: June 13th, 2023 (JAX 0.4.13). * Version 7 adds support for `stablehlo.shape_assertion` operations and for `shape_assertions` specified in `disabled_checks`. - See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule since July 12th, 2023 (cl/547482522), available in JAX serialization since July 20th, 2023 (JAX 0.4.14), and the default since August 12th, 2023 (JAX 0.4.15). @@ -1164,7 +1164,7 @@ self.assertAllClose(grad_jax.b, grad_tf[1]) Applies to both native and non-native serialization. When JAX differentiates functions with integer or boolean arguments, the gradients will -be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)). +be zero-vectors with a special `float0` type (see PR 4039](https://github.com/jax-ml/jax/pull/4039)). This type is translated to `int32` when lowering to TF. For example, @@ -1441,7 +1441,7 @@ Operations like ``jax.numpy.cumsum`` are lowered by JAX differently based on the platform. For TPU, the lowering uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow) operation, which has an efficient implementation for the cases when the reduction function is associative. For CPU and GPU, JAX uses an alternative -lowering using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). +lowering using [associative scans](https://github.com/jax-ml/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). jax2tf uses the TPU lowering (because it does not support backend-specific lowering) and hence it can be slow in some cases on CPU and GPU. @@ -1502,7 +1502,7 @@ before conversion. (This is a hypothesis, we have not yet verified it extensivel There is one know case when the performance of the lowered code will be different. JAX programs use a [stateless -deterministic PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) +deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md) and it has an internal JAX primitive for it. This primitive is at the moment lowered to a soup of tf.bitwise operations, which has a clear performance penalty. We plan to look into using the @@ -1589,7 +1589,7 @@ Applies to non-native serialization only. There are a number of cases when the TensorFlow ops that are used by the `jax2tf` are not supported by TensorFlow for the same data types as in JAX. There is an -[up-to-date list of unimplemented cases](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). +[up-to-date list of unimplemented cases](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). If you try to lower and run in TensorFlow a program with partially supported primitives, you may see TensorFlow errors that @@ -1626,7 +1626,7 @@ the function to a SavedModel, knowing that upon restore the jax2tf-lowered code will be compiled. For a more elaborate example, see the test `test_tf_mix_jax_with_uncompilable` -in [savedmodel_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py). +in [savedmodel_test.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py). # Calling TensorFlow functions from JAX @@ -1704,7 +1704,7 @@ For a more elaborate example, including round-tripping from JAX to TensorFlow and back through a SavedModel, with support for custom gradients, see the test `test_round_trip_custom_grad_saved_model` -in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py). +in [call_tf_test.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py). All the metadata inserted by TF during tracing and compilation, e.g., source location information and op names, is carried through to the @@ -1901,7 +1901,7 @@ As of today, the tests are run using `tf_nightly==2.14.0.dev20230720`. To run jax2tf on GPU, both jaxlib and TensorFlow must be installed with support for CUDA. One must be mindful to install a version of CUDA that is compatible -with both [jaxlib](https://github.com/google/jax/blob/main/README.md#pip-installation) and +with both [jaxlib](https://github.com/jax-ml/jax/blob/main/README.md#pip-installation) and [TensorFlow](https://www.tensorflow.org/install/source#tested_build_configurations). ## Updating the limitations documentation @@ -1913,9 +1913,9 @@ JAX primitive, data type, device type, and TensorFlow execution mode (`eager`, `graph`, or `compiled`). These limitations are also used to generate tables of limitations, e.g., - * [List of primitives not supported in JAX](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md), + * [List of primitives not supported in JAX](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md), e.g., due to unimplemented cases in the XLA compiler, and - * [List of primitives not supported in jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md), + * [List of primitives not supported in jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md), e.g., due to unimplemented cases in TensorFlow. This list is incremental on top of the unsupported JAX primitives. diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 037f8bbc2..baae52403 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -19,7 +19,7 @@ This module introduces the function :func:`call_tf` that allows JAX to call TensorFlow functions. For examples and details, see -https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax. +https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax. """ @@ -93,7 +93,7 @@ def call_tf( For an example and more details see the `README - `_. + `_. Args: callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow @@ -460,7 +460,7 @@ def _call_tf_abstract_eval( msg = ("call_tf cannot call functions whose output has dynamic shape. " f"Found output shapes: {concrete_function_flat_tf.output_shapes}. " "Consider using the `output_shape_dtype` argument to call_tf. " - "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + "\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" " for a discussion.") raise ValueError(msg) @@ -499,7 +499,7 @@ def _call_tf_lowering( msg = ( "call_tf works best with a TensorFlow function that does not capture " "variables or tensors from the context. " - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. " + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. " f"The following captures were found {concrete_function_flat_tf.captured_inputs}") logging.warning(msg) for inp in concrete_function_flat_tf.captured_inputs: @@ -544,7 +544,7 @@ def _call_tf_lowering( "\ncall_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compilable functions with static output shapes.\n" + - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." + + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." + "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e @@ -557,7 +557,7 @@ def _call_tf_lowering( f"{res_shape}. call_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compilable functions with static output shapes. " + - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") raise ValueError(msg) res_dtype = res_shape.numpy_dtype() diff --git a/jax/experimental/jax2tf/examples/README.md b/jax/experimental/jax2tf/examples/README.md index b049798e7..8869a226b 100644 --- a/jax/experimental/jax2tf/examples/README.md +++ b/jax/experimental/jax2tf/examples/README.md @@ -4,7 +4,7 @@ jax2tf Examples Link: go/jax2tf-examples. This directory contains a number of examples of using the -[jax2tf converter](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) to: +[jax2tf converter](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) to: * save SavedModel from trained MNIST models, using both Flax and pure JAX. * reuse the feature-extractor part of the trained MNIST model @@ -19,12 +19,12 @@ You can also find usage examples in other projects: The functions generated by `jax2tf.convert` are standard TensorFlow functions and you can save them in a SavedModel using standard TensorFlow code, as shown -in the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model). +in the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model). This decoupling of jax2tf from SavedModel is important, because it **allows the user to have full control over what metadata is saved in the SavedModel**. As an example, we provide the function `convert_and_save_model` -(see [saved_model_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py).) +(see [saved_model_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py).) For serious uses, you will probably want to copy and expand this function as needed. @@ -65,7 +65,7 @@ If you are using Flax, then the recipe to obtain this pair is as follows: ``` You can see in -[mnist_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) +[mnist_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) how this can be done for two implementations of MNIST, one using pure JAX (`PureJaxMNIST`) and a CNN one using Flax (`FlaxMNIST`). Other Flax models can be arranged similarly, @@ -91,7 +91,7 @@ embed all parameters in the graph: ``` (The MNIST Flax examples from -[mnist_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) +[mnist_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) normally has a GraphDef of 150k and a variables section of 3Mb. If we embed the parameters as constants in the GraphDef as shown above, the variables section becomes empty and the GraphDef becomes 13Mb. This embedding may allow @@ -112,7 +112,7 @@ If you are using Haiku, then the recipe is along these lines: Once you have the model in this form, you can use the `saved_model_lib.save_model` function from -[saved_model_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py) +[saved_model_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py) to generate the SavedModel. There is very little in that function that is specific to jax2tf. The goal of jax2tf is to convert JAX functions into functions that behave as if they had been written with TensorFlow. @@ -120,7 +120,7 @@ Therefore, if you are familiar with how to generate SavedModel, you can most likely just use your own code for this. The file -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) is an executable that shows how to perform the following sequence of steps: @@ -147,9 +147,9 @@ batch sizes: 1, 16, 128. You can see this in the dumped SavedModel. The SavedModel produced by the example in `saved_model_main.py` already implements the [reusable saved models interface](https://www.tensorflow.org/hub/reusable_saved_models). The executable -[keras_reuse_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/keras_reuse_main.py) +[keras_reuse_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/keras_reuse_main.py) extends -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) with code to include a jax2tf SavedModel into a larger TensorFlow Keras model. @@ -174,7 +174,7 @@ In particular, you can select the Flax MNIST model: `--model=mnist_flax`. It is also possible to use jax2tf-generated SavedModel with TensorFlow serving. At the moment, the open-source TensorFlow model server is missing XLA support, but the Google version can be used, as shown in the -[serving examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md). +[serving examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md). # Using jax2tf with TensorFlow Lite and TensorFlow JavaScript @@ -186,6 +186,6 @@ can pass the `enable_xla=False` parameter to `jax2tf.convert` to direct `jax2tf` to avoid problematic ops. This will increase the coverage, and in fact most, but not all, Flax examples can be converted this way. -Check out the [MNIST TensorFlow Lite](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md) +Check out the [MNIST TensorFlow Lite](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md) and the -[Quickdraw TensorFlow.js example](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md). +[Quickdraw TensorFlow.js example](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md). diff --git a/jax/experimental/jax2tf/examples/serving/README.md b/jax/experimental/jax2tf/examples/serving/README.md index 0d8f49e45..299923109 100644 --- a/jax/experimental/jax2tf/examples/serving/README.md +++ b/jax/experimental/jax2tf/examples/serving/README.md @@ -2,7 +2,7 @@ Using jax2tf with TensorFlow serving ==================================== This is a supplement to the -[examples/README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md) +[examples/README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md) with example code and instructions for using `jax2tf` with the open source TensorFlow model server. Specific instructions for Google-internal versions of model server are in the `internal` subdirectory. @@ -15,16 +15,16 @@ SavedModel**. The only difference in the SavedModel produced with jax2tf is that the function graphs may contain -[XLA TF ops](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#caveats) +[XLA TF ops](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#caveats) that require enabling CPU/GPU XLA for execution in the model server. This is achieved using a command-line flag. There are no other differences compared to using SavedModel produced by TensorFlow. This serving example uses -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) for saving the SavedModel and adds code specific to interacting with the model server: -[model_server_request.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/serving/model_server_request.py). +[model_server_request.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/model_server_request.py). 0. *Set up JAX and TensorFlow serving*. @@ -36,7 +36,7 @@ We also need to install TensorFlow for the `jax2tf` feature and the rest of this We use the `tf_nightly` package to get an up-to-date version. ```shell - git clone https://github.com/google/jax + git clone https://github.com/jax-ml/jax JAX2TF_EXAMPLES=$(pwd)/jax/jax/experimental/jax2tf/examples pip install -e jax pip install flax jaxlib tensorflow_datasets tensorflow_serving_api tf_nightly diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/README.md b/jax/experimental/jax2tf/examples/tflite/mnist/README.md index 9c889e647..f39bd9c7e 100644 --- a/jax/experimental/jax2tf/examples/tflite/mnist/README.md +++ b/jax/experimental/jax2tf/examples/tflite/mnist/README.md @@ -65,7 +65,7 @@ TensorFlow ops that are only available with the XLA compiler, and which are not understood (yet) by the TFLite converter to be used below. -Check out [more details about this limitation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md), +Check out [more details about this limitation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md), including to which JAX primitives it applies. ### Convert the trained model to the TF Lite format diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md b/jax/experimental/jax2tf/g3doc/convert_models_results.md index 545f1faee..24e2539a3 100644 --- a/jax/experimental/jax2tf/g3doc/convert_models_results.md +++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md @@ -48,13 +48,13 @@ details on the different converters. ## `flax/actor_critic_[(_, 4*b, 4*b, _)]` ### Example: `flax/actor_critic_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/actor_critic_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -62,13 +62,13 @@ InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -78,13 +78,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -94,13 +94,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -122,13 +122,13 @@ RuntimeError('third_party/tensorflow/lite/kernels/concatenation.cc:159 t->dims-> ## `flax/bilstm_[(b, _), (_,)]` ### Example: `flax/bilstm_[(b, _), (_,)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/bilstm_[(b, _), (_,)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -141,7 +141,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -156,7 +156,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -171,7 +171,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -180,13 +180,13 @@ for more details. ## `flax/bilstm_[(_, _), (b,)]` ### Example: `flax/bilstm_[(_, _), (b,)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/bilstm_[(_, _), (b,)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -199,7 +199,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -214,7 +214,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -229,7 +229,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -238,13 +238,13 @@ for more details. ## `flax/cnn_[(_, b, b, _)]` ### Example: `flax/cnn_[(_, b, b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/cnn_[(_, b, b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -253,13 +253,13 @@ InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_ Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -267,7 +267,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -278,13 +278,13 @@ for more details. Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -292,7 +292,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -303,13 +303,13 @@ for more details. Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -317,7 +317,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -395,13 +395,13 @@ ValueError('Cannot set tensor: Dimension mismatch. Got 8 but expected 1 for dime ## `flax/resnet50_[(_, 4*b, 4*b, _)]` ### Example: `flax/resnet50_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/resnet50_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -409,13 +409,13 @@ InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -425,13 +425,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -441,13 +441,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -613,13 +613,13 @@ IndexError('Cannot use NumPy slice indexing on an array dimension whose size is ## `flax/lm1b_[(b, _)]` ### Example: `flax/lm1b_[(b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/lm1b_[(b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -632,7 +632,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -647,7 +647,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -662,7 +662,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -684,13 +684,13 @@ ValueError('Cannot set tensor: Dimension mismatch. Got 2 but expected 1 for dime ## `flax/wmt_[(b, _), (b, _)]` ### Example: `flax/wmt_[(b, _), (b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/wmt_[(b, _), (b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -703,7 +703,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -718,7 +718,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -733,7 +733,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -798,14 +798,14 @@ This converter simply converts a the forward function of a JAX model to a Tensorflow function with XLA support linked in. This is considered the baseline converter and has the largest coverage, because we expect nearly all ops to be convertible. However, please see -[jax2tf Known Issue](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues) +[jax2tf Known Issue](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues) for a list of known problems. ### `jax2tf_noxla` This converter converts a JAX model to a Tensorflow function without XLA support. This means the Tensorflow XLA ops aren't used. See -[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) +[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) for more details. ### `jax2tfjs` diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md.template b/jax/experimental/jax2tf/g3doc/convert_models_results.md.template index b54c57503..54e1d2135 100644 --- a/jax/experimental/jax2tf/g3doc/convert_models_results.md.template +++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md.template @@ -29,14 +29,14 @@ This converter simply converts a the forward function of a JAX model to a Tensorflow function with XLA support linked in. This is considered the baseline converter and has the largest coverage, because we expect nearly all ops to be convertible. However, please see -[jax2tf Known Issue](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues) +[jax2tf Known Issue](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues) for a list of known problems. ### `jax2tf_noxla` This converter converts a JAX model to a Tensorflow function without XLA support. This means the Tensorflow XLA ops aren't used. See -[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) +[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) for more details. ### `jax2tfjs` diff --git a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md index 457dc998a..24a1d62ee 100644 --- a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md +++ b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md @@ -1,6 +1,6 @@ # jax2tf Limitations for `enable_xla=False` -*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues)* +*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues)* For most JAX primitives there is a natural TF op that fits the needed semantics (e.g., `jax.lax.abs` is equivalent to `tf.abs`). However, there are a number of diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md index dabbcca4d..b36b004a9 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md @@ -40,7 +40,7 @@ The converter has a mode in which it attempts to avoid special XLA TF ops (`enable_xla=False`). In this mode, some primitives have additional limitations. This table only shows errors for cases that are working in JAX (see [separate -list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) +list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) We do not yet have support for `pmap` (with its collective primitives), nor for `sharded_jit` (SPMD partitioning). @@ -56,7 +56,7 @@ We use the following abbreviations for sets of dtypes: * `all` = `integer`, `inexact`, `bool` More detailed information can be found in the -[source code for the limitation specification](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). +[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template index bf5dc41d8..219802f53 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template @@ -40,7 +40,7 @@ The converter has a mode in which it attempts to avoid special XLA TF ops (`enable_xla=False`). In this mode, some primitives have additional limitations. This table only shows errors for cases that are working in JAX (see [separate -list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) +list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) We do not yet have support for `pmap` (with its collective primitives), nor for `sharded_jit` (SPMD partitioning). @@ -56,7 +56,7 @@ We use the following abbreviations for sets of dtypes: * `all` = `integer`, `inexact`, `bool` More detailed information can be found in the -[source code for the limitation specification](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). +[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). {{tf_error_table}} diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 5ecde602c..310cbaab6 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -591,7 +591,7 @@ def _padding_reduce_window(operand, operand_shape, computation_name, padding_type = pads_to_padtype(operand_shape, window_dimensions, window_strides, padding) - # https://github.com/google/jax/issues/11874. + # https://github.com/jax-ml/jax/issues/11874. needs_manual_padding = ( padding_type == "SAME" and computation_name == "add" and window_dimensions != [1] * len(operand_shape)) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 24dee390f..8a90c491e 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -198,7 +198,7 @@ class _ThreadLocalState(threading.local): # A cache for the tf.convert_to_tensor for constants. We try to preserve # sharing for constants, to enable tf.Graph to take advantage of it. - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. self.constant_cache = None # None means that we don't use a cache. We # may be outside a conversion scope. @@ -249,7 +249,7 @@ def convert(fun_jax: Callable, """Allows calling a JAX function from a TensorFlow program. See - [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) + [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) for more details about usage and common problems. Args: @@ -291,12 +291,12 @@ def convert(fun_jax: Callable, polymorphic_shapes are only supported for positional arguments; shape polymorphism is not supported for keyword arguments. - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) + See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more details. polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of the form `e1 >= e2` or `e1 <= e2`. - See more details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. with_gradient: if set (default), add a tf.custom_gradient to the lowered function, by converting the ``jax.vjp(fun)``. This means that reverse-mode TensorFlow AD is supported for the output TensorFlow function, and the @@ -3536,7 +3536,7 @@ def _shard_value(val: TfVal, if tf_context.executing_eagerly(): raise ValueError( "A jit function with sharded arguments or results must be used under a `tf.function` context. " - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor( val, use_sharding_op=True) diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 2760efea8..e10c3fbfd 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -304,7 +304,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False) def test_with_var_different_shape(self): - # See https://github.com/google/jax/issues/6050 + # See https://github.com/jax-ml/jax/issues/6050 v = tf.Variable((4., 2.), dtype=tf.float32) def tf_func(x): @@ -428,7 +428,7 @@ class CallTfTest(tf_test_util.JaxToTfTestCase): self.assertAllClose(g_jax, g_tf) def test_grad_int_argument(self): - # Similar to https://github.com/google/jax/issues/6975 + # Similar to https://github.com/jax-ml/jax/issues/6975 # state is a pytree that contains an integer and a boolean. # The function returns an integer and a boolean. def f(param, state, x): diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 896d0436e..c3b9e96dc 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -1066,9 +1066,9 @@ class Jax2TfLimitation(test_harnesses.Limitation): @classmethod def qr(cls, harness: test_harnesses.Harness): - # See https://github.com/google/jax/pull/3775#issuecomment-659407824; + # See https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824; # # jit_compile=True breaks for complex types. - # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. + # TODO: see https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824. # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 64d461fe9..ef7a5ee2c 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -595,7 +595,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): @jtu.sample_product(with_function=[False, True]) def test_gradients_int_argument(self, with_function=False): - # https://github.com/google/jax/issues/6975 + # https://github.com/jax-ml/jax/issues/6975 # Also issue #6975. # An expanded version of test_gradients_unused_argument state = dict( @@ -969,7 +969,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def) def test_bfloat16_constant(self): - # Re: https://github.com/google/jax/issues/3942 + # Re: https://github.com/jax-ml/jax/issues/3942 def jax_fn_scalar(x): x = x.astype(jnp.bfloat16) x *= 2. @@ -990,7 +990,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): def test_shared_constants(self): # Check that the constants are shared properly in converted functions - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const = np.random.uniform(size=256).astype(np.float32) # A shared constant @@ -1002,7 +1002,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): def test_shared_constants_under_cond(self): # Check that the constants are shared properly in converted functions - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const_size = 512 @@ -1018,7 +1018,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.assertLen(f2_consts, len(f1_consts)) def test_shared_constants_under_scan(self): - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const_size = 512 @@ -1092,7 +1092,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): @jtu.sample_product(with_function=[False, True]) def test_kwargs(self, with_function=False): - # Re: https://github.com/google/jax/issues/6791 + # Re: https://github.com/jax-ml/jax/issues/6791 def f_jax(*, x): return jnp.sum(x) f_tf = jax2tf.convert(f_jax) @@ -1104,7 +1104,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): @jtu.sample_product(with_function=[False, True]) def test_grad_kwargs(self, with_function=False): - # Re: https://github.com/google/jax/issues/6791 + # Re: https://github.com/jax-ml/jax/issues/6791 x = (np.zeros(3, dtype=np.float32), np.zeros(4, dtype=np.float32)) def f_jax(*, x=(1., 2.)): diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 485fa6e58..78c24b7ea 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -30,7 +30,7 @@ in Tensorflow errors (for some devices and compilation modes). These limitations are captured as jax2tf_limitations.Jax2TfLimitation objects. From the limitations objects, we generate a -[report](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). +[report](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). The report has instructions for how to re-generate it. If a harness run fails with error, and a limitation that matches the device diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index 8b71de7db..bc19915d1 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -175,7 +175,7 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase): def test_save_grad_integers(self): - # https://github.com/google/jax/issues/7123 + # https://github.com/jax-ml/jax/issues/7123 # In the end this is a test that does not involve JAX at all batch_size = 5 state = np.array([1], dtype=np.int32) # Works if float32 diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 2475a062f..a9ee17762 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -933,7 +933,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): kwargs=[dict(with_function=v) for v in [True, False]] ) def test_grad_int(self, with_function=False): - # https://github.com/google/jax/issues/7093 + # https://github.com/jax-ml/jax/issues/7093 # Also issue #6975. x_shape = (2, 3, 4) xi = np.arange(math.prod(x_shape), dtype=np.int16).reshape(x_shape) @@ -2172,7 +2172,7 @@ _POLY_SHAPE_TEST_HARNESSES = [ (2, x.shape[0]), (1, 1), "VALID"), arg_descriptors=[RandArg((3, 8), _f32)], polymorphic_shapes=["b, ..."]), - # https://github.com/google/jax/issues/11804 + # https://github.com/jax-ml/jax/issues/11804 # Use the reshape trick to simulate a polymorphic dimension of 16*b. # (See test "conv_general_dilated.1d_1" above for more details.) PolyHarness("reduce_window", "add_monoid_strides_window_size=static", diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index d6349b487..9009c1586 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -437,7 +437,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): def test_grad_sharding_different_mesh(self): # Convert with two similar meshes, the only difference being # the order of the devices. grad should not fail. - # https://github.com/google/jax/issues/21314 + # https://github.com/jax-ml/jax/issues/21314 devices = jax.local_devices()[:2] if len(devices) < 2: raise unittest.SkipTest("Test requires 2 local devices") diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 1ed6183b1..ffe362974 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -45,11 +45,11 @@ r"""Jet is an experimental module for higher-order automatic differentiation and can thus be used for high-order automatic differentiation of :math:`f`. Details are explained in - `these notes `__. + `these notes `__. Note: Help improve :func:`jet` by contributing - `outstanding primitive rules `__. + `outstanding primitive rules `__. """ from collections.abc import Callable diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index fabd45ca0..f19401525 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -639,7 +639,7 @@ def _rule_missing(prim: core.Primitive, *_, **__): raise NotImplementedError( f"No replication rule for {prim}. As a workaround, pass the " "`check_rep=False` argument to `shard_map`. To get this fixed, open an " - "issue at https://github.com/google/jax/issues") + "issue at https://github.com/jax-ml/jax/issues") # Lowering @@ -845,20 +845,20 @@ class ShardMapTrace(core.Trace): f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " "yet supported. Put a `jax.jit` around the `shard_map`-decorated " "function, and open a feature request at " - "https://github.com/google/jax/issues !") + "https://github.com/jax-ml/jax/issues !") def process_map(self, map_primitive, fun, tracers, params): raise NotImplementedError( "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/google/jax/issues !") + "a feature request at https://github.com/jax-ml/jax/issues !") def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. if symbolic_zeros: msg = ("custom_jvp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) @@ -876,7 +876,7 @@ class ShardMapTrace(core.Trace): if symbolic_zeros: msg = ("custom_vjp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) @@ -1042,7 +1042,7 @@ def _standard_check(prim, mesh, *in_rep, **__): if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: raise Exception(f"Primitive {prim} requires argument replication types " f"to match, but got {in_rep}. Please open an issue at " - "https://github.com/google/jax/issues and as a temporary " + "https://github.com/jax-ml/jax/issues and as a temporary " "workaround pass the check_rep=False argument to shard_map") return in_rep_[0] if in_rep_ else None @@ -1057,7 +1057,7 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): raise Exception(f"Collective {prim} must be applied to a device-varying " f"replication type, but got {x_rep} for collective acting " f"over axis name {axis_name}. Please open an issue at " - "https://github.com/google/jax/issues and as a temporary " + "https://github.com/jax-ml/jax/issues and as a temporary " "workaround pass the check_rep=False argument to shard_map") return x_rep @@ -1114,7 +1114,7 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups): raise Exception("Collective psum must be applied to a device-varying " f"replication type, but got {in_rep} for collective acting " f"over axis name {axes}. Please open an issue at " - "https://github.com/google/jax/issues, and as a temporary " + "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r | set(axes) for r in in_rep] @@ -1129,7 +1129,7 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): "non-device-varying " f"replication type, but got {in_rep} for collective acting " f"over axis name {axes}. Please open an issue at " - "https://github.com/google/jax/issues, and as a temporary " + "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r - set(axes) for r in in_rep] @@ -1216,7 +1216,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): if not carry_rep_in == carry_rep_out: raise Exception("Scan carry input and output got mismatched replication " f"types {carry_rep_in} and {carry_rep_out}. Please open an " - "issue at https://github.com/google/jax/issues, and as a " + "issue at https://github.com/jax-ml/jax/issues, and as a " "temporary workaround pass the check_rep=False argument to " "shard_map") return out_rep @@ -1267,7 +1267,7 @@ def _custom_vjp_call_jaxpr_rewrite( mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and as" + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and as" " a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) @@ -1303,7 +1303,7 @@ def _linear_solve_check(mesh, *in_rep, const_lengths, jaxprs): assert in_rep if not in_rep_[:-1] == in_rep_[1:]: msg = ("shard_map check_rep rewrite failed. Please open an issue at " - "https://github.com/google/jax/issues and as a workaround pass the " + "https://github.com/jax-ml/jax/issues and as a workaround pass the " "check_rep=False argument to shard_map") raise Exception(msg) return [in_rep_[0]] * len(jaxprs.solve.out_avals) @@ -1878,7 +1878,7 @@ class RewriteTrace(core.Trace): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and " + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) @@ -1899,7 +1899,7 @@ class RewriteTrace(core.Trace): def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and " + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py index 8ab8cd887..f388cd527 100644 --- a/jax/experimental/sparse/__init__.py +++ b/jax/experimental/sparse/__init__.py @@ -189,7 +189,7 @@ To fit the same model on sparse data, we can apply the :func:`sparsify` transfor """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.experimental.sparse.ad import ( jacfwd as jacfwd, diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index d200577c2..9f2f0f69b 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -105,7 +105,7 @@ def _bcoo_set_nse(mat: BCOO, nse: int) -> BCOO: unique_indices=mat.unique_indices) # TODO(jakevdp) this can be problematic when used with autodiff; see -# https://github.com/google/jax/issues/10163. Should this be a primitive? +# https://github.com/jax-ml/jax/issues/10163. Should this be a primitive? # Alternatively, maybe roll this into bcoo_sum_duplicates as an optional argument. def bcoo_eliminate_zeros(mat: BCOO, nse: int | None = None) -> BCOO: data, indices, shape = mat.data, mat.indices, mat.shape @@ -1140,7 +1140,7 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :]) out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1]) # Note: we do not eliminate zeros here, because it can cause issues with autodiff. - # See https://github.com/google/jax/issues/10163. + # See https://github.com/jax-ml/jax/issues/10163. return _bcoo_sum_duplicates(out_data, out_indices, spinfo=SparseInfo(shape=out_shape), nse=out_nse) @bcoo_spdot_general_p.def_impl @@ -1537,7 +1537,7 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse): nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype) data_dot_out = data_out # This check is because scatter-add on zero-sized arrays has poorly defined - # semantics; see https://github.com/google/jax/issues/13656. + # semantics; see https://github.com/jax-ml/jax/issues/13656. if data_out.size: permute = lambda x, i, y: x.at[i].add(y, mode='drop') else: diff --git a/jax/extend/backend.py b/jax/extend/backend.py index 66fd149d7..b1e471133 100644 --- a/jax/extend/backend.py +++ b/jax/extend/backend.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.api import ( clear_backends as clear_backends, diff --git a/jax/extend/core/__init__.py b/jax/extend/core/__init__.py index 2732b1984..9f1632fb3 100644 --- a/jax/extend/core/__init__.py +++ b/jax/extend/core/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.abstract_arrays import ( array_types as array_types diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index e37287180..feb70b517 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.ad_util import stop_gradient_p as stop_gradient_p diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py index 3a26030c1..b2d480adc 100644 --- a/jax/extend/ffi.py +++ b/jax/extend/ffi.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.extend.ffi import ( ffi_call as ffi_call, diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py index d5fb9245a..715dfd435 100644 --- a/jax/extend/ifrt_programs.py +++ b/jax/extend/ifrt_programs.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.lib import xla_extension as _xe diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 1706f8c8c..74c52dddb 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.linear_util import ( StoreException as StoreException, diff --git a/jax/extend/random.py b/jax/extend/random.py index a055c7575..d6e0cfaab 100644 --- a/jax/extend/random.py +++ b/jax/extend/random.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.extend.random import ( define_prng_impl as define_prng_impl, diff --git a/jax/extend/source_info_util.py b/jax/extend/source_info_util.py index f74df2cab..f031dabef 100644 --- a/jax/extend/source_info_util.py +++ b/jax/extend/source_info_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.source_info_util import ( NameStack as NameStack, diff --git a/jax/image/__init__.py b/jax/image/__init__.py index c7ee8ffa9..993395f50 100644 --- a/jax/image/__init__.py +++ b/jax/image/__init__.py @@ -21,7 +21,7 @@ JAX, such as `PIX`_. """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.image.scale import ( resize as resize, diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6bfc3473f..28816afb0 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from __future__ import annotations diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 98fad903c..607fc6fa5 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.interpreters.batching import ( Array as Array, diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index e2bcd5de9..293bd0244 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.lax.lax import ( DotDimensionNumbers as DotDimensionNumbers, diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 230aacb76..496d03261 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -15,7 +15,7 @@ """Common functions for neural network libraries.""" # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.numpy import tanh as tanh from jax.nn import initializers as initializers diff --git a/jax/nn/initializers.py b/jax/nn/initializers.py index 6c73356ce..019f3e179 100644 --- a/jax/nn/initializers.py +++ b/jax/nn/initializers.py @@ -18,7 +18,7 @@ used in Keras and Sonnet. """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.nn.initializers import ( constant as constant, diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index da79f7859..20c37c559 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.numpy import fft as fft from jax.numpy import linalg as linalg diff --git a/jax/numpy/fft.py b/jax/numpy/fft.py index 24a271487..c268c2d65 100644 --- a/jax/numpy/fft.py +++ b/jax/numpy/fft.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.numpy.fft import ( ifft as ifft, diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index c342fde0a..05b5ff6db 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.numpy.linalg import ( cholesky as cholesky, diff --git a/jax/ops/__init__.py b/jax/ops/__init__.py index c61a44fd1..5e1f3d682 100644 --- a/jax/ops/__init__.py +++ b/jax/ops/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.ops.scatter import ( segment_sum as segment_sum, diff --git a/jax/profiler.py b/jax/profiler.py index 01ea6e222..77157dc02 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.profiler import ( StepTraceAnnotation as StepTraceAnnotation, diff --git a/jax/random.py b/jax/random.py index 5c2eaf81f..29a625389 100644 --- a/jax/random.py +++ b/jax/random.py @@ -103,7 +103,7 @@ Design and background **TLDR**: JAX PRNG = `Threefry counter PRNG `_ + a functional array-oriented `splitting model `_ -See `docs/jep/263-prng.md `_ +See `docs/jep/263-prng.md `_ for more details. To summarize, among other requirements, the JAX PRNG aims to: @@ -201,7 +201,7 @@ https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.random import ( PRNGKey as PRNGKey, diff --git a/jax/scipy/__init__.py b/jax/scipy/__init__.py index c0746910d..cf44b6e17 100644 --- a/jax/scipy/__init__.py +++ b/jax/scipy/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from typing import TYPE_CHECKING diff --git a/jax/scipy/cluster/__init__.py b/jax/scipy/cluster/__init__.py index 5a01ea0ee..ea35467f6 100644 --- a/jax/scipy/cluster/__init__.py +++ b/jax/scipy/cluster/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.cluster import vq as vq diff --git a/jax/scipy/cluster/vq.py b/jax/scipy/cluster/vq.py index 3a46ce52f..eeeabb722 100644 --- a/jax/scipy/cluster/vq.py +++ b/jax/scipy/cluster/vq.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.cluster.vq import vq as vq diff --git a/jax/scipy/fft.py b/jax/scipy/fft.py index b8005b72f..d3c2de099 100644 --- a/jax/scipy/fft.py +++ b/jax/scipy/fft.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.fft import ( dct as dct, diff --git a/jax/scipy/integrate.py b/jax/scipy/integrate.py index b19aa054c..3335f12fd 100644 --- a/jax/scipy/integrate.py +++ b/jax/scipy/integrate.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.integrate import ( trapezoid as trapezoid diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 059f927ec..64bc05440 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.linalg import ( block_diag as block_diag, diff --git a/jax/scipy/ndimage.py b/jax/scipy/ndimage.py index 2f63e2366..81d7e3ef2 100644 --- a/jax/scipy/ndimage.py +++ b/jax/scipy/ndimage.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.ndimage import ( map_coordinates as map_coordinates, diff --git a/jax/scipy/optimize/__init__.py b/jax/scipy/optimize/__init__.py index 8a2248733..f1c7167c3 100644 --- a/jax/scipy/optimize/__init__.py +++ b/jax/scipy/optimize/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.optimize.minimize import ( minimize as minimize, diff --git a/jax/scipy/signal.py b/jax/scipy/signal.py index 7e39da3f9..c46b2fce3 100644 --- a/jax/scipy/signal.py +++ b/jax/scipy/signal.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.signal import ( fftconvolve as fftconvolve, diff --git a/jax/scipy/sparse/__init__.py b/jax/scipy/sparse/__init__.py index f2e305e82..2968a26b4 100644 --- a/jax/scipy/sparse/__init__.py +++ b/jax/scipy/sparse/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.sparse import linalg as linalg diff --git a/jax/scipy/sparse/linalg.py b/jax/scipy/sparse/linalg.py index d475ddff8..d22e5ec43 100644 --- a/jax/scipy/sparse/linalg.py +++ b/jax/scipy/sparse/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.sparse.linalg import ( cg as cg, diff --git a/jax/scipy/spatial/transform.py b/jax/scipy/spatial/transform.py index 4b532d5f3..63e8dd373 100644 --- a/jax/scipy/spatial/transform.py +++ b/jax/scipy/spatial/transform.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.spatial.transform import ( Rotation as Rotation, diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 5d72339ea..431617d36 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.special import ( bernoulli as bernoulli, diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 7aa73f7b5..7719945f2 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.stats import bernoulli as bernoulli from jax.scipy.stats import beta as beta diff --git a/jax/scipy/stats/bernoulli.py b/jax/scipy/stats/bernoulli.py index 46c1e4825..1623f7113 100644 --- a/jax/scipy/stats/bernoulli.py +++ b/jax/scipy/stats/bernoulli.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.bernoulli import ( logpmf as logpmf, diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index 5c57dda6b..2a4e7f12f 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.beta import ( cdf as cdf, diff --git a/jax/scipy/stats/betabinom.py b/jax/scipy/stats/betabinom.py index 48f955d9e..f8adf68f4 100644 --- a/jax/scipy/stats/betabinom.py +++ b/jax/scipy/stats/betabinom.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.betabinom import ( logpmf as logpmf, diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index 4ff79f5f9..34c9972d0 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.cauchy import ( cdf as cdf, diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py index e17a2e331..47fcb76db 100644 --- a/jax/scipy/stats/chi2.py +++ b/jax/scipy/stats/chi2.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.chi2 import ( cdf as cdf, diff --git a/jax/scipy/stats/dirichlet.py b/jax/scipy/stats/dirichlet.py index 9368defc8..22e9b3cc1 100644 --- a/jax/scipy/stats/dirichlet.py +++ b/jax/scipy/stats/dirichlet.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.dirichlet import ( logpdf as logpdf, diff --git a/jax/scipy/stats/expon.py b/jax/scipy/stats/expon.py index 1ec50ac3f..8f5c0a068 100644 --- a/jax/scipy/stats/expon.py +++ b/jax/scipy/stats/expon.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.expon import ( logpdf as logpdf, diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py index 8efecafed..531a1e300 100644 --- a/jax/scipy/stats/gamma.py +++ b/jax/scipy/stats/gamma.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.gamma import ( cdf as cdf, diff --git a/jax/scipy/stats/gennorm.py b/jax/scipy/stats/gennorm.py index c903ff606..c760575fa 100644 --- a/jax/scipy/stats/gennorm.py +++ b/jax/scipy/stats/gennorm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.gennorm import ( cdf as cdf, diff --git a/jax/scipy/stats/geom.py b/jax/scipy/stats/geom.py index 75f917fc2..eb12dbb5a 100644 --- a/jax/scipy/stats/geom.py +++ b/jax/scipy/stats/geom.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.geom import ( logpmf as logpmf, diff --git a/jax/scipy/stats/laplace.py b/jax/scipy/stats/laplace.py index 3abe62020..8f182804d 100644 --- a/jax/scipy/stats/laplace.py +++ b/jax/scipy/stats/laplace.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.laplace import ( cdf as cdf, diff --git a/jax/scipy/stats/logistic.py b/jax/scipy/stats/logistic.py index c25a06856..7cdb26fb1 100644 --- a/jax/scipy/stats/logistic.py +++ b/jax/scipy/stats/logistic.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.logistic import ( cdf as cdf, diff --git a/jax/scipy/stats/multinomial.py b/jax/scipy/stats/multinomial.py index 723d1a645..392ca4055 100644 --- a/jax/scipy/stats/multinomial.py +++ b/jax/scipy/stats/multinomial.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.multinomial import ( logpmf as logpmf, diff --git a/jax/scipy/stats/multivariate_normal.py b/jax/scipy/stats/multivariate_normal.py index 95ad355c7..94c4cc50a 100644 --- a/jax/scipy/stats/multivariate_normal.py +++ b/jax/scipy/stats/multivariate_normal.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.multivariate_normal import ( logpdf as logpdf, diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py index f47765adf..563e40ce0 100644 --- a/jax/scipy/stats/norm.py +++ b/jax/scipy/stats/norm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.norm import ( cdf as cdf, diff --git a/jax/scipy/stats/pareto.py b/jax/scipy/stats/pareto.py index bf27ea205..5e46fd5d0 100644 --- a/jax/scipy/stats/pareto.py +++ b/jax/scipy/stats/pareto.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.pareto import ( logpdf as logpdf, diff --git a/jax/scipy/stats/poisson.py b/jax/scipy/stats/poisson.py index 2e857bc15..5fcde905f 100644 --- a/jax/scipy/stats/poisson.py +++ b/jax/scipy/stats/poisson.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.poisson import ( logpmf as logpmf, diff --git a/jax/scipy/stats/t.py b/jax/scipy/stats/t.py index d92fcab97..694bcb0b0 100644 --- a/jax/scipy/stats/t.py +++ b/jax/scipy/stats/t.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.t import ( logpdf as logpdf, diff --git a/jax/scipy/stats/truncnorm.py b/jax/scipy/stats/truncnorm.py index 28d5533b0..cb8e8958d 100644 --- a/jax/scipy/stats/truncnorm.py +++ b/jax/scipy/stats/truncnorm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.truncnorm import ( cdf as cdf, diff --git a/jax/scipy/stats/uniform.py b/jax/scipy/stats/uniform.py index d0a06c673..fa754125f 100644 --- a/jax/scipy/stats/uniform.py +++ b/jax/scipy/stats/uniform.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.uniform import ( logpdf as logpdf, diff --git a/jax/scipy/stats/vonmises.py b/jax/scipy/stats/vonmises.py index 8de7fba47..6572e43f6 100644 --- a/jax/scipy/stats/vonmises.py +++ b/jax/scipy/stats/vonmises.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.vonmises import ( logpdf as logpdf, diff --git a/jax/scipy/stats/wrapcauchy.py b/jax/scipy/stats/wrapcauchy.py index 6e2420c5a..eb1768f0c 100644 --- a/jax/scipy/stats/wrapcauchy.py +++ b/jax/scipy/stats/wrapcauchy.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.wrapcauchy import ( logpdf as logpdf, diff --git a/jax/sharding.py b/jax/sharding.py index ea92e9d17..26c542292 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.sharding import Sharding as Sharding from jax._src.sharding_impls import ( diff --git a/jax/stages.py b/jax/stages.py index 6ffc3144c..3e7e461c3 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -22,7 +22,7 @@ For more, see the `AOT walkthrough as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.stages import ( Compiled as Compiled, diff --git a/jax/test_util.py b/jax/test_util.py index 5d4f5ed0a..176f4521b 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.public_test_util import ( check_grads as check_grads, diff --git a/jax/tree_util.py b/jax/tree_util.py index b4854c7df..956d79b9b 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -36,7 +36,7 @@ for examples. """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.tree_util import ( DictKey as DictKey, diff --git a/jax/util.py b/jax/util.py index c1259e9c5..8071f77df 100644 --- a/jax/util.py +++ b/jax/util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.util import ( HashableFunction as HashableFunction, diff --git a/jax/version.py b/jax/version.py index c6e4b3ad1..6c64d75b9 100644 --- a/jax/version.py +++ b/jax/version.py @@ -115,7 +115,7 @@ def _get_cmdclass(pkg_source_path): # missing or outdated. Because _write_version(...) modifies the copy of # this file in the build tree, re-building from the same JAX directory # would not automatically re-copy a clean version, and _write_version - # would fail without this deletion. See google/jax#18252. + # would fail without this deletion. See jax-ml/jax#18252. if os.path.isfile(this_file_in_build_dir): os.unlink(this_file_in_build_dir) super().run() diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index 9ccf3bf44..a84a6b34e 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -51,7 +51,7 @@ setup( packages=[package_name], python_requires=">=3.9", install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/jaxlib/README.md b/jaxlib/README.md index 74e1e5b36..cee5f246d 100644 --- a/jaxlib/README.md +++ b/jaxlib/README.md @@ -4,4 +4,4 @@ jaxlib is the support library for JAX. While JAX itself is a pure Python package jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main -JAX README: https://github.com/google/jax/. +JAX README: https://github.com/jax-ml/jax/. diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 215313f9b..dea9503c7 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -66,7 +66,7 @@ setup( 'numpy>=1.24', 'ml_dtypes>=0.2.0', ], - url='https://github.com/google/jax', + url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ "Programming Language :: Python :: 3.10", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 6305b0c24..52a17c451 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): We don't entirely know why this happens, but in some build environments we seem to target the wrong Mac OS version. - https://github.com/google/jax/issues/3867 + https://github.com/jax-ml/jax/issues/3867 This check makes sure we don't release wheels that have this dependency. """ diff --git a/setup.py b/setup.py index 81eef74e0..e807ff3b0 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ setup( f"jax-cuda12-plugin=={_current_jaxlib_version}", ], }, - url='https://github.com/google/jax', + url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ "Programming Language :: Python :: 3.10", diff --git a/tests/api_test.py b/tests/api_test.py index 1deaa4c08..adce61d65 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -467,7 +467,7 @@ class JitTest(jtu.BufferDonationTestCase): ("argnames", "donate_argnames", ('array',)), ) def test_jnp_array_copy(self, argnum_type, argnum_val): - # https://github.com/google/jax/issues/3412 + # https://github.com/jax-ml/jax/issues/3412 @partial(jit, **{argnum_type: argnum_val}) def _test(array): @@ -905,7 +905,7 @@ class JitTest(jtu.BufferDonationTestCase): @jax.legacy_prng_key('allow') def test_omnistaging(self): - # See https://github.com/google/jax/issues/5206 + # See https://github.com/jax-ml/jax/issues/5206 # TODO(frostig): remove `wrap` once we always enable_custom_prng def wrap(arr): @@ -1409,7 +1409,7 @@ class JitTest(jtu.BufferDonationTestCase): f({E.A: 1.0, E.B: 2.0}) def test_jit_static_argnums_requires_type_equality(self): - # See: https://github.com/google/jax/pull/9311 + # See: https://github.com/jax-ml/jax/pull/9311 @partial(jit, static_argnums=(0,)) def f(k): assert python_should_be_executing @@ -1424,7 +1424,7 @@ class JitTest(jtu.BufferDonationTestCase): self.assertEqual(x, f(x)) def test_caches_depend_on_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 f = lambda: lax.psum(1, "i") g = jax.jit(f) expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)() @@ -1437,7 +1437,7 @@ class JitTest(jtu.BufferDonationTestCase): self.assertEqual(ans, expected) def test_caches_dont_depend_on_unnamed_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 f = jax.jit(lambda: jnp.sin(1)) expected = f() with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 @@ -1446,7 +1446,7 @@ class JitTest(jtu.BufferDonationTestCase): self.assertArraysAllClose(ans, expected, check_dtypes=True) def test_cache_key_defaults(self): - # https://github.com/google/jax/discussions/11875 + # https://github.com/jax-ml/jax/discussions/11875 f = jit(lambda x: (x ** 2).sum()) self.assertEqual(f._cache_size(), 0) x = jnp.arange(5.0) @@ -1455,7 +1455,7 @@ class JitTest(jtu.BufferDonationTestCase): self.assertEqual(f._cache_size(), 1) def test_jit_nan_times_zero(self): - # https://github.com/google/jax/issues/4780 + # https://github.com/jax-ml/jax/issues/4780 def f(x): return 1 + x * 0 self.assertAllClose(f(np.nan), np.nan) @@ -2163,7 +2163,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(aux, [4.**2, 4.]) def test_grad_and_aux_no_tracers(self): - # see https://github.com/google/jax/issues/1950 + # see https://github.com/jax-ml/jax/issues/1950 def f(x): aux = dict(identity=x, p1=x+1) return x ** 2, aux @@ -2322,7 +2322,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(actual, expected) def test_linear_transpose_dce(self): - # https://github.com/google/jax/issues/15660 + # https://github.com/jax-ml/jax/issues/15660 f = jit(lambda x: (2 * x, x > 0)) g = lambda x: f(x)[0] api.linear_transpose(g, 1.)(1.) @@ -2389,7 +2389,7 @@ class APITest(jtu.JaxTestCase): self.assertRaises(TypeError, lambda: jacrev(lambda x: jnp.sin(x))(1 + 2j)) def test_nonholomorphic_jacrev(self): - # code based on https://github.com/google/jax/issues/603 + # code based on https://github.com/jax-ml/jax/issues/603 zs = 0.5j * np.arange(5) + np.arange(5) def f(z): @@ -2401,8 +2401,8 @@ class APITest(jtu.JaxTestCase): @jax.numpy_dtype_promotion('standard') # Test explicitly exercises implicit dtype promotion. def test_heterogeneous_jacfwd(self): - # See https://github.com/google/jax/issues/7157 - # See https://github.com/google/jax/issues/7780 + # See https://github.com/jax-ml/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7780 x = np.array([2.0], dtype=np.float16) y = np.array([3.0], dtype=np.float32) a = (x, y) @@ -2421,8 +2421,8 @@ class APITest(jtu.JaxTestCase): @jax.numpy_dtype_promotion('standard') # Test explicitly exercises implicit dtype promotion. def test_heterogeneous_jacrev(self): - # See https://github.com/google/jax/issues/7157 - # See https://github.com/google/jax/issues/7780 + # See https://github.com/jax-ml/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7780 x = np.array([2.0], dtype=np.float16) y = np.array([3.0], dtype=np.float32) a = (x, y) @@ -2440,7 +2440,7 @@ class APITest(jtu.JaxTestCase): jtu.check_eq(actual, desired) def test_heterogeneous_grad(self): - # See https://github.com/google/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7157 x = np.array(1.0+1j) y = np.array(2.0) a = (x, y) @@ -2512,7 +2512,7 @@ class APITest(jtu.JaxTestCase): self.assertIsNone(y()) def test_namedtuple_transparency(self): - # See https://github.com/google/jax/issues/446 + # See https://github.com/jax-ml/jax/issues/446 Point = collections.namedtuple("Point", ["x", "y"]) def f(pt): @@ -2528,7 +2528,7 @@ class APITest(jtu.JaxTestCase): self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False) def test_namedtuple_subclass_transparency(self): - # See https://github.com/google/jax/issues/806 + # See https://github.com/jax-ml/jax/issues/806 Point = collections.namedtuple("Point", ["x", "y"]) class ZeroPoint(Point): @@ -2705,7 +2705,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(out_shape.shape, (3, 5)) def test_eval_shape_duck_typing2(self): - # https://github.com/google/jax/issues/5683 + # https://github.com/jax-ml/jax/issues/5683 class EasyDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2980,7 +2980,7 @@ class APITest(jtu.JaxTestCase): ])) def test_vmap_in_axes_list(self): - # https://github.com/google/jax/issues/2367 + # https://github.com/jax-ml/jax/issues/2367 dictionary = {'a': 5., 'b': jnp.ones(2)} x = jnp.zeros(3) y = jnp.arange(3.) @@ -2993,7 +2993,7 @@ class APITest(jtu.JaxTestCase): self.assertAllClose(out1, out2) def test_vmap_in_axes_non_tuple_error(self): - # https://github.com/google/jax/issues/18548 + # https://github.com/jax-ml/jax/issues/18548 with self.assertRaisesRegex( TypeError, re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding " @@ -3001,7 +3001,7 @@ class APITest(jtu.JaxTestCase): jax.vmap(lambda x: x['a'], in_axes={'a': 0}) def test_vmap_in_axes_wrong_length_tuple_error(self): - # https://github.com/google/jax/issues/18548 + # https://github.com/jax-ml/jax/issues/18548 with self.assertRaisesRegex( ValueError, re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding to the " @@ -3009,7 +3009,7 @@ class APITest(jtu.JaxTestCase): jax.vmap(lambda x: x['a'], in_axes=(0, {'a': 0}))({'a': jnp.zeros((3, 3))}) def test_vmap_in_axes_tree_prefix_error(self): - # https://github.com/google/jax/issues/795 + # https://github.com/jax-ml/jax/issues/795 value_tree = jnp.ones(3) self.assertRaisesRegex( ValueError, @@ -3030,14 +3030,14 @@ class APITest(jtu.JaxTestCase): api.vmap(lambda x: x, out_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.])) def test_vmap_unbatched_object_passthrough_issue_183(self): - # https://github.com/google/jax/issues/183 + # https://github.com/jax-ml/jax/issues/183 fun = lambda f, x: f(x) vfun = api.vmap(fun, (None, 0)) ans = vfun(lambda x: x + 1, jnp.arange(3)) self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False) def test_vmap_mismatched_keyword(self): - # https://github.com/google/jax/issues/10193 + # https://github.com/jax-ml/jax/issues/10193 @jax.vmap def f(x, y): return x + y @@ -3051,7 +3051,7 @@ class APITest(jtu.JaxTestCase): f(jnp.array([1], 'int32'), y=jnp.array([1, 2], 'int32')) def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): - # https://github.com/google/jax/issues/705 + # https://github.com/jax-ml/jax/issues/705 def h(a, b): return jnp.sum(a) + jnp.sum(b) @@ -3156,12 +3156,12 @@ class APITest(jtu.JaxTestCase): self.assertEqual(vfoo(tree).shape, (6, 2, 5)) def test_vmap_in_axes_bool_error(self): - # https://github.com/google/jax/issues/6372 + # https://github.com/jax-ml/jax/issues/6372 with self.assertRaisesRegex(TypeError, "must be an int"): api.vmap(lambda x: x, in_axes=False)(jnp.zeros(3)) def test_pmap_in_axes_bool_error(self): - # https://github.com/google/jax/issues/6372 + # https://github.com/jax-ml/jax/issues/6372 with self.assertRaisesRegex(TypeError, "must be an int"): api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1)) @@ -3223,7 +3223,7 @@ class APITest(jtu.JaxTestCase): hash(rep) def test_grad_without_enough_args_error_message(self): - # https://github.com/google/jax/issues/1696 + # https://github.com/jax-ml/jax/issues/1696 def f(x, y): return x + y df = api.grad(f, argnums=0) self.assertRaisesRegex( @@ -3301,7 +3301,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(count[0], 0) # cache hits on both fwd and bwd def test_grad_does_not_unflatten_tree_with_none(self): - # https://github.com/google/jax/issues/7546 + # https://github.com/jax-ml/jax/issues/7546 class CustomNode(list): pass @@ -3370,7 +3370,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(count[0], 1) def test_arange_jit(self): - # see https://github.com/google/jax/issues/553 + # see https://github.com/jax-ml/jax/issues/553 def fun(x): r = jnp.arange(x.shape[0])[x] return r @@ -3496,7 +3496,7 @@ class APITest(jtu.JaxTestCase): _ = self._saved_tracer+1 def test_pmap_static_kwarg_error_message(self): - # https://github.com/google/jax/issues/3007 + # https://github.com/jax-ml/jax/issues/3007 def f(a, b): return a + b @@ -3650,7 +3650,7 @@ class APITest(jtu.JaxTestCase): g(1) def test_join_concrete_arrays_with_omnistaging(self): - # https://github.com/google/jax/issues/4622 + # https://github.com/jax-ml/jax/issues/4622 x = jnp.array([1., 2., 3.]) y = jnp.array([1., 2., 4.]) @@ -3673,7 +3673,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(aux, True) def test_linearize_aval_error(self): - # https://github.com/google/jax/issues/4622 + # https://github.com/jax-ml/jax/issues/4622 f = lambda x: x # these should not error @@ -3691,7 +3691,7 @@ class APITest(jtu.JaxTestCase): f_jvp(np.ones(2, np.int32)) def test_grad_of_token_consuming_primitive(self): - # https://github.com/google/jax/issues/5463 + # https://github.com/jax-ml/jax/issues/5463 tokentest_p = core.Primitive("tokentest") tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p)) tokentest_p.def_abstract_eval(lambda x, y: x) @@ -3823,7 +3823,7 @@ class APITest(jtu.JaxTestCase): f(3) def test_leak_checker_avoids_false_positive_custom_jvp(self): - # see https://github.com/google/jax/issues/5636 + # see https://github.com/jax-ml/jax/issues/5636 with jax.checking_leaks(): @jax.custom_jvp def t(y): @@ -3906,7 +3906,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(jnp.ones(1).devices(), system_default_devices) def test_dunder_jax_array(self): - # https://github.com/google/jax/pull/4725 + # https://github.com/jax-ml/jax/pull/4725 class AlexArray: def __init__(self, jax_val): @@ -3939,7 +3939,7 @@ class APITest(jtu.JaxTestCase): self.assertAllClose(np.array(((1, 1), (1, 1))), a2) def test_eval_shape_weak_type(self): - # https://github.com/google/jax/issues/23302 + # https://github.com/jax-ml/jax/issues/23302 arr = jax.numpy.array(1) with jtu.count_jit_tracing_cache_miss() as count: @@ -3980,7 +3980,7 @@ class APITest(jtu.JaxTestCase): f(a, a) # don't crash def test_constant_handler_mro(self): - # https://github.com/google/jax/issues/6129 + # https://github.com/jax-ml/jax/issues/6129 class Foo(enum.IntEnum): bar = 1 @@ -3997,7 +3997,7 @@ class APITest(jtu.JaxTestCase): {"testcase_name": f"{dtype.__name__}", "dtype": dtype} for dtype in jtu.dtypes.all]) def test_constant_handlers(self, dtype): - # https://github.com/google/jax/issues/9380 + # https://github.com/jax-ml/jax/issues/9380 @jax.jit def f(): return jnp.exp(dtype(0)) @@ -4135,7 +4135,7 @@ class APITest(jtu.JaxTestCase): jaxpr = api.make_jaxpr(f)(3) self.assertNotIn('pjit', str(jaxpr)) - # Repro for https://github.com/google/jax/issues/7229. + # Repro for https://github.com/jax-ml/jax/issues/7229. def test_compute_with_large_transfer(self): def f(x, delta): return x + jnp.asarray(delta, x.dtype) @@ -4193,7 +4193,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(actual, expected) def test_leaked_tracer_issue_7613(self): - # from https://github.com/google/jax/issues/7613 + # from https://github.com/jax-ml/jax/issues/7613 import numpy.random as npr def sigmoid(x): @@ -4211,7 +4211,7 @@ class APITest(jtu.JaxTestCase): _ = jax.grad(loss)(A, x) # doesn't crash def test_vmap_caching(self): - # https://github.com/google/jax/issues/7621 + # https://github.com/jax-ml/jax/issues/7621 f = lambda x: jnp.square(x).mean() jf = jax.jit(f) @@ -4299,7 +4299,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(2 * i, g(2, i), msg=i) def test_fastpath_cache_confusion(self): - # https://github.com/google/jax/issues/12542 + # https://github.com/jax-ml/jax/issues/12542 @jax.jit def a(x): return () @@ -4344,7 +4344,7 @@ class APITest(jtu.JaxTestCase): b(8) # don't crash def test_vjp_multiple_arguments_error_message(self): - # https://github.com/google/jax/issues/13099 + # https://github.com/jax-ml/jax/issues/13099 def foo(x): return (x, x) _, f_vjp = jax.vjp(foo, 1.0) @@ -4376,7 +4376,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(jfoo.__module__, "jax") def test_inner_jit_function_retracing(self): - # https://github.com/google/jax/issues/7155 + # https://github.com/jax-ml/jax/issues/7155 inner_count = outer_count = 0 @jax.jit @@ -4403,7 +4403,7 @@ class APITest(jtu.JaxTestCase): self.assertEqual(outer_count, 1) def test_grad_conj_symbolic_zeros(self): - # https://github.com/google/jax/issues/15400 + # https://github.com/jax-ml/jax/issues/15400 f = lambda x: jax.jit(lambda x, y: (x, y))(x, jax.lax.conj(x))[0] out = jax.grad(f)(3.0) # doesn't crash self.assertAllClose(out, 1., check_dtypes=False) @@ -4555,7 +4555,7 @@ class APITest(jtu.JaxTestCase): self._CompileAndCheck(f, args_maker) def test_jvp_asarray_returns_array(self): - # https://github.com/google/jax/issues/15676 + # https://github.com/jax-ml/jax/issues/15676 p, t = jax.jvp(jax.numpy.asarray, (1.,), (2.,)) _check_instance(self, p) _check_instance(self, t) @@ -4716,7 +4716,7 @@ class APITest(jtu.JaxTestCase): f() def test_inline_return_twice(self): - # https://github.com/google/jax/issues/22944 + # https://github.com/jax-ml/jax/issues/22944 @jax.jit def add_one(x: int) -> int: return x + 1 @@ -5074,7 +5074,7 @@ class RematTest(jtu.JaxTestCase): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_no_redundant_flops(self, remat): - # see https://github.com/google/jax/pull/1749#issuecomment-558267584 + # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584 @api.jit def g(x): @@ -5124,7 +5124,7 @@ class RematTest(jtu.JaxTestCase): ('_new', new_checkpoint), ]) def test_remat_symbolic_zeros(self, remat): - # code from https://github.com/google/jax/issues/1907 + # code from https://github.com/jax-ml/jax/issues/1907 key = jax.random.key(0) key, split = jax.random.split(key) @@ -5177,7 +5177,7 @@ class RematTest(jtu.JaxTestCase): ('_new', new_checkpoint), ]) def test_remat_nontrivial_env(self, remat): - # simplified from https://github.com/google/jax/issues/2030 + # simplified from https://github.com/jax-ml/jax/issues/2030 @remat def foo(state, dt=0.5, c=1): @@ -5211,7 +5211,7 @@ class RematTest(jtu.JaxTestCase): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit3(self, remat): - # https://github.com/google/jax/issues/2180 + # https://github.com/jax-ml/jax/issues/2180 def f(w, x): a = jnp.dot(x, w) b = jnp.einsum("btd,bTd->btT", a, a) @@ -5244,7 +5244,7 @@ class RematTest(jtu.JaxTestCase): ('_new', new_checkpoint), ]) def test_remat_scan2(self, remat): - # https://github.com/google/jax/issues/1963 + # https://github.com/jax-ml/jax/issues/1963 def scan_bug(x0): f = lambda x, _: (x + 1, None) @@ -5256,7 +5256,7 @@ class RematTest(jtu.JaxTestCase): jax.grad(scan_bug)(1.0) # doesn't crash def test_remat_jit_static_argnum_omnistaging(self): - # https://github.com/google/jax/issues/2833 + # https://github.com/jax-ml/jax/issues/2833 # NOTE(mattjj): after #3370, this test doesn't actually call remat... def named_call(f): def named_f(*args): @@ -5281,7 +5281,7 @@ class RematTest(jtu.JaxTestCase): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_eval_counter(self, remat): - # https://github.com/google/jax/issues/2737 + # https://github.com/jax-ml/jax/issues/2737 add_one_p = core.Primitive('add_one') add_one = add_one_p.bind @@ -5665,7 +5665,7 @@ class RematTest(jtu.JaxTestCase): # The old implementation of remat worked by data dependence, and so # (potentially large) constants would not be rematerialized and could be # wastefully instantiated. This test checks that the newer remat - # implementation avoids that. See https://github.com/google/jax/pull/8191. + # implementation avoids that. See https://github.com/jax-ml/jax/pull/8191. # no residuals from constants created inside jnp.einsum @partial(new_checkpoint, policy=lambda *_, **__: False) @@ -5790,7 +5790,7 @@ class RematTest(jtu.JaxTestCase): _ = jax.grad(f)(3.) # doesn't crash def test_linearize_caching(self): - # https://github.com/google/jax/issues/9661 + # https://github.com/jax-ml/jax/issues/9661 identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) _, f_lin = jax.linearize(identity, 1.) with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 @@ -5799,7 +5799,7 @@ class RematTest(jtu.JaxTestCase): self.assertEqual(count[0], 1) # cached after first execution def test_vjp_caching(self): - # https://github.com/google/jax/issues/9661 + # https://github.com/jax-ml/jax/issues/9661 identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) _, f_vjp = jax.vjp(identity, 1.) with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841 @@ -6928,7 +6928,7 @@ class CustomJVPTest(jtu.JaxTestCase): check_dtypes=False) def test_kwargs(self): - # from https://github.com/google/jax/issues/1938 + # from https://github.com/jax-ml/jax/issues/1938 @jax.custom_jvp def my_fun(x, y, c=1.): return c * (x + y) @@ -7209,7 +7209,7 @@ class CustomJVPTest(jtu.JaxTestCase): def test_jvp_rule_doesnt_return_pair_error_message(self): - # https://github.com/google/jax/issues/2516 + # https://github.com/jax-ml/jax/issues/2516 @jax.custom_jvp def f(x): @@ -7374,7 +7374,7 @@ class CustomJVPTest(jtu.JaxTestCase): api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) def test_jaxpr_zeros(self): - # from https://github.com/google/jax/issues/2657 + # from https://github.com/jax-ml/jax/issues/2657 @jax.custom_jvp def f(A, b): return A @ b @@ -7420,7 +7420,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_jvps_first_rule_is_none(self): - # https://github.com/google/jax/issues/3389 + # https://github.com/jax-ml/jax/issues/3389 @jax.custom_jvp def f(x, y): return x ** 2 * y @@ -7431,7 +7431,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_concurrent_initial_style(self): - # https://github.com/google/jax/issues/3843 + # https://github.com/jax-ml/jax/issues/3843 def unroll(param, sequence): def scan_f(prev_state, inputs): return prev_state, jax.nn.sigmoid(param * inputs) @@ -7453,7 +7453,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected) def test_nondiff_argnums_vmap_tracer(self): - # https://github.com/google/jax/issues/3964 + # https://github.com/jax-ml/jax/issues/3964 @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) def sample(shape, param, seed): return jax.random.uniform(key=seed, shape=shape, minval=param) @@ -7495,7 +7495,7 @@ class CustomJVPTest(jtu.JaxTestCase): api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) def test_closure_with_vmap(self): - # https://github.com/google/jax/issues/3822 + # https://github.com/jax-ml/jax/issues/3822 alpha = np.float32(2.) def sample(seed): @@ -7515,7 +7515,7 @@ class CustomJVPTest(jtu.JaxTestCase): api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash def test_closure_with_vmap2(self): - # https://github.com/google/jax/issues/8783 + # https://github.com/jax-ml/jax/issues/8783 def h(z): def f(x): @jax.custom_jvp @@ -7660,7 +7660,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_jvp_vmap_broadcasting_interaction(self): - # https://github.com/google/jax/issues/6452 + # https://github.com/jax-ml/jax/issues/6452 def f2(y, z): v1 = z v2 = jnp.sum(y) + z @@ -7678,7 +7678,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertEqual(g.shape, ()) def test_custom_jvp_vmap_broadcasting_interaction_2(self): - # https://github.com/google/jax/issues/5849 + # https://github.com/jax-ml/jax/issues/5849 @jax.custom_jvp def transform(box, R): if jnp.isscalar(box) or box.size == 1: @@ -7716,7 +7716,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) def test_custom_jvp_implicit_broadcasting(self): - # https://github.com/google/jax/issues/6357 + # https://github.com/jax-ml/jax/issues/6357 if config.enable_x64.value: raise unittest.SkipTest("test only applies when x64 is disabled") @@ -7774,7 +7774,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) def test_vmap_inside_defjvp(self): - # https://github.com/google/jax/issues/3201 + # https://github.com/jax-ml/jax/issues/3201 seed = 47 key = jax.random.key(seed) mat = jax.random.normal(key, (2, 3)) @@ -7823,7 +7823,7 @@ class CustomJVPTest(jtu.JaxTestCase): jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash def test_custom_jvp_unbroadcasting(self): - # https://github.com/google/jax/issues/3056 + # https://github.com/jax-ml/jax/issues/3056 a = jnp.array([1., 1.]) @jax.custom_jvp @@ -7841,8 +7841,8 @@ class CustomJVPTest(jtu.JaxTestCase): def test_maybe_perturbed_internal_helper_function(self): # This is a unit test for an internal API. We include it so as not to - # regress https://github.com/google/jax/issues/9567. For an explanation of - # this helper function, see https://github.com/google/jax/issues/6415. + # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of + # this helper function, see https://github.com/jax-ml/jax/issues/6415. def f(x): def g(y, _): z = y * x @@ -7854,7 +7854,7 @@ class CustomJVPTest(jtu.JaxTestCase): jax.jvp(f, (1.0,), (1.0,)) # assertions inside f def test_maybe_perturbed_int_regression(self): - # see https://github.com/google/jax/discussions/9951 + # see https://github.com/jax-ml/jax/discussions/9951 @jax.jit def f(): @@ -7864,7 +7864,7 @@ class CustomJVPTest(jtu.JaxTestCase): f() def test_sinc_constant_function_batching(self): - # https://github.com/google/jax/pull/10756 + # https://github.com/jax-ml/jax/pull/10756 batch_data = jnp.arange(15.).reshape(5, 3) @jax.vmap @@ -7981,7 +7981,7 @@ class CustomJVPTest(jtu.JaxTestCase): _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! def test_symbolic_zeros_under_jit(self): - # https://github.com/google/jax/issues/14833 + # https://github.com/jax-ml/jax/issues/14833 Zero = jax.custom_derivatives.SymbolicZero @jax.custom_jvp @@ -8015,7 +8015,7 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) def test_run_rules_more_than_once(self): - # https://github.com/google/jax/issues/16614 + # https://github.com/jax-ml/jax/issues/16614 @jax.custom_jvp def f(x, y): @@ -8206,7 +8206,7 @@ class CustomVJPTest(jtu.JaxTestCase): lambda: api.jvp(jit(f), (3.,), (1.,))) def test_kwargs(self): - # from https://github.com/google/jax/issues/1938 + # from https://github.com/jax-ml/jax/issues/1938 @jax.custom_vjp def my_fun(x, y, c=1.): return c * (x + y) @@ -8502,7 +8502,7 @@ class CustomVJPTest(jtu.JaxTestCase): api.jit(foo)(arr) # doesn't crash def test_lowering_out_of_traces(self): - # https://github.com/google/jax/issues/2578 + # https://github.com/jax-ml/jax/issues/2578 class F(collections.namedtuple("F", ["a"])): def __call__(self, x): @@ -8515,7 +8515,7 @@ class CustomVJPTest(jtu.JaxTestCase): jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash def test_clip_gradient(self): - # https://github.com/google/jax/issues/2784 + # https://github.com/jax-ml/jax/issues/2784 @jax.custom_vjp def _clip_gradient(lo, hi, x): return x # identity function when not differentiating @@ -8538,7 +8538,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(g, jnp.array(0.2)) def test_nestable_vjp(self): - # Verify that https://github.com/google/jax/issues/3667 is resolved. + # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. def f(x): return x ** 2 @@ -8571,7 +8571,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(y, jnp.array(6.0)) def test_initial_style_vmap_2(self): - # https://github.com/google/jax/issues/4173 + # https://github.com/jax-ml/jax/issues/4173 x = jnp.ones((10, 3)) # Create the custom function @@ -8837,7 +8837,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_vjp_closure_4521(self): - # https://github.com/google/jax/issues/4521 + # https://github.com/jax-ml/jax/issues/4521 @jax.custom_vjp def g(x, y): return None @@ -8954,7 +8954,7 @@ class CustomVJPTest(jtu.JaxTestCase): def test_closure_convert_mixed_consts(self): # Like test_closure_convert, but close over values that # participate in AD as well as values that do not. - # See https://github.com/google/jax/issues/6415 + # See https://github.com/jax-ml/jax/issues/6415 def cos_after(fn, x): converted_fn, aux_args = jax.closure_convert(fn, x) @@ -8993,7 +8993,7 @@ class CustomVJPTest(jtu.JaxTestCase): self.assertAllClose(g_x, 17. * x, check_dtypes=False) def test_closure_convert_pytree_mismatch(self): - # See https://github.com/google/jax/issues/23588 + # See https://github.com/jax-ml/jax/issues/23588 def f(x, z): return z * x @@ -9021,7 +9021,7 @@ class CustomVJPTest(jtu.JaxTestCase): jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash def test_custom_vjp_scan_batching_edge_case(self): - # https://github.com/google/jax/issues/5832 + # https://github.com/jax-ml/jax/issues/5832 @jax.custom_vjp def mul(x, coeff): return x * coeff def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) @@ -9052,7 +9052,7 @@ class CustomVJPTest(jtu.JaxTestCase): modes=['rev']) def test_closure_with_vmap2(self): - # https://github.com/google/jax/issues/8783 + # https://github.com/jax-ml/jax/issues/8783 def h(z): def f(x): @jax.custom_vjp @@ -9094,7 +9094,7 @@ class CustomVJPTest(jtu.JaxTestCase): jax.grad(f)(A([1.])) # doesn't crash def test_vmap_vjp_called_twice(self): - # https://github.com/google/jax/pull/14728 + # https://github.com/jax-ml/jax/pull/14728 @jax.custom_vjp def f(x): return x @@ -9390,7 +9390,7 @@ class CustomVJPTest(jtu.JaxTestCase): _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! def test_run_rules_more_than_once(self): - # https://github.com/google/jax/issues/16614 + # https://github.com/jax-ml/jax/issues/16614 @jax.custom_vjp def f(x, y): @@ -9420,7 +9420,7 @@ class CustomVJPTest(jtu.JaxTestCase): g(1.) # doesn't crash def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): - # https://github.com/google/jax/issues/8356 + # https://github.com/jax-ml/jax/issues/8356 @jax.custom_vjp def f(x): return x[0] @@ -9618,7 +9618,7 @@ class CustomVJPTest(jtu.JaxTestCase): jax.grad(f)(x, y) # Doesn't error def test_optimize_remat_custom_vmap(self): - # See https://github.com/google/jax/pull/23000 + # See https://github.com/jax-ml/jax/pull/23000 @jax.custom_vjp def f(x, y): return jnp.sin(x) * y @@ -10908,7 +10908,7 @@ class AutodidaxTest(jtu.JaxTestCase): class GarbageCollectionTest(jtu.JaxTestCase): def test_xla_gc_callback(self): - # https://github.com/google/jax/issues/14882 + # https://github.com/jax-ml/jax/issues/14882 x_np = np.arange(10, dtype='int32') x_jax = jax.device_put(x_np) x_np_weakref = weakref.ref(x_np) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index c2cd4c0f9..5585f1bcc 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -185,7 +185,7 @@ class DLPackTest(jtu.JaxTestCase): @unittest.skipIf(not tf, "Test requires TensorFlow") def testTensorFlowToJaxInt64(self): - # See https://github.com/google/jax/issues/11895 + # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64))) dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32 diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 4378a3c75..4eb354a8d 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -323,7 +323,7 @@ class AttrsTest(jtu.JaxTestCase): self.assertEqual(count, 1) def test_tracer_lifetime_bug(self): - # regression test for https://github.com/google/jax/issues/20082 + # regression test for https://github.com/jax-ml/jax/issues/20082 class StatefulRNG: key: jax.Array diff --git a/tests/batching_test.py b/tests/batching_test.py index 6cd8c7bc2..2b0b0d63a 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -335,7 +335,7 @@ class BatchingTest(jtu.JaxTestCase): self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): - # test modeling the code in https://github.com/google/jax/issues/54 + # test modeling the code in https://github.com/jax-ml/jax/issues/54 def func(xs): return jnp.array(list(xs)) @@ -345,7 +345,7 @@ class BatchingTest(jtu.JaxTestCase): jacfwd(func)(xs) # don't crash def testAny(self): - # test modeling the code in https://github.com/google/jax/issues/108 + # test modeling the code in https://github.com/jax-ml/jax/issues/108 ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]])) expected = jnp.array([True, False]) @@ -368,7 +368,7 @@ class BatchingTest(jtu.JaxTestCase): def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax - # see https://github.com/google/jax/issues/1613 for an explanation of why we + # see https://github.com/jax-ml/jax/issues/1613 for an explanation of why we # need to use np rather than np to create x and idx x = jnp.arange(30).reshape((10, 3)) @@ -933,7 +933,7 @@ class BatchingTest(jtu.JaxTestCase): rtol=jtu.default_gradient_tolerance) def testIssue387(self): - # https://github.com/google/jax/issues/387 + # https://github.com/jax-ml/jax/issues/387 R = self.rng().rand(100, 2) def dist_sq(R): @@ -951,7 +951,7 @@ class BatchingTest(jtu.JaxTestCase): @jax.legacy_prng_key('allow') def testIssue489(self): - # https://github.com/google/jax/issues/489 + # https://github.com/jax-ml/jax/issues/489 def f(key): def body_fn(uk): key = uk[1] @@ -1131,7 +1131,7 @@ class BatchingTest(jtu.JaxTestCase): x - np.arange(x.shape[0], dtype='int32')) def testVmapKwargs(self): - # https://github.com/google/jax/issues/912 + # https://github.com/jax-ml/jax/issues/912 def f(a, b): return (2*a, 3*b) @@ -1242,7 +1242,7 @@ class BatchingTest(jtu.JaxTestCase): self.assertEqual(jax.vmap(f)(jnp.ones((2, 3))).shape, (2, 3)) def testPpermuteBatcherTrivial(self): - # https://github.com/google/jax/issues/8688 + # https://github.com/jax-ml/jax/issues/8688 def ppermute(input): return jax.lax.ppermute(input, axis_name="i", perm=[[0, 1], [1, 0]]) @@ -1255,7 +1255,7 @@ class BatchingTest(jtu.JaxTestCase): self.assertAllClose(ans, jnp.ones(2), check_dtypes=False) def testBatchingPreservesWeakType(self): - # Regression test for https://github.com/google/jax/issues/10025 + # Regression test for https://github.com/jax-ml/jax/issues/10025 x = jnp.ravel(1) self.assertTrue(dtypes.is_weakly_typed(x)) @vmap diff --git a/tests/core_test.py b/tests/core_test.py index 0838702c4..94b701090 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -349,7 +349,7 @@ class CoreTest(jtu.JaxTestCase): g_vmap(jnp.ones((1, ))) def test_concrete_array_string_representation(self): - # https://github.com/google/jax/issues/5364 + # https://github.com/jax-ml/jax/issues/5364 self.assertEqual( str(core.ConcreteArray(np.dtype(np.int32), np.array([1], dtype=np.int32))), @@ -369,7 +369,7 @@ class CoreTest(jtu.JaxTestCase): self.assertEqual(dropvar.aval, aval) def test_input_residual_forwarding(self): - # https://github.com/google/jax/pull/11151 + # https://github.com/jax-ml/jax/pull/11151 x = jnp.arange(3 * 4.).reshape(3, 4) y = jnp.arange(4 * 3.).reshape(4, 3) diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py index 830526826..857dc34d4 100644 --- a/tests/custom_linear_solve_test.py +++ b/tests/custom_linear_solve_test.py @@ -291,7 +291,7 @@ class CustomLinearSolveTest(jtu.JaxTestCase): jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3) - # regression test for https://github.com/google/jax/issues/1536 + # regression test for https://github.com/jax-ml/jax/issues/1536 jtu.check_grads(jax.jit(linear_solve), (a, b), order=2, rtol={np.float32: 2e-3}) @@ -396,7 +396,7 @@ class CustomLinearSolveTest(jtu.JaxTestCase): def test_custom_linear_solve_pytree_with_aux(self): # Check that lax.custom_linear_solve handles # pytree inputs + has_aux=True - # https://github.com/google/jax/pull/13093 + # https://github.com/jax-ml/jax/pull/13093 aux_orig = {'a': 1, 'b': 2} b = {'c': jnp.ones(2), 'd': jnp.ones(3)} diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 19e2a5893..020c9f744 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -127,7 +127,7 @@ class DebugNaNsTest(jtu.JaxTestCase): ans.block_until_ready() def testDebugNansJitWithDonation(self): - # https://github.com/google/jax/issues/12514 + # https://github.com/jax-ml/jax/issues/12514 a = jnp.array(0.) with self.assertRaises(FloatingPointError): ans = jax.jit(lambda x: 0. / x, donate_argnums=(0,))(a) @@ -214,7 +214,7 @@ class DebugInfsTest(jtu.JaxTestCase): f(1) def testDebugNansDoesntCorruptCaches(self): - # https://github.com/google/jax/issues/6614 + # https://github.com/jax-ml/jax/issues/6614 @jax.jit def f(x): return jnp.divide(x, x) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index c6b12e2d8..e736e06da 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -667,7 +667,7 @@ class TestPromotionTables(jtu.JaxTestCase): {"testcase_name": f"_{typ}", "typ": typ} for typ in [bool, int, float, complex]) def testScalarWeakTypes(self, typ): - # Regression test for https://github.com/google/jax/issues/11377 + # Regression test for https://github.com/jax-ml/jax/issues/11377 val = typ(0) result1 = jnp.array(val) @@ -806,7 +806,7 @@ class TestPromotionTables(jtu.JaxTestCase): for weak_type in [True, False] ) def testUnaryPromotion(self, dtype, weak_type): - # Regression test for https://github.com/google/jax/issues/6051 + # Regression test for https://github.com/jax-ml/jax/issues/6051 if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) @@ -852,7 +852,7 @@ class TestPromotionTables(jtu.JaxTestCase): self.skipTest("XLA support for float8 is incomplete.") if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") - # Regression test for https://github.com/google/jax/issues/6051 + # Regression test for https://github.com/jax-ml/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) with jax.numpy_dtype_promotion(promotion): y = (x + x) diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 101dddccb..cc2419fb3 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -621,7 +621,7 @@ class DynamicShapeStagingTest(jtu.JaxTestCase): self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) def test_shape_validation(self): - # Regression test for https://github.com/google/jax/issues/18937 + # Regression test for https://github.com/jax-ml/jax/issues/18937 msg = r"Shapes must be 1D sequences of integer scalars, got .+" with self.assertRaisesRegex(TypeError, msg): jax.make_jaxpr(jnp.ones)(5.0) diff --git a/tests/export_test.py b/tests/export_test.py index d5884b7e6..0d946d84d 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1334,7 +1334,7 @@ class JaxExportTest(jtu.JaxTestCase): def test_grad_sharding_different_mesh(self): # Export and serialize with two similar meshes, the only difference being # the order of the devices. grad and serialization should not fail. - # https://github.com/google/jax/issues/21314 + # https://github.com/jax-ml/jax/issues/21314 def f(x): return jnp.sum(x * 2.) diff --git a/tests/fft_test.py b/tests/fft_test.py index 05fa96a93..a87b7b66e 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -175,7 +175,7 @@ class FftTest(jtu.JaxTestCase): self.assertEqual(dtype, expected_dtype) def testIrfftTranspose(self): - # regression test for https://github.com/google/jax/issues/6223 + # regression test for https://github.com/jax-ml/jax/issues/6223 def build_matrix(linear_func, size): return jax.vmap(linear_func)(jnp.eye(size, size)) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 944b47dc8..837d205fb 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1035,7 +1035,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): ( 5.00 2 )""", testing_stream.output) def test_tap_grad_float0_result(self): - # https://github.com/google/jax/issues/7340 + # https://github.com/jax-ml/jax/issues/7340 # x is a Tuple[f32[2], s32[3]] x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) @@ -1058,7 +1058,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): ( [0.70 0.80] [11 12 13] )""", testing_stream.output) def test_tap_higher_order_grad_float0_result(self): - # https://github.com/google/jax/issues/7340 + # https://github.com/jax-ml/jax/issues/7340 # x is a Tuple[f32[2], s32[3]] x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) @@ -1935,7 +1935,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): hcb.id_tap(func, 1, y=2) def test_tap_id_tap_random_key(self): - # See https://github.com/google/jax/issues/13949 + # See https://github.com/jax-ml/jax/issues/13949 with jax.enable_custom_prng(): @jax.jit def f(x): diff --git a/tests/image_test.py b/tests/image_test.py index f3cd56ed7..0f6341086 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -180,7 +180,7 @@ class ImageTest(jtu.JaxTestCase): antialias=[False, True], ) def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias): - # Regression test for https://github.com/google/jax/issues/7586 + # Regression test for https://github.com/jax-ml/jax/issues/7586 image = np.ones(image_shape, dtype) out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias) self.assertArraysEqual(out, jnp.zeros(target_shape, dtype)) diff --git a/tests/jet_test.py b/tests/jet_test.py index b1e2ef3f8..4e437c044 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -404,7 +404,7 @@ class JetTest(jtu.JaxTestCase): self.assertArraysEqual(g_out_series, f_out_series) def test_add_any(self): - # https://github.com/google/jax/issues/5217 + # https://github.com/jax-ml/jax/issues/5217 f = lambda x, eps: x * eps + eps + x def g(eps): x = jnp.array(1.) @@ -412,7 +412,7 @@ class JetTest(jtu.JaxTestCase): jet(g, (1.,), ([1.],)) # doesn't crash def test_scatter_add(self): - # very basic test from https://github.com/google/jax/issues/5365 + # very basic test from https://github.com/jax-ml/jax/issues/5365 def f(x): x0 = x[0] x1 = x[1] diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index ab3a18317..78d90cb8a 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -424,7 +424,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): assert "Precision.HIGHEST" in s def testDotPreferredElementType(self): - # https://github.com/google/jax/issues/10818 + # https://github.com/jax-ml/jax/issues/10818 x = jax.numpy.ones((), jax.numpy.float16) def f(x): return jax.lax.dot_general(x, x, (((), ()), ((), ())), @@ -513,7 +513,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): rtol={np.float32: 3e-3}) def testPowSecondDerivative(self): - # https://github.com/google/jax/issues/12033 + # https://github.com/jax-ml/jax/issues/12033 x, y = 4.0, 0.0 expected = ((0.0, 1/x), (1/x, np.log(x) ** 2)) @@ -528,18 +528,18 @@ class LaxAutodiffTest(jtu.JaxTestCase): with self.subTest("zero to the zero"): result = jax.grad(lax.pow)(0.0, 0.0) # TODO(jakevdp) special-case zero in a way that doesn't break other cases - # See https://github.com/google/jax/pull/12041#issuecomment-1222766191 + # See https://github.com/jax-ml/jax/pull/12041#issuecomment-1222766191 # self.assertEqual(result, 0.0) self.assertAllClose(result, np.nan) def testPowIntPowerAtZero(self): - # https://github.com/google/jax/issues/14397 + # https://github.com/jax-ml/jax/issues/14397 ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0) self.assertAllClose(ans, 0., check_dtypes=False) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testPowIntPowerAtZero2(self): - # https://github.com/google/jax/issues/17995 + # https://github.com/jax-ml/jax/issues/17995 a = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=int)) b = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=float)) c = lambda z: 1 + z @@ -634,7 +634,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) def testDynamicSliceValueAndGrad(self): - # Regression test for https://github.com/google/jax/issues/10984 + # Regression test for https://github.com/jax-ml/jax/issues/10984 # Issue arose due to an out-of-range negative index. rng = jtu.rand_default(self.rng()) shape = (5, 5) @@ -649,7 +649,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): self.assertAllClose(result1, result2) def testDynamicUpdateSliceValueAndGrad(self): - # Regression test for https://github.com/google/jax/issues/10984 + # Regression test for https://github.com/jax-ml/jax/issues/10984 # Issue arose due to an out-of-range negative index. rng = jtu.rand_default(self.rng()) shape = (5, 5) @@ -1004,7 +1004,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) def testScatterGradSymbolicZeroUpdate(self): - # https://github.com/google/jax/issues/1901 + # https://github.com/jax-ml/jax/issues/1901 def f(x): n = x.shape[0] y = np.arange(n, dtype=x.dtype) @@ -1111,7 +1111,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): check_grads(lax.rem, (x, y), 2, ["fwd", "rev"]) def testHigherOrderGradientOfReciprocal(self): - # Regression test for https://github.com/google/jax/issues/3136 + # Regression test for https://github.com/jax-ml/jax/issues/3136 def inv(x): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x @@ -1150,7 +1150,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): jax.jacrev(f)(x) def testPowShapeMismatch(self): - # Regression test for https://github.com/google/jax/issues/17294 + # Regression test for https://github.com/jax-ml/jax/issues/17294 x = lax.iota('float32', 4) y = 2 actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 37ad22063..7fb118d47 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -733,7 +733,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertEqual(fun(4), (8, 16)) def testCondPredIsNone(self): - # see https://github.com/google/jax/issues/11574 + # see https://github.com/jax-ml/jax/issues/11574 def f(pred, x): return lax.cond(pred, lambda x: x + 1, lambda x: x + 2, x) @@ -743,7 +743,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): lambda: jax.jit(f)(None, 1.)) def testCondTwoOperands(self): - # see https://github.com/google/jax/issues/8469 + # see https://github.com/jax-ml/jax/issues/8469 add, mul = lax.add, lax.mul def fun(x): @@ -775,7 +775,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertEqual(fun(1), cfun(1)) def testCondCallableOperands(self): - # see https://github.com/google/jax/issues/16413 + # see https://github.com/jax-ml/jax/issues/16413 @tree_util.register_pytree_node_class class Foo: @@ -1560,7 +1560,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): {"testcase_name": f"_{name}", "cond": cond} for cond, name in COND_IMPLS) def testCondVmapGrad(self, cond): - # https://github.com/google/jax/issues/2264 + # https://github.com/jax-ml/jax/issues/2264 def f_1(x): return x ** 2 def f_2(x): return x ** 3 @@ -1839,7 +1839,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): def testIssue711(self, scan): # Tests reverse-mode differentiation through a scan for which the scanned # function also involves reverse-mode differentiation. - # See https://github.com/google/jax/issues/711 + # See https://github.com/jax-ml/jax/issues/711 def harmonic_bond(conf, params): return jnp.sum(conf * params) @@ -2078,7 +2078,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False) def testIssue757(self): - # code from https://github.com/google/jax/issues/757 + # code from https://github.com/jax-ml/jax/issues/757 def fn(a): return jnp.cos(a) @@ -2107,7 +2107,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertAllClose(actual, expected) def testMapEmpty(self): - # https://github.com/google/jax/issues/2412 + # https://github.com/jax-ml/jax/issues/2412 ans = lax.map(lambda x: x * x, jnp.array([])) expected = jnp.array([]) self.assertAllClose(ans, expected) @@ -2164,7 +2164,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): lax.while_loop(cond, body, 0) def test_caches_depend_on_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 scanned_f = lambda _, __: (lax.psum(1, 'i'), None) f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() @@ -2443,7 +2443,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertEqual(h, length) def test_disable_jit_cond_with_vmap(self): - # https://github.com/google/jax/issues/3093 + # https://github.com/jax-ml/jax/issues/3093 def fn(t): return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1) fn = jax.vmap(fn) @@ -2452,14 +2452,14 @@ class LaxControlFlowTest(jtu.JaxTestCase): _ = fn(jnp.array([1])) # doesn't crash def test_disable_jit_while_loop_with_vmap(self): - # https://github.com/google/jax/issues/2823 + # https://github.com/jax-ml/jax/issues/2823 def trivial_while(y): return lax.while_loop(lambda x: x < 10.0, lambda x: x + 1.0, y) with jax.disable_jit(): jax.vmap(trivial_while)(jnp.array([3.0,4.0])) # doesn't crash def test_vmaps_of_while_loop(self): - # https://github.com/google/jax/issues/3164 + # https://github.com/jax-ml/jax/issues/3164 def f(x, n): return lax.fori_loop(0, n, lambda _, x: x + 1, x) x, n = jnp.arange(3), jnp.arange(4) jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash @@ -2567,7 +2567,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): lambda: core.check_jaxpr(jaxpr)) def test_cond_transformation_rule_with_consts(self): - # https://github.com/google/jax/pull/9731 + # https://github.com/jax-ml/jax/pull/9731 @jax.custom_jvp def f(x): @@ -2584,7 +2584,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): jax.jvp(g, (x,), (x,)) # doesn't crash def test_cond_excessive_compilation(self): - # Regression test for https://github.com/google/jax/issues/14058 + # Regression test for https://github.com/jax-ml/jax/issues/14058 def f(x): return x + 1 @@ -2632,7 +2632,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): ('new_remat', new_checkpoint), ]) def test_scan_vjp_forwards_extensive_residuals(self, remat): - # https://github.com/google/jax/issues/4510 + # https://github.com/jax-ml/jax/issues/4510 def cumprod(x): s = jnp.ones((2, 32), jnp.float32) return lax.scan(lambda s, x: (x*s, s), s, x) @@ -2671,7 +2671,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): (jnp.array([1.]), jnp.array([[0., 1., 2., 3., 4.]])), check_dtypes=False) def test_xla_cpu_gpu_loop_cond_bug(self): - # https://github.com/google/jax/issues/5900 + # https://github.com/jax-ml/jax/issues/5900 def deriv(f): return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0) @@ -2750,7 +2750,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): jax.grad(f)(1.) # doesn't crash def test_custom_jvp_tangent_cond_transpose(self): - # https://github.com/google/jax/issues/14026 + # https://github.com/jax-ml/jax/issues/14026 def mask_fun(arr, choice): out = (1 - choice) * arr.sum() + choice * (1 - arr.sum()) return out @@ -2997,7 +2997,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertIsInstance(y, jax.Array) def test_cond_memory_leak(self): - # https://github.com/google/jax/issues/12719 + # https://github.com/jax-ml/jax/issues/12719 def leak(): data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 02fecb7b3..d3dada0d7 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -914,7 +914,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): check_dtypes=False) def testRoundMethod(self): - # https://github.com/google/jax/issues/15190 + # https://github.com/jax-ml/jax/issues/15190 (jnp.arange(3.) / 5.).round() # doesn't crash @jtu.sample_product(shape=[(5,), (5, 2)]) @@ -1425,7 +1425,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): y=[0, 32, 64, 128], ) def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/google/jax/issues/5987 + # Regression test for https://github.com/jax-ml/jax/issues/5987 args_maker = lambda: [x, y] self._CheckAgainstNumpy(np.power, jnp.power, args_maker) self._CompileAndCheck(jnp.power, args_maker) @@ -1536,7 +1536,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testConcatenateAxisNone(self): - # https://github.com/google/jax/issues/3419 + # https://github.com/jax-ml/jax/issues/3419 a = jnp.array([[1, 2], [3, 4]]) b = jnp.array([[5]]) jnp.concatenate((a, b), axis=None) @@ -2768,7 +2768,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testDiffPrepoendScalar(self): - # Regression test for https://github.com/google/jax/issues/19362 + # Regression test for https://github.com/jax-ml/jax/issues/19362 x = jnp.arange(10) result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) @@ -3359,7 +3359,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): _check([jnp.complex128(1)], np.complex128, False) # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/google/jax/issues/8945) + # (regression test for https://github.com/jax-ml/jax/issues/8945) _check([0, np.int16(1)], np.int16, False) _check([0.0, np.float16(1)], np.float16, False) @@ -3932,17 +3932,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # TODO(mattjj): test other ndarray-like method overrides def testNpMean(self): - # from https://github.com/google/jax/issues/125 + # from https://github.com/jax-ml/jax/issues/125 x = jnp.eye(3, dtype=float) + 0. ans = np.mean(x) self.assertAllClose(ans, np.array(1./3), check_dtypes=False) def testArangeOnFloats(self): np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/google/jax/issues/145 + # from https://github.com/jax-ml/jax/issues/145 self.assertAllClose(np_arange(0.0, 1.0, 0.1), jnp.arange(0.0, 1.0, 0.1)) - # from https://github.com/google/jax/issues/3450 + # from https://github.com/jax-ml/jax/issues/3450 self.assertAllClose(np_arange(2.5), jnp.arange(2.5)) self.assertAllClose(np_arange(0., 2.5), @@ -4303,7 +4303,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_op, args_maker) def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/google/jax/issues/5088 + # https://github.com/jax-ml/jax/issues/5088 h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-1) @@ -4513,9 +4513,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # NOTE(mattjj): I disabled this test when removing lax._safe_mul because # introducing the convention 0 * inf = 0 leads to silently wrong results in # some cases. See this comment for details: - # https://github.com/google/jax/issues/1052#issuecomment-514083352 + # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352 # def testIssue347(self): - # # https://github.com/google/jax/issues/347 + # # https://github.com/jax-ml/jax/issues/347 # def test_fail(x): # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) # ones = jnp.ones_like(x) @@ -4526,7 +4526,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # assert not np.any(np.isnan(result)) def testIssue453(self): - # https://github.com/google/jax/issues/453 + # https://github.com/jax-ml/jax/issues/453 a = np.arange(6) + 1 ans = jnp.reshape(a, (3, 2), order='F') expected = np.reshape(a, (3, 2), order='F') @@ -4538,7 +4538,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): op=["atleast_1d", "atleast_2d", "atleast_3d"], ) def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 + # Fixes: https://github.com/jax-ml/jax/issues/634 np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) jnp_fun = lambda arg: getattr(jnp, op)(arg) args_maker = lambda: [dtype(2)] @@ -5147,7 +5147,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp.ones(2) + 3 # don't want to warn for scalars def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 + # tests https://github.com/jax-ml/jax/issues/1271 @jax.jit def foo(x): return jnp.stack(x) @@ -5316,7 +5316,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testZerosShapeErrors(self): - # see https://github.com/google/jax/issues/1822 + # see https://github.com/jax-ml/jax/issues/1822 self.assertRaisesRegex( TypeError, "Shapes must be 1D sequences of concrete values of integer type.*", @@ -5334,7 +5334,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) def testIntegerPowersArePrecise(self): - # See https://github.com/google/jax/pull/3036 + # See https://github.com/jax-ml/jax/pull/3036 # Checks if the squares of float32 integers have no numerical errors. # It should be satisfied with all integers less than sqrt(2**24). x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) @@ -5405,7 +5405,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testIssue2347(self): - # https://github.com/google/jax/issues/2347 + # https://github.com/jax-ml/jax/issues/2347 object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] self.assertRaises(TypeError, jnp.array, object_list) @@ -5617,7 +5617,7 @@ class ReportedIssuesTests(jtu.JaxTestCase): return False - #https://github.com/google/jax/issues/16420 + #https://github.com/jax-ml/jax/issues/16420 def test_broadcast_dim(self): x = jnp.arange(2) f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (0,)) @@ -5640,7 +5640,7 @@ class ReportedIssuesTests(jtu.JaxTestCase): res = jnp.triu(x) jtu.check_eq(res, np.triu(x)) - #https://github.com/google/jax/issues/16471 + #https://github.com/jax-ml/jax/issues/16471 def test_matmul_1d(self): x = np.array(np.random.rand(3, 3)) y = np.array(np.random.rand(3)) @@ -5650,7 +5650,7 @@ class ReportedIssuesTests(jtu.JaxTestCase): res = jnp.dot(x, y) self.assertArraysAllClose(res, np.dot(x,y)) - #https://github.com/google/jax/issues/17175 + #https://github.com/jax-ml/jax/issues/17175 def test_indexing(self): x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) @jax.vmap @@ -5661,7 +5661,7 @@ class ReportedIssuesTests(jtu.JaxTestCase): res = f(idx) jtu.check_eq(res, np.array([[4., 5., 6.], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.], [1., 2., 3.]])) - #https://github.com/google/jax/issues/17344 + #https://github.com/jax-ml/jax/issues/17344 def test_take_along_axis(self): @jax.jit def f(): @@ -5672,7 +5672,7 @@ class ReportedIssuesTests(jtu.JaxTestCase): return jnp.take_along_axis(x, idx, axis=1) jtu.check_eq(f(), self.dispatchOn([], f)) - #https://github.com/google/jax/issues/17590 + #https://github.com/jax-ml/jax/issues/17590 def test_in1d(self): a = np.array([123,2,4]) b = np.array([123,1]) @@ -5688,7 +5688,7 @@ class ReportedIssuesTests(jtu.JaxTestCase): res = f(x) jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]])) - #https://github.com/google/jax/issues/16326 + #https://github.com/jax-ml/jax/issues/16326 def test_indexing_update2(self): @jax.jit def f(x, r): @@ -5722,7 +5722,7 @@ module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas print(res) jtu.check_eq(res, res_ref) - #https://github.com/google/jax/issues/16366 + #https://github.com/jax-ml/jax/issues/16366 def test_pad_interior_1(self): if not ReportedIssuesTests.jax_metal_supported('0.0.6'): raise unittest.SkipTest("jax-metal version doesn't support it.") diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index 7397cf3e4..ea7bff1d0 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -89,7 +89,7 @@ class EinsumTest(jtu.JaxTestCase): self._check(s, x, y) def test_two_operands_6(self): - # based on https://github.com/google/jax/issues/37#issuecomment-448572187 + # based on https://github.com/jax-ml/jax/issues/37#issuecomment-448572187 r = self.rng() x = r.randn(2, 1) y = r.randn(2, 3, 4) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index bf2785f62..d58a5c2c3 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -496,7 +496,7 @@ class IndexingTest(jtu.JaxTestCase): self._CompileAndCheck(jnp_op_idx, args_maker) def testIndexApplyBatchingBug(self): - # https://github.com/google/jax/issues/16655 + # https://github.com/jax-ml/jax/issues/16655 arr = jnp.array([[1, 2, 3, 4, 5, 6]]) ind = jnp.array([3]) func = lambda a, i: a.at[i].apply(lambda x: x - 1) @@ -505,7 +505,7 @@ class IndexingTest(jtu.JaxTestCase): self.assertArraysEqual(out, expected) def testIndexUpdateScalarBug(self): - # https://github.com/google/jax/issues/14923 + # https://github.com/jax-ml/jax/issues/14923 a = jnp.arange(10.) out = a.at[0].apply(jnp.cos) self.assertArraysEqual(out, a.at[0].set(1)) @@ -835,7 +835,7 @@ class IndexingTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis(self): - # Regression test for https://github.com/google/jax/issues/8412 + # Regression test for https://github.com/jax-ml/jax/issues/8412 x = np.arange(24).reshape(4, 3, 2) idx = (..., np.array([True, False])) ans = jnp.array(x)[idx] @@ -843,7 +843,7 @@ class IndexingTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis2(self): - # Regression test for https://github.com/google/jax/issues/9050 + # Regression test for https://github.com/jax-ml/jax/issues/9050 x = np.arange(3) idx = (..., np.array([True, False, True])) ans = jnp.array(x)[idx] @@ -936,7 +936,7 @@ class IndexingTest(jtu.JaxTestCase): self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) def testTrivialGatherIsntGenerated(self): - # https://github.com/google/jax/issues/1621 + # https://github.com/jax-ml/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertNotIn('gather', str(jaxpr)) @@ -988,14 +988,14 @@ class IndexingTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingShapeMismatch(self): - # Regression test for https://github.com/google/jax/issues/7329 + # Regression test for https://github.com/jax-ml/jax/issues/7329 x = jnp.arange(4) idx = jnp.array([True, False]) with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"): x[idx] def testBooleanIndexingWithNone(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) idx = (None, jnp.array([True, False])) ans = x[idx] @@ -1003,7 +1003,7 @@ class IndexingTest(jtu.JaxTestCase): self.assertAllClose(ans, expected) def testBooleanIndexingWithNoneAndEllipsis(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) mask = jnp.array([True, False, False]) ans = x[None, ..., mask] @@ -1011,7 +1011,7 @@ class IndexingTest(jtu.JaxTestCase): self.assertAllClose(ans, expected) def testBooleanIndexingWithEllipsisAndNone(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) mask = jnp.array([True, False, False]) ans = x[..., None, mask] @@ -1038,7 +1038,7 @@ class IndexingTest(jtu.JaxTestCase): [(3, 4, 5), (3, 0)], ) def testEmptyBooleanIndexing(self, x_shape, m_shape): - # Regression test for https://github.com/google/jax/issues/22886 + # Regression test for https://github.com/jax-ml/jax/issues/22886 rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)] @@ -1120,7 +1120,7 @@ class IndexingTest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, msg): jnp.zeros((2, 3))[:, 'abc'] - def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 + def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 x = jnp.arange(5, dtype=jnp.int32) + 1 self.assertAllClose(x, x[:10]) @@ -1613,7 +1613,7 @@ class IndexedUpdateTest(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testIndexDtypeError(self): - # https://github.com/google/jax/issues/2795 + # https://github.com/jax-ml/jax/issues/2795 jnp.array(1) # get rid of startup warning with self.assertNoWarnings(): jnp.zeros(5).at[::2].set(1) @@ -1647,13 +1647,13 @@ class IndexedUpdateTest(jtu.JaxTestCase): x.at[normalize(idx)].set(0) def testIndexedUpdateAliasingBug(self): - # https://github.com/google/jax/issues/7461 + # https://github.com/jax-ml/jax/issues/7461 fn = lambda x: x.at[1:].set(1 + x[:-1]) y = jnp.zeros(8) self.assertArraysEqual(fn(y), jax.jit(fn)(y)) def testScatterValuesCastToTargetDType(self): - # https://github.com/google/jax/issues/15505 + # https://github.com/jax-ml/jax/issues/15505 a = jnp.zeros(1, dtype=jnp.uint32) val = 2**32 - 1 # too large for int32 diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index d9d6fa464..45a780c9f 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -697,7 +697,7 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase): self.assertIsInstance(jax.jit(operator.mul)(b, a), MyArray) def testI0Grad(self): - # Regression test for https://github.com/google/jax/issues/11479 + # Regression test for https://github.com/jax-ml/jax/issues/11479 dx = jax.grad(jax.numpy.i0)(0.0) self.assertArraysEqual(dx, 0.0) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 402e206ef..33830c541 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -425,7 +425,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): if (shape in [()] + scalar_shapes and dtype in [jnp.int16, jnp.uint16] and jnp_op in [jnp.min, jnp.max]): - self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.") + self.skipTest("Known XLA failure; see https://github.com/jax-ml/jax/issues/4971.") rng = rng_factory(self.rng()) is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. @@ -582,7 +582,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): size=[0, 1, 2] ) def testStdOrVarLargeDdofReturnsNan(self, jnp_fn, size): - # test for https://github.com/google/jax/issues/21330 + # test for https://github.com/jax-ml/jax/issues/21330 x = jnp.arange(size) self.assertTrue(np.isnan(jnp_fn(x, ddof=size))) self.assertTrue(np.isnan(jnp_fn(x, ddof=size + 1))) @@ -622,7 +622,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): atol=tol) def testNanStdGrad(self): - # Regression test for https://github.com/google/jax/issues/8128 + # Regression test for https://github.com/jax-ml/jax/issues/8128 x = jnp.arange(5.0).at[0].set(jnp.nan) y = jax.grad(jnp.nanvar)(x) self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75]), check_dtypes=False) @@ -740,7 +740,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): @unittest.skipIf(not config.enable_x64.value, "test requires X64") @jtu.run_on_devices("cpu") # test is for CPU float64 precision def testPercentilePrecision(self): - # Regression test for https://github.com/google/jax/issues/8513 + # Regression test for https://github.com/jax-ml/jax/issues/8513 x = jnp.float64([1, 2, 3, 4, 7, 10]) self.assertEqual(jnp.percentile(x, 50), 3.5) @@ -778,14 +778,14 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) def testMeanLargeArray(self): - # https://github.com/google/jax/issues/15068 + # https://github.com/jax-ml/jax/issues/15068 raise unittest.SkipTest("test is slow, but it passes!") x = jnp.ones((16, 32, 1280, 4096), dtype='int8') self.assertEqual(1.0, jnp.mean(x)) self.assertEqual(1.0, jnp.mean(x, where=True)) def testStdLargeArray(self): - # https://github.com/google/jax/issues/15068 + # https://github.com/jax-ml/jax/issues/15068 raise unittest.SkipTest("test is slow, but it passes!") x = jnp.ones((16, 32, 1280, 4096), dtype='int8') self.assertEqual(0.0, jnp.std(x)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ddf42a28e..d3f9f2d61 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1043,7 +1043,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): check_dtypes=False) def testRoundMethod(self): - # https://github.com/google/jax/issues/15190 + # https://github.com/jax-ml/jax/issues/15190 (jnp.arange(3.) / 5.).round() # doesn't crash @jtu.sample_product(shape=[(5,), (5, 2)]) @@ -1571,7 +1571,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): y=[0, 32, 64, 128], ) def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/google/jax/issues/5987 + # Regression test for https://github.com/jax-ml/jax/issues/5987 args_maker = lambda: [x, y] self._CheckAgainstNumpy(np.power, jnp.power, args_maker) self._CompileAndCheck(jnp.power, args_maker) @@ -1713,7 +1713,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testConcatenateAxisNone(self): - # https://github.com/google/jax/issues/3419 + # https://github.com/jax-ml/jax/issues/3419 a = jnp.array([[1, 2], [3, 4]]) b = jnp.array([[5]]) jnp.concatenate((a, b), axis=None) @@ -2977,7 +2977,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testDiffPrepoendScalar(self): - # Regression test for https://github.com/google/jax/issues/19362 + # Regression test for https://github.com/jax-ml/jax/issues/19362 x = jnp.arange(10) result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) @@ -3611,7 +3611,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): _check([jnp.complex128(1)], np.complex128, False) # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/google/jax/issues/8945) + # (regression test for https://github.com/jax-ml/jax/issues/8945) _check([0, np.int16(1)], np.int16, False) _check([0.0, np.float16(1)], np.float16, False) @@ -4229,17 +4229,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # TODO(mattjj): test other ndarray-like method overrides def testNpMean(self): - # from https://github.com/google/jax/issues/125 + # from https://github.com/jax-ml/jax/issues/125 x = jnp.eye(3, dtype=float) + 0. ans = np.mean(x) self.assertAllClose(ans, np.array(1./3), check_dtypes=False) def testArangeOnFloats(self): np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/google/jax/issues/145 + # from https://github.com/jax-ml/jax/issues/145 self.assertAllClose(np_arange(0.0, 1.0, 0.1), jnp.arange(0.0, 1.0, 0.1)) - # from https://github.com/google/jax/issues/3450 + # from https://github.com/jax-ml/jax/issues/3450 self.assertAllClose(np_arange(2.5), jnp.arange(2.5)) self.assertAllClose(np_arange(0., 2.5), @@ -4400,7 +4400,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): dtype=unsigned_dtypes, ) def testPartitionUnsignedWithZeros(self, kth, dtype): - # https://github.com/google/jax/issues/22137 + # https://github.com/jax-ml/jax/issues/22137 max_val = np.iinfo(dtype).max arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2]], dtype=dtype) axis = -1 @@ -4441,7 +4441,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): dtype=unsigned_dtypes, ) def testArgpartitionUnsignedWithZeros(self, kth, dtype): - # https://github.com/google/jax/issues/22137 + # https://github.com/jax-ml/jax/issues/22137 max_val = np.iinfo(dtype).max arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2, 3]], dtype=dtype) axis = -1 @@ -4616,7 +4616,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_op, args_maker) def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/google/jax/issues/5088 + # https://github.com/jax-ml/jax/issues/5088 h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-1) @@ -4837,9 +4837,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # NOTE(mattjj): I disabled this test when removing lax._safe_mul because # introducing the convention 0 * inf = 0 leads to silently wrong results in # some cases. See this comment for details: - # https://github.com/google/jax/issues/1052#issuecomment-514083352 + # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352 # def testIssue347(self): - # # https://github.com/google/jax/issues/347 + # # https://github.com/jax-ml/jax/issues/347 # def test_fail(x): # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) # ones = jnp.ones_like(x) @@ -4850,7 +4850,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): # assert not np.any(np.isnan(result)) def testIssue453(self): - # https://github.com/google/jax/issues/453 + # https://github.com/jax-ml/jax/issues/453 a = np.arange(6) + 1 ans = jnp.reshape(a, (3, 2), order='F') expected = np.reshape(a, (3, 2), order='F') @@ -4861,7 +4861,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): op=["atleast_1d", "atleast_2d", "atleast_3d"], ) def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 + # Fixes: https://github.com/jax-ml/jax/issues/634 np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) jnp_fun = lambda arg: getattr(jnp, op)(arg) args_maker = lambda: [dtype(2)] @@ -5489,7 +5489,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): jnp.ones(2) + 3 # don't want to warn for scalars def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 + # tests https://github.com/jax-ml/jax/issues/1271 @jax.jit def foo(x): return jnp.stack(x) @@ -5536,7 +5536,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_op, args_maker) def testBroadcastToInvalidShape(self): - # Regression test for https://github.com/google/jax/issues/20533 + # Regression test for https://github.com/jax-ml/jax/issues/20533 x = jnp.zeros((3, 4, 5)) with self.assertRaisesRegex( ValueError, "Cannot broadcast to shape with fewer dimensions"): @@ -5688,7 +5688,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp.gradient, args_maker) def testZerosShapeErrors(self): - # see https://github.com/google/jax/issues/1822 + # see https://github.com/jax-ml/jax/issues/1822 self.assertRaisesRegex( TypeError, "Shapes must be 1D sequences of concrete values of integer type.*", @@ -5706,7 +5706,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) def testIntegerPowersArePrecise(self): - # See https://github.com/google/jax/pull/3036 + # See https://github.com/jax-ml/jax/pull/3036 # Checks if the squares of float32 integers have no numerical errors. # It should be satisfied with all integers less than sqrt(2**24). x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) @@ -5777,7 +5777,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker) def testIssue2347(self): - # https://github.com/google/jax/issues/2347 + # https://github.com/jax-ml/jax/issues/2347 object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] self.assertRaises(TypeError, jnp.array, object_list) @@ -6096,7 +6096,7 @@ class NumpyGradTests(jtu.JaxTestCase): jax.grad(lambda x: jnp.sinc(x).sum())(jnp.arange(10.)) # doesn't crash def testTakeAlongAxisIssue1521(self): - # https://github.com/google/jax/issues/1521 + # https://github.com/jax-ml/jax/issues/1521 idx = jnp.repeat(jnp.arange(3), 10).reshape((30, 1)) def f(x): @@ -6207,7 +6207,7 @@ class NumpySignaturesTest(jtu.JaxTestCase): if name == "clip": # JAX's support of the Array API spec for clip, and the way it handles # backwards compatibility was introduced in - # https://github.com/google/jax/pull/20550 with a different signature + # https://github.com/jax-ml/jax/pull/20550 with a different signature # from the one in numpy, introduced in # https://github.com/numpy/numpy/pull/26724 # TODO(dfm): After our deprecation period for the clip arguments ends diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index c65df8aa8..630d89f53 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -379,7 +379,7 @@ class LaxNumpyUfuncTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun_at, args_maker) def test_frompyfunc_at_broadcasting(self): - # Regression test for https://github.com/google/jax/issues/18004 + # Regression test for https://github.com/jax-ml/jax/issues/18004 args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]), np.arange(9.0).reshape(3, 3)] def np_fun(x, idx, y): diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 56fd0f781..985dba484 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -258,7 +258,7 @@ class VectorizeTest(jtu.JaxTestCase): f(*args) def test_rank_promotion_error(self): - # Regression test for https://github.com/google/jax/issues/22305 + # Regression test for https://github.com/jax-ml/jax/issues/22305 f = jnp.vectorize(jnp.add, signature="(),()->()") rank2 = jnp.zeros((10, 10)) rank1 = jnp.zeros(10) diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index d2e64833b..303c67c58 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -469,7 +469,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertTrue(dtypes.is_weakly_typed(x)) def test_linear_solve_batching_via_jacrev(self): - # See https://github.com/google/jax/issues/14249 + # See https://github.com/jax-ml/jax/issues/14249 rng = np.random.RandomState(0) M = rng.randn(5, 5) A = np.dot(M, M.T) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index df19750b1..8a8b1dd42 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -165,7 +165,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(sign * np.exp(logsumexp).astype(x.dtype), expected_sumexp, rtol=tol) def testLogSumExpZeros(self): - # Regression test for https://github.com/google/jax/issues/5370 + # Regression test for https://github.com/jax-ml/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] @@ -173,14 +173,14 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker) def testLogSumExpOnes(self): - # Regression test for https://github.com/google/jax/issues/7390 + # Regression test for https://github.com/jax-ml/jax/issues/7390 args_maker = lambda: [np.ones(4, dtype='float32')] with jax.debug_infs(True): self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker) self._CompileAndCheck(lsp_special.logsumexp, args_maker) def testLogSumExpNans(self): - # Regression test for https://github.com/google/jax/issues/7634 + # Regression test for https://github.com/jax-ml/jax/issues/7634 with jax.debug_nans(True): with jax.disable_jit(): result = lsp_special.logsumexp(1.0) @@ -246,7 +246,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): - # https://github.com/google/jax/issues/15598 + # https://github.com/jax-ml/jax/issues/15598 x0, y0 = 0.0, 3.0 d_xlog1py_dx = jax.grad(lsp_special.xlogy, argnums=0)(x0, y0) self.assertAllClose(d_xlog1py_dx, lax.log(y0)) @@ -260,7 +260,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): - # https://github.com/google/jax/issues/15598 + # https://github.com/jax-ml/jax/issues/15598 x0, y0 = 0.0, 3.0 d_xlog1py_dx = jax.grad(lsp_special.xlog1py, argnums=0)(x0, y0) self.assertAllClose(d_xlog1py_dx, lax.log1p(y0)) @@ -284,7 +284,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase): rtol=.1, eps=1e-3) def testGradOfEntrAtZero(self): - # https://github.com/google/jax/issues/15709 + # https://github.com/jax-ml/jax/issues/15709 self.assertEqual(jax.jacfwd(lsp_special.entr)(0.0), jnp.inf) self.assertEqual(jax.jacrev(lsp_special.entr)(0.0), jnp.inf) diff --git a/tests/lax_test.py b/tests/lax_test.py index 3d31bcb7d..0ae5f77af 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1002,7 +1002,7 @@ class LaxTest(jtu.JaxTestCase): self._CheckAgainstNumpy(fun_via_grad, fun, args_maker) def testConvTransposePaddingList(self): - # Regression test for https://github.com/google/jax/discussions/8695 + # Regression test for https://github.com/jax-ml/jax/discussions/8695 a = jnp.ones((28,28)) b = jnp.ones((3,3)) c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1)) @@ -1280,7 +1280,7 @@ class LaxTest(jtu.JaxTestCase): self._CompileAndCheck(op, args_maker) def testBroadcastInDimOperandShapeTranspose(self): - # Regression test for https://github.com/google/jax/issues/5276 + # Regression test for https://github.com/jax-ml/jax/issues/5276 def f(x): return lax.broadcast_in_dim(x, (2, 3, 4), broadcast_dimensions=(0, 1, 2)).sum() def g(x): @@ -1681,7 +1681,7 @@ class LaxTest(jtu.JaxTestCase): lax.dynamic_update_slice, args_maker) def testDynamicUpdateSliceBatched(self): - # Regression test for https://github.com/google/jax/issues/9083 + # Regression test for https://github.com/jax-ml/jax/issues/9083 x = jnp.arange(5) y = jnp.arange(6, 9) ind = jnp.arange(6) @@ -2236,7 +2236,7 @@ class LaxTest(jtu.JaxTestCase): self.assertEqual(shape, result.shape) def testReduceWindowWithEmptyOutput(self): - # https://github.com/google/jax/issues/10315 + # https://github.com/jax-ml/jax/issues/10315 shape = (5, 3, 2) operand, padding, strides = np.ones(shape), 'VALID', (1,) * len(shape) out = jax.eval_shape(lambda x: lax.reduce_window(x, 0., lax.add, padding=padding, @@ -2859,13 +2859,13 @@ class LaxTest(jtu.JaxTestCase): op(2+3j, 4+5j) def test_population_count_booleans_not_supported(self): - # https://github.com/google/jax/issues/3886 + # https://github.com/jax-ml/jax/issues/3886 msg = "population_count does not accept dtype bool" with self.assertRaisesRegex(TypeError, msg): lax.population_count(True) def test_conv_general_dilated_different_input_ranks_error(self): - # https://github.com/google/jax/issues/4316 + # https://github.com/jax-ml/jax/issues/4316 msg = ("conv_general_dilated lhs and rhs must have the same number of " "dimensions") dimension_numbers = lax.ConvDimensionNumbers(lhs_spec=(0, 1, 2), @@ -2885,7 +2885,7 @@ class LaxTest(jtu.JaxTestCase): lax.conv_general_dilated(lhs, rhs, **kwargs) def test_window_strides_dimension_shape_rule(self): - # https://github.com/google/jax/issues/5087 + # https://github.com/jax-ml/jax/issues/5087 msg = ("conv_general_dilated window and window_strides must have " "the same number of dimensions") lhs = jax.numpy.zeros((1, 1, 3, 3)) @@ -2894,7 +2894,7 @@ class LaxTest(jtu.JaxTestCase): jax.lax.conv(lhs, rhs, [1], 'SAME') def test_reduce_window_scalar_init_value_shape_rule(self): - # https://github.com/google/jax/issues/4574 + # https://github.com/jax-ml/jax/issues/4574 args = { "operand": np.ones((4, 4), dtype=np.int32) , "init_value": np.zeros((1,), dtype=np.int32) , "computation": lax.max @@ -3045,7 +3045,7 @@ class LaxTest(jtu.JaxTestCase): np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128]) def test_dot_general_batching_python_builtin_arg(self): - # https://github.com/google/jax/issues/16805 + # https://github.com/jax-ml/jax/issues/16805 @jax.remat def f(x): return jax.lax.dot_general(x, x, (([], []), ([], []))) @@ -3053,7 +3053,7 @@ class LaxTest(jtu.JaxTestCase): jax.hessian(f)(1.0) # don't crash def test_constant_folding_complex_to_real_scan_regression(self): - # regression test for github.com/google/jax/issues/19059 + # regression test for github.com/jax-ml/jax/issues/19059 def g(hiddens): hiddens_aug = jnp.vstack((hiddens[0], hiddens)) new_hiddens = hiddens_aug.copy() @@ -3088,7 +3088,7 @@ class LaxTest(jtu.JaxTestCase): jaxpr = jax.make_jaxpr(asarray_closure)() self.assertLen(jaxpr.eqns, 0) - # Regression test for https://github.com/google/jax/issues/19334 + # Regression test for https://github.com/jax-ml/jax/issues/19334 # lax.asarray as a closure should not trigger transfer guard. with jax.transfer_guard('disallow'): jax.jit(asarray_closure)() @@ -3254,7 +3254,7 @@ class LazyConstantTest(jtu.JaxTestCase): def testUnaryWeakTypes(self, op_name, rec_dtypes): """Test that all lax unary ops propagate weak_type information appropriately.""" if op_name == "bitwise_not": - raise unittest.SkipTest("https://github.com/google/jax/issues/12066") + raise unittest.SkipTest("https://github.com/jax-ml/jax/issues/12066") # Find a valid dtype for the function. for dtype in [float, int, complex, bool]: dtype = dtypes.canonicalize_dtype(dtype) @@ -3648,7 +3648,7 @@ class CustomElementTypesTest(jtu.JaxTestCase): self.assertEqual(ys.shape, (3, 2, 1)) def test_gather_batched_index_dtype(self): - # Regression test for https://github.com/google/jax/issues/16557 + # Regression test for https://github.com/jax-ml/jax/issues/16557 dtype = jnp.int8 size = jnp.iinfo(dtype).max + 10 indices = jnp.zeros(size, dtype=dtype) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 37d51c04f..37a0011e7 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -693,8 +693,8 @@ class LaxVmapTest(jtu.JaxTestCase): # TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU @jtu.skip_on_devices("gpu") def test_variadic_reduce_window(self): - # https://github.com/google/jax/discussions/9818 and - # https://github.com/google/jax/issues/9837 + # https://github.com/jax-ml/jax/discussions/9818 and + # https://github.com/jax-ml/jax/issues/9837 def normpool(x): norms = jnp.linalg.norm(x, axis=-1) idxs = jnp.arange(x.shape[0]) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 446e10abd..15963b10b 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -329,7 +329,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): @jtu.run_on_devices("cpu") def testEigvalsInf(self): - # https://github.com/google/jax/issues/2661 + # https://github.com/jax-ml/jax/issues/2661 x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -1004,7 +1004,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): - # Regression test for https://github.com/google/jax/issues/10530 + # Regression test for https://github.com/jax-ml/jax/issues/10530 rng = jtu.rand_default(self.rng()) arr = rng(shape, dtype) if jtu.test_device_matches(['cpu']): @@ -1422,7 +1422,7 @@ class ScipyLinalgTest(jtu.JaxTestCase): @parameterized.parameters(lax_linalg.lu, lax_linalg._lu_python) def testLuOnZeroMatrix(self, lu): - # Regression test for https://github.com/google/jax/issues/19076 + # Regression test for https://github.com/jax-ml/jax/issues/19076 x = jnp.zeros((2, 2), dtype=np.float32) x_lu, _, _ = lu(x) self.assertArraysEqual(x_lu, x) diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 5daa0e0e5..a3b6b1efa 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -174,7 +174,7 @@ class MultiDeviceTest(jtu.JaxTestCase): def test_closed_over_values_device_placement(self): - # see https://github.com/google/jax/issues/1431 + # see https://github.com/jax-ml/jax/issues/1431 devices = self.get_devices() def f(): return lax.add(3., 4.) diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 4f2e36c64..4697ba8b2 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -148,7 +148,7 @@ class MultiBackendTest(jtu.JaxTestCase): @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") def test_closed_over_values_device_placement(self): - # see https://github.com/google/jax/issues/1431 + # see https://github.com/jax-ml/jax/issues/1431 def f(): return jnp.add(3., 4.) self.assertNotEqual(jax.jit(f)().devices(), {jax.devices('cpu')[0]}) @@ -186,7 +186,7 @@ class MultiBackendTest(jtu.JaxTestCase): @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_indexing(self): - # https://github.com/google/jax/issues/2905 + # https://github.com/jax-ml/jax/issues/2905 cpus = jax.devices("cpu") x = jax.device_put(np.ones(2), cpus[0]) @@ -195,7 +195,7 @@ class MultiBackendTest(jtu.JaxTestCase): @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_sum(self): - # https://github.com/google/jax/issues/2905 + # https://github.com/jax-ml/jax/issues/2905 cpus = jax.devices("cpu") x = jax.device_put(np.ones(2), cpus[0]) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 40c5b6b2e..5c84f8c69 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -46,7 +46,7 @@ jax.config.parse_flags_with_absl() @unittest.skipIf(not portpicker, "Test requires portpicker") class DistributedTest(jtu.JaxTestCase): - # TODO(phawkins): Enable after https://github.com/google/jax/issues/11222 + # TODO(phawkins): Enable after https://github.com/jax-ml/jax/issues/11222 # is fixed. @unittest.SkipTest def testInitializeAndShutdown(self): diff --git a/tests/nn_test.py b/tests/nn_test.py index be07de184..416beffce 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -325,12 +325,12 @@ class NNFunctionsTest(jtu.JaxTestCase): self.assertEqual(out.dtype, dtype) def testEluMemory(self): - # see https://github.com/google/jax/pull/1640 + # see https://github.com/jax-ml/jax/pull/1640 with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom def testHardTanhMemory(self): - # see https://github.com/google/jax/pull/1640 + # see https://github.com/jax-ml/jax/pull/1640 with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom @@ -367,7 +367,7 @@ class NNFunctionsTest(jtu.JaxTestCase): @parameterized.parameters([nn.softmax, nn.log_softmax]) def testSoftmaxWhereGrad(self, fn): - # regression test for https://github.com/google/jax/issues/19490 + # regression test for https://github.com/jax-ml/jax/issues/19490 x = jnp.array([36., 10000.]) mask = x < 1000 @@ -443,7 +443,7 @@ class NNFunctionsTest(jtu.JaxTestCase): self.assertAllClose(actual, expected) def testOneHotConcretizationError(self): - # https://github.com/google/jax/issues/3654 + # https://github.com/jax-ml/jax/issues/3654 msg = r"in jax.nn.one_hot argument `num_classes`" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): jax.jit(nn.one_hot)(3, 5) @@ -463,7 +463,7 @@ class NNFunctionsTest(jtu.JaxTestCase): nn.tanh # doesn't crash def testCustomJVPLeak(self): - # https://github.com/google/jax/issues/8171 + # https://github.com/jax-ml/jax/issues/8171 @jax.jit def fwd(): a = jnp.array(1.) @@ -479,7 +479,7 @@ class NNFunctionsTest(jtu.JaxTestCase): fwd() # doesn't crash def testCustomJVPLeak2(self): - # https://github.com/google/jax/issues/8171 + # https://github.com/jax-ml/jax/issues/8171 # The above test uses jax.nn.sigmoid, as in the original #8171, but that # function no longer actually has a custom_jvp! So we inline the old def. diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index e8cd40a67..f5dcff837 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -20,7 +20,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { diff --git a/tests/notebooks/colab_gpu.ipynb b/tests/notebooks/colab_gpu.ipynb index 8352bdaf7..2335455e6 100644 --- a/tests/notebooks/colab_gpu.ipynb +++ b/tests/notebooks/colab_gpu.ipynb @@ -7,7 +7,7 @@ "id": "view-in-github" }, "source": [ - "\"Open" + "\"Open" ] }, { diff --git a/tests/ode_test.py b/tests/ode_test.py index 834745e1c..acdfa1fc6 100644 --- a/tests/ode_test.py +++ b/tests/ode_test.py @@ -139,7 +139,7 @@ class ODETest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu", "gpu") def test_odeint_vmap_grad(self): - # https://github.com/google/jax/issues/2531 + # https://github.com/jax-ml/jax/issues/2531 def dx_dt(x, *args): return 0.1 * x @@ -169,7 +169,7 @@ class ODETest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu", "gpu") def test_disable_jit_odeint_with_vmap(self): - # https://github.com/google/jax/issues/2598 + # https://github.com/jax-ml/jax/issues/2598 with jax.disable_jit(): t = jnp.array([0.0, 1.0]) x0_eval = jnp.zeros((5, 2)) @@ -178,7 +178,7 @@ class ODETest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu", "gpu") def test_grad_closure(self): - # simplification of https://github.com/google/jax/issues/2718 + # simplification of https://github.com/jax-ml/jax/issues/2718 def experiment(x): def model(y, t): return -x * y @@ -188,7 +188,7 @@ class ODETest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu", "gpu") def test_grad_closure_with_vmap(self): - # https://github.com/google/jax/issues/2718 + # https://github.com/jax-ml/jax/issues/2718 @jax.jit def experiment(x): def model(y, t): @@ -209,7 +209,7 @@ class ODETest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu", "gpu") def test_forward_mode_error(self): - # https://github.com/google/jax/issues/3558 + # https://github.com/jax-ml/jax/issues/3558 def f(k): return odeint(lambda x, t: k*x, 1., jnp.linspace(0, 1., 50)).sum() @@ -219,7 +219,7 @@ class ODETest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu", "gpu") def test_closure_nondiff(self): - # https://github.com/google/jax/issues/3584 + # https://github.com/jax-ml/jax/issues/3584 def dz_dt(z, t): return jnp.stack([z[0], z[1]]) @@ -232,8 +232,8 @@ class ODETest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu", "gpu") def test_complex_odeint(self): - # https://github.com/google/jax/issues/3986 - # https://github.com/google/jax/issues/8757 + # https://github.com/jax-ml/jax/issues/3986 + # https://github.com/jax-ml/jax/issues/8757 def dy_dt(y, t, alpha): return alpha * y * jnp.exp(-t).astype(y.dtype) diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index b7710d9b9..c4eca0707 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -260,7 +260,7 @@ class OptimizerTests(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def testIssue758(self): - # code from https://github.com/google/jax/issues/758 + # code from https://github.com/jax-ml/jax/issues/758 # this is more of a scan + jacfwd/jacrev test, but it lives here to use the # optimizers.py code diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index e2f0e2152..7692294cd 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -178,7 +178,7 @@ class FusedAttentionTest(PallasBaseTest): (1, 384, 8, 64, True, True, True, {}), (1, 384, 8, 64, True, True, False, {}), (2, 384, 8, 64, True, True, True, {}), - # regression test: https://github.com/google/jax/pull/17314 + # regression test: https://github.com/jax-ml/jax/pull/17314 (1, 384, 8, 64, True, False, False, {'block_q': 128, 'block_k': 64}), ] ] @@ -419,7 +419,7 @@ class SoftmaxTest(PallasBaseTest): }[dtype] # We upcast to float32 because NumPy <2.0 does not handle custom dtypes - # properly. See https://github.com/google/jax/issues/11014. + # properly. See https://github.com/jax-ml/jax/issues/11014. np.testing.assert_allclose( softmax.softmax(x, axis=-1).astype(jnp.float32), jax.nn.softmax(x, axis=-1).astype(jnp.float32), diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 63c3148e8..d8f890c06 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -798,7 +798,7 @@ class OpsExtraTest(PallasBaseTest): np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) def test_abs_weak_type(self): - # see https://github.com/google/jax/issues/23191 + # see https://github.com/jax-ml/jax/issues/23191 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32), ) @@ -999,7 +999,7 @@ class OpsExtraTest(PallasBaseTest): x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) # We upcast to float32 because NumPy <2.0 does not handle custom dtypes - # properly. See https://github.com/google/jax/issues/11014. + # properly. See https://github.com/jax-ml/jax/issues/11014. np.testing.assert_allclose( kernel(x).astype(jnp.float32), jnp.tanh(x).astype(jnp.float32), @@ -1260,7 +1260,7 @@ class OpsExtraTest(PallasBaseTest): np.testing.assert_array_equal(out, o_new) def test_strided_load(self): - # Reproducer from https://github.com/google/jax/issues/20895. + # Reproducer from https://github.com/jax-ml/jax/issues/20895. @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 5ee30ba33..aec7fd54c 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -484,7 +484,7 @@ class PallasCallTest(PallasBaseTest): self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32)) def test_hoisted_consts(self): - # See https://github.com/google/jax/issues/21557. + # See https://github.com/jax-ml/jax/issues/21557. # to_store will be hoisted as a constant. Choose distinct shapes from in/outs. to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index 4b3f47e6f..fefccfe7e 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -209,7 +209,7 @@ class PallasCallVmapTest(PallasBaseTest): @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_large_vmap(self): - # Catches https://github.com/google/jax/issues/18361 + # Catches https://github.com/jax-ml/jax/issues/18361 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), grid=(2,)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 57106948f..c20084c3c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1510,7 +1510,7 @@ class CustomPartitionerTest(jtu.JaxTestCase): def test_custom_partitioner_with_scan(self): self.skip_if_custom_partitioning_not_supported() - # This is a reproducer from https://github.com/google/jax/issues/20864. + # This is a reproducer from https://github.com/jax-ml/jax/issues/20864. @custom_partitioning def f(x): @@ -1921,7 +1921,7 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertArraysEqual(s.data, input_data) def test_sds_full_like(self): - # https://github.com/google/jax/issues/20390 + # https://github.com/jax-ml/jax/issues/20390 mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) x = jax.ShapeDtypeStruct((4, 4), jnp.float32, sharding=s) @@ -4113,7 +4113,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_spmd_preserves_input_sharding_vmap_grad(self): if config.use_shardy_partitioner.value: self.skipTest("Shardy doesn't support PositionalSharding") - # https://github.com/google/jax/issues/20710 + # https://github.com/jax-ml/jax/issues/20710 n_devices = jax.device_count() sharding = PositionalSharding(jax.devices()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 8b121d91a..d7dcc7ba3 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -122,7 +122,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testDeviceBufferToArray(self): sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2))) - # Changed in https://github.com/google/jax/pull/10584 not to access + # Changed in https://github.com/jax-ml/jax/pull/10584 not to access # sda.device_buffers, which isn't supported, and instead ensure fast slices # of the arrays returned by pmap are set up correctly. # buf = sda.device_buffers[-1] @@ -336,7 +336,7 @@ class PythonPmapTest(jtu.JaxTestCase): compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) def test_pmap_replicated_copy(self): - # https://github.com/google/jax/issues/17690 + # https://github.com/jax-ml/jax/issues/17690 inp = jnp.arange(jax.device_count()) x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(inp) out = jnp.copy(x) @@ -605,7 +605,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertAllClose(y, ref) def testNestedPmapAxisSwap(self): - # Regression test for https://github.com/google/jax/issues/5757 + # Regression test for https://github.com/jax-ml/jax/issues/5757 if jax.device_count() < 8: raise SkipTest("test requires at least 8 devices") f = jax.pmap(jax.pmap(lambda x: x, in_axes=1, out_axes=0), in_axes=0, @@ -1180,7 +1180,7 @@ class PythonPmapTest(jtu.JaxTestCase): "`perm` does not represent a permutation: \\[1.*\\]", g) def testPpermuteWithZipObject(self): - # https://github.com/google/jax/issues/1703 + # https://github.com/jax-ml/jax/issues/1703 num_devices = jax.device_count() perm = [num_devices - 1] + list(range(num_devices - 1)) f = self.pmap(lambda x: lax.ppermute(x, "i", zip(perm, range(num_devices))), "i") @@ -1501,7 +1501,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertEqual(ans.shape, (13, N_DEVICES)) def testVmapOfPmap3(self): - # https://github.com/google/jax/issues/3399 + # https://github.com/jax-ml/jax/issues/3399 device_count = jax.device_count() if device_count < 2: raise SkipTest("test requires at least two devices") @@ -1661,7 +1661,7 @@ class PythonPmapTest(jtu.JaxTestCase): @ignore_jit_of_pmap_warning() def testIssue1065(self): - # from https://github.com/google/jax/issues/1065 + # from https://github.com/jax-ml/jax/issues/1065 device_count = jax.device_count() def multi_step_pmap(state, count): @@ -1697,7 +1697,7 @@ class PythonPmapTest(jtu.JaxTestCase): # replica. @unittest.skip("need eager multi-replica support") def testPostProcessMap(self): - # test came from https://github.com/google/jax/issues/1369 + # test came from https://github.com/jax-ml/jax/issues/1369 nrep = jax.device_count() def pmvm(a, b): @@ -1730,7 +1730,7 @@ class PythonPmapTest(jtu.JaxTestCase): @jax.default_matmul_precision("float32") def testPostProcessMap2(self): - # code from https://github.com/google/jax/issues/2787 + # code from https://github.com/jax-ml/jax/issues/2787 def vv(x, y): """Vector-vector multiply""" return jnp.dot(x, y) @@ -1758,7 +1758,7 @@ class PythonPmapTest(jtu.JaxTestCase): ('_new', new_checkpoint), ]) def testAxisIndexRemat(self, remat): - # https://github.com/google/jax/issues/2716 + # https://github.com/jax-ml/jax/issues/2716 n = len(jax.devices()) def f(key): @@ -1769,7 +1769,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.pmap(remat(f), axis_name='i')(keys) def testPmapMapVmapCombinations(self): - # https://github.com/google/jax/issues/2822 + # https://github.com/jax-ml/jax/issues/2822 def vv(x, y): """Vector-vector multiply""" return jnp.dot(x, y) @@ -1802,7 +1802,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3) def testPmapAxisNameError(self): - # https://github.com/google/jax/issues/3120 + # https://github.com/jax-ml/jax/issues/3120 a = np.arange(4)[np.newaxis,:] def test(x): return jax.lax.psum(x, axis_name='batch') @@ -1811,7 +1811,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.pmap(test)(a) def testPsumOnBooleanDtype(self): - # https://github.com/google/jax/issues/3123 + # https://github.com/jax-ml/jax/issues/3123 n = jax.device_count() if n > 1: x = jnp.array([True, False]) @@ -1889,7 +1889,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertIn("mhlo.num_partitions = 1", hlo) def testPsumZeroCotangents(self): - # https://github.com/google/jax/issues/3651 + # https://github.com/jax-ml/jax/issues/3651 def loss(params, meta_params): (net, mpo) = params return meta_params * mpo * net @@ -1914,7 +1914,7 @@ class PythonPmapTest(jtu.JaxTestCase): @ignore_jit_of_pmap_warning() def test_issue_1062(self): - # code from https://github.com/google/jax/issues/1062 @shoyer + # code from https://github.com/jax-ml/jax/issues/1062 @shoyer # this tests, among other things, whether ShardedDeviceTuple constants work device_count = jax.device_count() @@ -1938,7 +1938,7 @@ class PythonPmapTest(jtu.JaxTestCase): # TODO(skye): fix backend caching so we always have multiple CPUs available if jax.device_count("cpu") < 4: self.skipTest("test requires 4 CPU device") - # https://github.com/google/jax/issues/4223 + # https://github.com/jax-ml/jax/issues/4223 def fn(indices): return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32) mapped_fn = self.pmap(fn, axis_name='i', backend='cpu') @@ -1982,7 +1982,7 @@ class PythonPmapTest(jtu.JaxTestCase): for dtype in [np.float32, np.int32] ) def testPmapDtype(self, dtype): - # Regression test for https://github.com/google/jax/issues/6022 + # Regression test for https://github.com/jax-ml/jax/issues/6022 @partial(self.pmap, axis_name='i') def func(_): return jax.lax.psum(dtype(0), axis_name='i') @@ -1991,7 +1991,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertEqual(out_dtype, dtype) def test_num_replicas_with_switch(self): - # https://github.com/google/jax/issues/7411 + # https://github.com/jax-ml/jax/issues/7411 def identity(x): return x @@ -2154,7 +2154,7 @@ class PythonPmapTest(jtu.JaxTestCase): @jtu.run_on_devices("cpu") def test_pmap_stack_size(self): - # Regression test for https://github.com/google/jax/issues/20428 + # Regression test for https://github.com/jax-ml/jax/issues/20428 # pmap isn't particularly important here, but it guarantees that the CPU # client runs the computation on a threadpool rather than inline. if jax.device_count() < 2: @@ -2164,7 +2164,7 @@ class PythonPmapTest(jtu.JaxTestCase): y.block_until_ready() # doesn't crash def test_pmap_of_prng_key(self): - # Regression test for https://github.com/google/jax/issues/20392 + # Regression test for https://github.com/jax-ml/jax/issues/20392 keys = jax.random.split(jax.random.key(0), jax.device_count()) result1 = jax.pmap(jax.random.bits)(keys) with jtu.ignore_warning( diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 76854beae..d9887cf7b 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1245,7 +1245,7 @@ class IOCallbackTest(jtu.JaxTestCase): np.testing.assert_array_equal(shard[0] + 1, shard[1]) def test_batching_with_side_effects(self): - # https://github.com/google/jax/issues/20628#issuecomment-2050800195 + # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195 x_lst = [] def append_x(x): nonlocal x_lst @@ -1261,7 +1261,7 @@ class IOCallbackTest(jtu.JaxTestCase): self.assertAllClose(x_lst, [0., 1., 2., 0., 2., 4.], check_dtypes=False) def test_batching_with_side_effects_while_loop(self): - # https://github.com/google/jax/issues/20628#issuecomment-2050921219 + # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219 x_lst = [] def append_x(x): nonlocal x_lst diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index 8d00b5eed..2e0fc3223 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -109,7 +109,7 @@ class DLPackTest(jtu.JaxTestCase): self.assertAllClose(np, y.cpu().numpy()) def testTorchToJaxInt64(self): - # See https://github.com/google/jax/issues/11895 + # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64))) dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32 diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index f69687ddc..63510b729 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -178,7 +178,7 @@ class LaxRandomTest(jtu.JaxTestCase): def testNormalBfloat16(self): # Passing bfloat16 as dtype string. - # https://github.com/google/jax/issues/6813 + # https://github.com/jax-ml/jax/issues/6813 res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16') res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16) self.assertAllClose(res_bfloat16, res_bfloat16_str) @@ -391,7 +391,7 @@ class LaxRandomTest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testBetaSmallParameters(self, dtype=np.float32): - # Regression test for beta version of https://github.com/google/jax/issues/9896 + # Regression test for beta version of https://github.com/jax-ml/jax/issues/9896 key = self.make_key(0) a, b = 0.0001, 0.0002 samples = random.beta(key, a, b, shape=(100,), dtype=dtype) @@ -441,7 +441,7 @@ class LaxRandomTest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") # lower accuracy leads to failures. def testDirichletSmallAlpha(self, dtype=np.float32): - # Regression test for https://github.com/google/jax/issues/9896 + # Regression test for https://github.com/jax-ml/jax/issues/9896 key = self.make_key(0) alpha = 0.00001 * jnp.ones(3) samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype) @@ -530,7 +530,7 @@ class LaxRandomTest(jtu.JaxTestCase): rtol=rtol) def testGammaGradType(self): - # Regression test for https://github.com/google/jax/issues/2130 + # Regression test for https://github.com/jax-ml/jax/issues/2130 key = self.make_key(0) a = jnp.array(1., dtype=jnp.float32) b = jnp.array(3., dtype=jnp.float32) @@ -663,7 +663,7 @@ class LaxRandomTest(jtu.JaxTestCase): ) def testGeneralizedNormalKS(self, p, shape, dtype): self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws - "sensitive to random key - https://github.com/google/jax/issues/18941") + "sensitive to random key - https://github.com/jax-ml/jax/issues/18941") key = lambda: self.make_key(2) rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype) crand = jax.jit(rand) @@ -700,7 +700,7 @@ class LaxRandomTest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testBallKS(self, d, p, shape, dtype): self.skipTest( - "sensitive to random key - https://github.com/google/jax/issues/18932") + "sensitive to random key - https://github.com/jax-ml/jax/issues/18932") key = lambda: self.make_key(123) rand = lambda key, p: random.ball(key, d, p, (100, *shape), dtype) crand = jax.jit(rand) @@ -800,7 +800,7 @@ class LaxRandomTest(jtu.JaxTestCase): assert samples.shape == shape + (dim,) def testMultivariateNormalCovariance(self): - # test code based on https://github.com/google/jax/issues/1869 + # test code based on https://github.com/jax-ml/jax/issues/1869 N = 100000 mean = jnp.zeros(4) cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00], @@ -827,7 +827,7 @@ class LaxRandomTest(jtu.JaxTestCase): @jtu.sample_product(method=['cholesky', 'eigh', 'svd']) @jtu.skip_on_devices('gpu', 'tpu') # Some NaNs on accelerators. def testMultivariateNormalSingularCovariance(self, method): - # Singular covariance matrix https://github.com/google/jax/discussions/13293 + # Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293 mu = jnp.zeros((2,)) sigma = jnp.ones((2, 2)) key = self.make_key(0) @@ -889,7 +889,7 @@ class LaxRandomTest(jtu.JaxTestCase): def testRandomBroadcast(self): """Issue 4033""" - # test for broadcast issue in https://github.com/google/jax/issues/4033 + # test for broadcast issue in https://github.com/jax-ml/jax/issues/4033 key = lambda: self.make_key(0) shape = (10, 2) with jax.numpy_rank_promotion('allow'): @@ -1071,7 +1071,7 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertGreater((r == 255).sum(), 0) def test_large_prng(self): - # https://github.com/google/jax/issues/11010 + # https://github.com/jax-ml/jax/issues/11010 def f(): return random.uniform( self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) @@ -1086,7 +1086,7 @@ class LaxRandomTest(jtu.JaxTestCase): logits_shape_base=[(3, 4), (3, 1), (1, 4)], axis=[-3, -2, -1, 0, 1, 2]) def test_categorical_shape_argument(self, shape, logits_shape_base, axis): - # https://github.com/google/jax/issues/13124 + # https://github.com/jax-ml/jax/issues/13124 logits_shape = list(logits_shape_base) logits_shape.insert(axis % (len(logits_shape_base) + 1), 10) assert logits_shape[axis] == 10 diff --git a/tests/random_test.py b/tests/random_test.py index 941172f75..da182dbcc 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -436,7 +436,7 @@ class PrngTest(jtu.JaxTestCase): @skipIf(not config.threefry_partitionable.value, 'enable after upgrade') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_threefry_split_vmapped_fold_in_symmetry(self, make_key): - # See https://github.com/google/jax/issues/7708 + # See https://github.com/jax-ml/jax/issues/7708 with jax.default_prng_impl('threefry2x32'): key = make_key(72) f1, f2, f3 = vmap(lambda k, _: random.fold_in(k, lax.axis_index('batch')), @@ -450,7 +450,7 @@ class PrngTest(jtu.JaxTestCase): @skipIf(config.threefry_partitionable.value, 'changed random bit values') def test_loggamma_nan_corner_case(self): - # regression test for https://github.com/google/jax/issues/17922 + # regression test for https://github.com/jax-ml/jax/issues/17922 # This particular key previously led to NaN output. # If the underlying implementation ever changes, this test will no longer # exercise this corner case, so we compare to a particular output value @@ -545,7 +545,7 @@ class PrngTest(jtu.JaxTestCase): @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_key_output_vjp(self, make_key): - # See https://github.com/google/jax/issues/14856 + # See https://github.com/jax-ml/jax/issues/14856 def f(seed): return make_key(seed) jax.vjp(f, 1) # doesn't crash @@ -578,7 +578,7 @@ class ThreefryPrngTest(jtu.JaxTestCase): partial(random.PRNGKey, impl='threefry2x32'), partial(random.key, impl='threefry2x32')]]) def test_seed_no_implicit_transfers(self, make_key): - # See https://github.com/google/jax/issues/15613 + # See https://github.com/jax-ml/jax/issues/15613 with jax.transfer_guard('disallow'): make_key(jax.device_put(42)) # doesn't crash @@ -922,14 +922,14 @@ class KeyArrayTest(jtu.JaxTestCase): self.assertEqual(ys.shape, (3, 2)) def test_select_scalar_cond(self): - # regression test for https://github.com/google/jax/issues/16422 + # regression test for https://github.com/jax-ml/jax/issues/16422 ks = self.make_keys(3) ys = lax.select(True, ks, ks) self.assertIsInstance(ys, prng_internal.PRNGKeyArray) self.assertEqual(ys.shape, (3,)) def test_vmap_of_cond(self): - # See https://github.com/google/jax/issues/15869 + # See https://github.com/jax-ml/jax/issues/15869 def f(x): keys = self.make_keys(*x.shape) return lax.select(x, keys, keys) @@ -1126,7 +1126,7 @@ class KeyArrayTest(jtu.JaxTestCase): self.assertEqual(repr(spec), f"PRNGSpec({name!r})") def test_keyarray_custom_vjp(self): - # Regression test for https://github.com/google/jax/issues/18442 + # Regression test for https://github.com/jax-ml/jax/issues/18442 @jax.custom_vjp def f(_, state): return state diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 701b7c570..dd34a99a7 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -129,7 +129,7 @@ class NdimageTest(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_op, lsp_op, args_maker) def testContinuousGradients(self): - # regression test for https://github.com/google/jax/issues/3024 + # regression test for https://github.com/jax-ml/jax/issues/3024 def loss(delta): x = np.arange(100.0) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index ad988bba6..983cb6bdc 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -314,7 +314,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): rtol={np.float32: 2e-3, np.float64: 1e-4}) def testBetaLogPdfZero(self): - # Regression test for https://github.com/google/jax/issues/7645 + # Regression test for https://github.com/jax-ml/jax/issues/7645 a = b = 1. x = np.array([0., 1.]) self.assertAllClose( @@ -539,7 +539,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker) def testGammaLogPdfZero(self): - # Regression test for https://github.com/google/jax/issues/7256 + # Regression test for https://github.com/jax-ml/jax/issues/7256 self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) @@ -710,7 +710,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker) def testLogisticLogpdfOverflow(self): - # Regression test for https://github.com/google/jax/issues/10219 + # Regression test for https://github.com/jax-ml/jax/issues/10219 self.assertAllClose( np.array([-100, -100], np.float32), lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), @@ -855,7 +855,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker) def testNormSfNearZero(self): - # Regression test for https://github.com/google/jax/issues/17199 + # Regression test for https://github.com/jax-ml/jax/issues/17199 value = np.array(10, np.float32) self.assertAllClose(osp_stats.norm.sf(value).astype('float32'), lsp_stats.norm.sf(value), @@ -1208,7 +1208,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testBinomPmfOutOfRange(self): - # Regression test for https://github.com/google/jax/issues/19150 + # Regression test for https://github.com/jax-ml/jax/issues/19150 self.assertEqual(lsp_stats.binom.pmf(k=6.5, n=5, p=0.8), 0.0) def testBinomLogPmfZerokZeron(self): diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 27199c874..77e5273d1 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -2973,7 +2973,7 @@ _POLY_SHAPE_TEST_HARNESSES = [ (2, x.shape[0]), (1, 1), "VALID"), arg_descriptors=[RandArg((3, 8), _f32)], polymorphic_shapes=["b, ..."]), - # https://github.com/google/jax/issues/11804 + # https://github.com/jax-ml/jax/issues/11804 # Use the reshape trick to simulate a polymorphic dimension of 16*b. # (See test "conv_general_dilated.1d_1" above for more details.) PolyHarness("reduce_window", "add_monoid_strides_window_size=static", diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index ae22eeca0..20bc33475 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -724,7 +724,7 @@ class ShardMapTest(jtu.JaxTestCase): self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) def test_nested_vmap_with_capture_spmd_axis_name(self): - self.skipTest('https://github.com/google/jax/issues/23476') + self.skipTest('https://github.com/jax-ml/jax/issues/23476') mesh = jtu.create_mesh((2, 2), ('x', 'y')) def to_map_with_capture(x, y): @@ -902,7 +902,7 @@ class ShardMapTest(jtu.JaxTestCase): @jax.legacy_prng_key('allow') def test_prngkeyarray_eager(self): - # https://github.com/google/jax/issues/15398 + # https://github.com/jax-ml/jax/issues/15398 mesh = jtu.create_mesh((4,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) @@ -1069,7 +1069,7 @@ class ShardMapTest(jtu.JaxTestCase): self.assertEqual(out, 1.) def test_jaxpr_shardings_with_no_outputs(self): - # https://github.com/google/jax/issues/15385 + # https://github.com/jax-ml/jax/issues/15385 mesh = jtu.create_mesh((4,), ('i',)) @jax.jit @@ -1109,7 +1109,7 @@ class ShardMapTest(jtu.JaxTestCase): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def test_key_array_with_replicated_last_tile_dim(self): - # See https://github.com/google/jax/issues/16137 + # See https://github.com/jax-ml/jax/issues/16137 mesh = jtu.create_mesh((2, 4), ('i', 'j')) @@ -1690,7 +1690,7 @@ class ShardMapTest(jtu.JaxTestCase): self.assertAllClose(grad, jnp.ones(4) * 4 * 4, check_dtypes=False) def test_repeated_psum_allowed(self): - # https://github.com/google/jax/issues/19175 + # https://github.com/jax-ml/jax/issues/19175 mesh = jtu.create_mesh((4,), 'i') @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) @@ -1927,7 +1927,7 @@ class ShardMapTest(jtu.JaxTestCase): self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_vmap_grad_shmap_spmd_axis_name_residuals(self): - # https://github.com/google/jax/pull/21032 + # https://github.com/jax-ml/jax/pull/21032 mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial( @@ -1944,7 +1944,7 @@ class ShardMapTest(jtu.JaxTestCase): jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): - # https://github.com/google/jax/pull/21056 + # https://github.com/jax-ml/jax/pull/21056 mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial(jax.remat, policy=lambda *_, **__: True) @@ -1962,7 +1962,7 @@ class ShardMapTest(jtu.JaxTestCase): jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash def test_grad_shmap_residuals_axis_names_in_mesh_order(self): - # https://github.com/google/jax/issues/21236 + # https://github.com/jax-ml/jax/issues/21236 mesh = jtu.create_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) @partial( @@ -2108,7 +2108,7 @@ class ShardMapTest(jtu.JaxTestCase): )((object(), object()), x) def test_custom_linear_solve_rep_rules(self): - # https://github.com/google/jax/issues/20162 + # https://github.com/jax-ml/jax/issues/20162 mesh = jtu.create_mesh((1,), ('i',)) a = jnp.array(1).reshape(1, 1) b = jnp.array(1).reshape(1) diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index 545d73bff..38fde72f0 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -310,7 +310,7 @@ class BCOOTest(sptu.SparseTestCase): self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]])) def test_bcoo_extract_batching(self): - # https://github.com/google/jax/issues/9431 + # https://github.com/jax-ml/jax/issues/9431 indices = jnp.zeros((4, 1, 1), dtype=int) mat = jnp.arange(4.).reshape((4, 1)) @@ -353,7 +353,7 @@ class BCOOTest(sptu.SparseTestCase): self.assertEqual(hess.shape, data.shape + 2 * M.shape) def test_bcoo_extract_zero_nse(self): - # Regression test for https://github.com/google/jax/issues/13653 + # Regression test for https://github.com/jax-ml/jax/issues/13653 # (n_batch, n_sparse, n_dense) = (1, 0, 0), nse = 2 args_maker = lambda: (jnp.zeros((3, 2, 0), dtype='int32'), jnp.arange(3)) @@ -974,7 +974,7 @@ class BCOOTest(sptu.SparseTestCase): self.assertEqual(out.nse, expected_nse) def test_bcoo_spdot_general_ad_bug(self): - # Regression test for https://github.com/google/jax/issues/10163 + # Regression test for https://github.com/jax-ml/jax/issues/10163 A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) A_values = jnp.array([-2.0, 1.0, -1.0, 0.5, 2.0]) A_shape = (2, 3) @@ -1287,7 +1287,7 @@ class BCOOTest(sptu.SparseTestCase): self.assertEqual(y2.nse, x.nse) def test_bcoo_sum_duplicates_padding(self): - # Regression test for https://github.com/google/jax/issues/8163 + # Regression test for https://github.com/jax-ml/jax/issues/8163 size = 3 data = jnp.array([1, 0, 0]) indices = jnp.array([1, size, size])[:, None] @@ -1606,7 +1606,7 @@ class BCOOTest(sptu.SparseTestCase): self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol) def test_bcoo_mul_sparse_with_duplicates(self): - # Regression test for https://github.com/google/jax/issues/8888 + # Regression test for https://github.com/jax-ml/jax/issues/8888 indices = jnp.array([[0, 1, 0, 0, 1, 1], [1, 0, 1, 2, 0, 2]]).T data = jnp.array([1, 2, 3, 4, 5, 6]) @@ -1940,7 +1940,7 @@ class BCSRTest(sptu.SparseTestCase): self._CheckGradsSparse(dense_func, sparse_func, args_maker) def test_bcoo_spdot_abstract_eval_bug(self): - # Regression test for https://github.com/google/jax/issues/21921 + # Regression test for https://github.com/jax-ml/jax/issues/21921 lhs = sparse.BCOO( (jnp.float32([[1]]), lax.broadcasted_iota(jnp.int32, (10, 1, 1), 0)), shape=(10, 10)) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 616396222..eb8d70be1 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -323,7 +323,7 @@ class cuSparseTest(sptu.SparseTestCase): self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=sptu.MATMUL_TOL) def test_coo_matmat_layout(self): - # Regression test for https://github.com/google/jax/issues/7533 + # Regression test for https://github.com/jax-ml/jax/issues/7533 d = jnp.array([1.0, 2.0, 3.0, 4.0]) i = jnp.array([0, 0, 1, 2]) j = jnp.array([0, 2, 0, 0]) diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 46086511d..46c2f5aaf 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -610,7 +610,7 @@ class SparsifyTest(jtu.JaxTestCase): self.assertArraysEqual(jit(func)(Msp).todense(), expected) def testWeakTypes(self): - # Regression test for https://github.com/google/jax/issues/8267 + # Regression test for https://github.com/jax-ml/jax/issues/8267 M = jnp.arange(12, dtype='int32').reshape(3, 4) Msp = BCOO.fromdense(M) self.assertArraysEqual( diff --git a/tests/stax_test.py b/tests/stax_test.py index 6850f36a0..e21300ddd 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -216,7 +216,7 @@ class StaxTest(jtu.JaxTestCase): def testBatchNormShapeNCHW(self): key = random.PRNGKey(0) - # Regression test for https://github.com/google/jax/issues/461 + # Regression test for https://github.com/jax-ml/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index bc741702c..f8792a263 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -343,7 +343,7 @@ class TreeTest(jtu.JaxTestCase): self.assertEqual(h.args, (3,)) def testPartialFuncAttributeHasStableHash(self): - # https://github.com/google/jax/issues/9429 + # https://github.com/jax-ml/jax/issues/9429 fun = functools.partial(print, 1) p1 = tree_util.Partial(fun, 2) p2 = tree_util.Partial(fun, 2) @@ -359,7 +359,7 @@ class TreeTest(jtu.JaxTestCase): self.assertEqual([c0, c1], tree.children()) def testTreedefTupleFromChildren(self): - # https://github.com/google/jax/issues/7377 + # https://github.com/jax-ml/jax/issues/7377 tree = ((1, 2, (3, 4)), (5,)) leaves, treedef1 = tree_util.tree_flatten(tree) treedef2 = tree_util.treedef_tuple(treedef1.children()) @@ -368,7 +368,7 @@ class TreeTest(jtu.JaxTestCase): self.assertEqual(treedef1.num_nodes, treedef2.num_nodes) def testTreedefTupleComparesEqual(self): - # https://github.com/google/jax/issues/9066 + # https://github.com/jax-ml/jax/issues/9066 self.assertEqual(tree_util.tree_structure((3,)), tree_util.treedef_tuple((tree_util.tree_structure(3),))) @@ -978,7 +978,7 @@ class RavelUtilTest(jtu.JaxTestCase): self.assertAllClose(tree, tree_, atol=0., rtol=0.) def testDtypePolymorphicUnravel(self): - # https://github.com/google/jax/issues/7809 + # https://github.com/jax-ml/jax/issues/7809 x = jnp.arange(10, dtype=jnp.float32) x_flat, unravel = flatten_util.ravel_pytree(x) y = x_flat < 5.3 @@ -987,7 +987,7 @@ class RavelUtilTest(jtu.JaxTestCase): @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testDtypeMonomorphicUnravel(self): - # https://github.com/google/jax/issues/7809 + # https://github.com/jax-ml/jax/issues/7809 x1 = jnp.arange(10, dtype=jnp.float32) x2 = jnp.arange(10, dtype=jnp.int32) x_flat, unravel = flatten_util.ravel_pytree((x1, x2)) diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 58cf4a2ba..d4403b7e5 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -128,7 +128,7 @@ class X64ContextTests(jtu.JaxTestCase): @unittest.skip("test fails, see #8552") def test_convert_element_type(self): - # Regression test for part of https://github.com/google/jax/issues/5982 + # Regression test for part of https://github.com/jax-ml/jax/issues/5982 with enable_x64(): x = jnp.int64(1) self.assertEqual(x.dtype, jnp.int64)