From cdeeacabcf2e32d4ced4cf278b26d08f4d342bfd Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Tue, 4 Mar 2025 18:30:34 -0800 Subject: [PATCH] Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 733536104 --- docs/about.md | 2 +- jax/_src/custom_derivatives.py | 2 +- jax/experimental/colocated_python/__init__.py | 2 +- tests/lax_control_flow_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/about.md b/docs/about.md index c4bc93140..58e170384 100644 --- a/docs/about.md +++ b/docs/about.md @@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly, and elsewhere. At the heart of the project is the [JAX -core](http://github.com/google/jax) library, which focuses on the +core](http://github.com/jax-ml/jax) library, which focuses on the fundamentals of machine learning and numerical computing, at scale. When [developing](#development) the core, we want to maintain agility diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 2bcdb7b5c..32856106a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1255,7 +1255,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]: def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. - # See https://github.com/google/jax/issues/6415 for motivation. + # See https://github.com/jax-ml/jax/issues/6415 for motivation. if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. return False diff --git a/jax/experimental/colocated_python/__init__.py b/jax/experimental/colocated_python/__init__.py index 5bd56e732..2d387b37c 100644 --- a/jax/experimental/colocated_python/__init__.py +++ b/jax/experimental/colocated_python/__init__.py @@ -14,7 +14,7 @@ """Colocated Python API.""" # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 # pylint: disable=useless-import-alias from jax.experimental.colocated_python.api import ( diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 4c61426c6..15fc37805 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2120,7 +2120,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash def testIssue804(self): - # https://github.com/google/jax/issues/804 + # https://github.com/jax-ml/jax/issues/804 num_devices = jax.device_count() f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.) jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash