mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24647 from emilyfertig:emilyaf-doc-pytree-dataclass
PiperOrigin-RevId: 691984161
This commit is contained in:
commit
5a3ed6c792
@ -192,6 +192,8 @@ def g_inner_jitted(x, n):
|
||||
g_inner_jitted(10, 20)
|
||||
```
|
||||
|
||||
(jit-marking-arguments-as-static)=
|
||||
|
||||
## Marking arguments as static
|
||||
|
||||
If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`.
|
||||
|
@ -272,6 +272,49 @@ jax.tree.leaves([
|
||||
|
||||
Notice that the `name` field now appears as a leaf, because all tuple elements are children. This is what happens when you don't have to register the class the hard way.
|
||||
|
||||
Unlike `NamedTuple` subclasses, classes decorated with `@dataclass` are not automatically pytrees. However, they can be registered as pytrees using the {func}`jax.tree_util.register_dataclass` decorator:
|
||||
|
||||
```{code-cell}
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
|
||||
@functools.partial(jax.tree_util.register_dataclass,
|
||||
data_fields=['a', 'b', 'c'],
|
||||
meta_fields=['name'])
|
||||
@dataclass
|
||||
class MyDataclassContainer(object):
|
||||
name: str
|
||||
a: Any
|
||||
b: Any
|
||||
c: Any
|
||||
|
||||
# MyDataclassContainer is now a pytree node.
|
||||
jax.tree.leaves([
|
||||
MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])),
|
||||
MyDataclassContainer('banana', np.array([3, 4]), -1., 0.)
|
||||
])
|
||||
```
|
||||
|
||||
Notice that the `name` field does not appear as a leaf. This is because we included it in the `meta_fields` argument to {func}`jax.tree_util.register_dataclass`, indicating that it should be treated as metadata/auxiliary data, just like `aux_data` in `RegisteredSpecial` above. Now instances of `MyDataclassContainer` can be passed into JIT-ed functions, and `name` will be treated as static (see {ref}`jit-marking-arguments-as-static` for more information on static args):
|
||||
|
||||
```{code-cell}
|
||||
@jax.jit
|
||||
def f(x: MyDataclassContainer | MyOtherContainer):
|
||||
return x.a + x.b
|
||||
|
||||
# Works fine! `mdc.name` is static.
|
||||
mdc = MyDataclassContainer('mdc', 1, 2, 3)
|
||||
y = f(mdc)
|
||||
```
|
||||
|
||||
Contrast this with `MyOtherContainer`, the `NamedTuple` subclass. Since the `name` field is a pytree leaf, JIT expects it to be convertible to {class}`jax.Array`, and the following raises an error:
|
||||
|
||||
```{code-cell}
|
||||
:tags: [raises-exception]
|
||||
|
||||
moc = MyOtherContainer('moc', 1, 2, 3)
|
||||
y = f(moc)
|
||||
```
|
||||
|
||||
(pytree-and-jax-transformations)=
|
||||
## Pytrees and JAX transformations
|
||||
|
Loading…
x
Reference in New Issue
Block a user