1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

upgrade docs from jax.core to jax.extend.core where needed to fix doc build

This commit is contained in:
Roy Frostig 2025-04-01 22:17:41 -07:00
parent f139192201
commit 6fe6d80506
3 changed files with 4 additions and 4 deletions

@ -21,7 +21,7 @@ kernelspec:
A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide).
For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.core.Primitive("multiply_add")`, as demonstrated further below.
For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.extend.core.Primitive("multiply_add")`, as demonstrated further below.
And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as {func}`jax.jit`, {func}`jax.grad` and {func}`jax.vmap`. JAX implements these transforms in a *JAX-traceable* way. This means that when a Python function is executed, the only operations it applies to the data are either:
@ -171,7 +171,7 @@ The JAX traceability property is satisfied as long as the function is written in
The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality.
```{code-cell}
from jax import core
from jax.extend import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive

@ -215,8 +215,8 @@
"# Importing Jax functions useful for tracing/interpreting.\n",
"from functools import wraps\n",
"\n",
"from jax import core\n",
"from jax import lax\n",
"from jax.extend import core\n",
"from jax._src.util import safe_map"
]
},

@ -147,8 +147,8 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr.
# Importing Jax functions useful for tracing/interpreting.
from functools import wraps
from jax import core
from jax import lax
from jax.extend import core
from jax._src.util import safe_map
```