mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 733536104
This commit is contained in:
parent
1a19d5594a
commit
cdeeacabcf
@ -10,7 +10,7 @@ DeepMind](https://deepmind.google/), Alphabet more broadly,
|
||||
and elsewhere.
|
||||
|
||||
At the heart of the project is the [JAX
|
||||
core](http://github.com/google/jax) library, which focuses on the
|
||||
core](http://github.com/jax-ml/jax) library, which focuses on the
|
||||
fundamentals of machine learning and numerical computing, at scale.
|
||||
|
||||
When [developing](#development) the core, we want to maintain agility
|
||||
|
@ -1255,7 +1255,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
|
||||
def _maybe_perturbed(x: Any) -> bool:
|
||||
# False if x can't represent an AD-perturbed value (i.e. a value
|
||||
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
|
||||
# See https://github.com/google/jax/issues/6415 for motivation.
|
||||
# See https://github.com/jax-ml/jax/issues/6415 for motivation.
|
||||
if not isinstance(x, core.Tracer):
|
||||
# If x is not a Tracer, it can't be perturbed.
|
||||
return False
|
||||
|
@ -14,7 +14,7 @@
|
||||
"""Colocated Python API."""
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
# pylint: disable=useless-import-alias
|
||||
from jax.experimental.colocated_python.api import (
|
||||
|
@ -2120,7 +2120,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash
|
||||
|
||||
def testIssue804(self):
|
||||
# https://github.com/google/jax/issues/804
|
||||
# https://github.com/jax-ml/jax/issues/804
|
||||
num_devices = jax.device_count()
|
||||
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
|
||||
jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash
|
||||
|
Loading…
x
Reference in New Issue
Block a user