mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
upgrade docs from jax.core
to jax.extend.core
where needed to fix doc build
This commit is contained in:
parent
f139192201
commit
6fe6d80506
@ -21,7 +21,7 @@ kernelspec:
|
||||
|
||||
A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide).
|
||||
|
||||
For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.core.Primitive("multiply_add")`, as demonstrated further below.
|
||||
For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.extend.core.Primitive("multiply_add")`, as demonstrated further below.
|
||||
|
||||
And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as {func}`jax.jit`, {func}`jax.grad` and {func}`jax.vmap`. JAX implements these transforms in a *JAX-traceable* way. This means that when a Python function is executed, the only operations it applies to the data are either:
|
||||
|
||||
@ -171,7 +171,7 @@ The JAX traceability property is satisfied as long as the function is written in
|
||||
The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality.
|
||||
|
||||
```{code-cell}
|
||||
from jax import core
|
||||
from jax.extend import core
|
||||
|
||||
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
|
||||
|
||||
|
@ -215,8 +215,8 @@
|
||||
"# Importing Jax functions useful for tracing/interpreting.\n",
|
||||
"from functools import wraps\n",
|
||||
"\n",
|
||||
"from jax import core\n",
|
||||
"from jax import lax\n",
|
||||
"from jax.extend import core\n",
|
||||
"from jax._src.util import safe_map"
|
||||
]
|
||||
},
|
||||
|
@ -147,8 +147,8 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr.
|
||||
# Importing Jax functions useful for tracing/interpreting.
|
||||
from functools import wraps
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.extend import core
|
||||
from jax._src.util import safe_map
|
||||
```
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user