Update references to JAX's GitHub repo

JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
This commit is contained in:
Jake Harmon 2025-03-04 18:30:34 -08:00 committed by jax authors
parent 1a19d5594a
commit cdeeacabcf
4 changed files with 4 additions and 4 deletions

View File

@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly,
and elsewhere.
At the heart of the project is the [JAX
core](http://github.com/google/jax) library, which focuses on the
core](http://github.com/jax-ml/jax) library, which focuses on the
fundamentals of machine learning and numerical computing, at scale.
When [developing](#development) the core, we want to maintain agility

View File

@ -1255,7 +1255,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
# See https://github.com/google/jax/issues/6415 for motivation.
# See https://github.com/jax-ml/jax/issues/6415 for motivation.
if not isinstance(x, core.Tracer):
# If x is not a Tracer, it can't be perturbed.
return False

View File

@ -14,7 +14,7 @@
"""Colocated Python API."""
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
# pylint: disable=useless-import-alias
from jax.experimental.colocated_python.api import (

View File

@ -2120,7 +2120,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash
def testIssue804(self):
# https://github.com/google/jax/issues/804
# https://github.com/jax-ml/jax/issues/804
num_devices = jax.device_count()
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash