Merge pull request #24647 from emilyfertig:emilyaf-doc-pytree-dataclass

PiperOrigin-RevId: 691984161
This commit is contained in:
jax authors 2024-10-31 17:16:31 -07:00
commit 5a3ed6c792
2 changed files with 45 additions and 0 deletions

View File

@ -192,6 +192,8 @@ def g_inner_jitted(x, n):
g_inner_jitted(10, 20)
```
(jit-marking-arguments-as-static)=
## Marking arguments as static
If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`.

View File

@ -272,6 +272,49 @@ jax.tree.leaves([
Notice that the `name` field now appears as a leaf, because all tuple elements are children. This is what happens when you don't have to register the class the hard way.
Unlike `NamedTuple` subclasses, classes decorated with `@dataclass` are not automatically pytrees. However, they can be registered as pytrees using the {func}`jax.tree_util.register_dataclass` decorator:
```{code-cell}
from dataclasses import dataclass
import functools
@functools.partial(jax.tree_util.register_dataclass,
data_fields=['a', 'b', 'c'],
meta_fields=['name'])
@dataclass
class MyDataclassContainer(object):
name: str
a: Any
b: Any
c: Any
# MyDataclassContainer is now a pytree node.
jax.tree.leaves([
MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])),
MyDataclassContainer('banana', np.array([3, 4]), -1., 0.)
])
```
Notice that the `name` field does not appear as a leaf. This is because we included it in the `meta_fields` argument to {func}`jax.tree_util.register_dataclass`, indicating that it should be treated as metadata/auxiliary data, just like `aux_data` in `RegisteredSpecial` above. Now instances of `MyDataclassContainer` can be passed into JIT-ed functions, and `name` will be treated as static (see {ref}`jit-marking-arguments-as-static` for more information on static args):
```{code-cell}
@jax.jit
def f(x: MyDataclassContainer | MyOtherContainer):
return x.a + x.b
# Works fine! `mdc.name` is static.
mdc = MyDataclassContainer('mdc', 1, 2, 3)
y = f(mdc)
```
Contrast this with `MyOtherContainer`, the `NamedTuple` subclass. Since the `name` field is a pytree leaf, JIT expects it to be convertible to {class}`jax.Array`, and the following raises an error:
```{code-cell}
:tags: [raises-exception]
moc = MyOtherContainer('moc', 1, 2, 3)
y = f(moc)
```
(pytree-and-jax-transformations)=
## Pytrees and JAX transformations