mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
DOC: add FAQ entry on jit-compiling methods
This commit is contained in:
parent
452de3f465
commit
c4836aa507
205
docs/faq.rst
205
docs/faq.rst
@ -113,6 +113,211 @@ easier to write.
|
||||
If your functions are slow to compile for another reason, please open an issue
|
||||
on GitHub.
|
||||
|
||||
.. _faq-jit-class-methods:
|
||||
|
||||
How to use ``jit`` with methods?
|
||||
--------------------------------
|
||||
Most examples of :func:`jax.jit` concern decorating stand-alone Python functions,
|
||||
but decorating a method within a class introduces some complication. For example,
|
||||
consider the following simple class, where we've used a standard :func:`~jax.jit`
|
||||
annotation on a method::
|
||||
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax import jit
|
||||
|
||||
>>> class CustomClass:
|
||||
... def __init__(self, x: jnp.ndarray, mul: bool):
|
||||
... self.x = x
|
||||
... self.mul = mul
|
||||
...
|
||||
... @jit # <---- How to do this correctly?
|
||||
... def calc(self, y):
|
||||
... if self.mul:
|
||||
... return self.x * y
|
||||
... return y
|
||||
|
||||
However, this approach will result in an error when you attempt to call this method::
|
||||
|
||||
>>> c = CustomClass(2, True)
|
||||
>>> c.calc(3) # doctest: +SKIP
|
||||
---------------------------------------------------------------------------
|
||||
TypeError Traceback (most recent call last)
|
||||
File "<stdin>", line 1, in <module
|
||||
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
|
||||
|
||||
The problem is that the first argument to the function is ``self``, which has type
|
||||
``CustomClass``, and JAX does not know how to handle this type.
|
||||
There are three basic strategies we might use in this case, and we'll discuss
|
||||
them below.
|
||||
|
||||
Strategy 1: JIT-compiled helper function
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
The most straightforward approach is to create a helper function external to the class
|
||||
that can be JIT-decorated in the normal way. For example::
|
||||
|
||||
>>> from functools import partial
|
||||
|
||||
>>> class CustomClass:
|
||||
... def __init__(self, x: jnp.ndarray, mul: bool):
|
||||
... self.x = x
|
||||
... self.mul = mul
|
||||
...
|
||||
... def calc(self, y):
|
||||
... return _calc(self.mul, self.x, y)
|
||||
|
||||
>>> @partial(jit, static_argnums=0)
|
||||
... def _calc(mul, x, y):
|
||||
... if mul:
|
||||
... return x * y
|
||||
... return y
|
||||
|
||||
The result will work as expected::
|
||||
|
||||
>>> c = CustomClass(2, True)
|
||||
>>> print(c.calc(3))
|
||||
6
|
||||
|
||||
The benefit of such an approach is that it is simple, explicit, and it avoids the need
|
||||
to teach JAX how to handle objects of type ``CustomClass``. However, you may wish to
|
||||
keep all the method logic in the same place.
|
||||
|
||||
Strategy 2: Marking ``self`` as static
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Another common pattern is to use ``static_argnums`` to mark the ``self`` argument as static.
|
||||
But this must be done with care to avoid unexpected results.
|
||||
You may be tempted to simply do this::
|
||||
|
||||
>>> class CustomClass:
|
||||
... def __init__(self, x: jnp.ndarray, mul: bool):
|
||||
... self.x = x
|
||||
... self.mul = mul
|
||||
...
|
||||
... # WARNING: this example is broken, as we'll see below. Don't copy & paste!
|
||||
... @partial(jit, static_argnums=0)
|
||||
... def calc(self, y):
|
||||
... if self.mul:
|
||||
... return self.x * y
|
||||
... return y
|
||||
|
||||
If you call the method, it will no longer raise an error::
|
||||
|
||||
>>> c = CustomClass(2, True)
|
||||
>>> print(c.calc(3))
|
||||
6
|
||||
|
||||
However, there is a catch: if you mutate the object after the first method call, the
|
||||
subsequent method call may return an incorrect result::
|
||||
|
||||
>>> c.mul = False
|
||||
>>> print(c.calc(3)) # Should print 3
|
||||
6
|
||||
|
||||
What's happening here? The issue is that ``static_argnums`` relies on the hash of the object
|
||||
to determine whether it has changed between calls, and the default ``__hash__`` method
|
||||
for a user-defined class will not take into account the values of class attributes. That means
|
||||
that on the second function call, JAX has no way of knowing that the class attribues have
|
||||
changed, and uses the cached static value from the previous compilation.
|
||||
|
||||
For this reason, if you are marking ``self`` arguments as static, it is important that you
|
||||
define an appropriate ``__hash__`` method for your class.
|
||||
For example, you might proceed like this::
|
||||
|
||||
>>> class CustomClass:
|
||||
... def __init__(self, x: jnp.ndarray, mul: bool):
|
||||
... self.x = x
|
||||
... self.mul = mul
|
||||
...
|
||||
... @partial(jit, static_argnums=0)
|
||||
... def calc(self, y):
|
||||
... if self.mul:
|
||||
... return self.x * y
|
||||
... return y
|
||||
...
|
||||
... def __hash__(self):
|
||||
... return hash((self.x, self.mul))
|
||||
...
|
||||
... def __eq__(self, other):
|
||||
... return (isinstance(other, CustomClass) and
|
||||
... (self.x, self.mul) == (other.x, other.mul))
|
||||
|
||||
Note that we've defined the ``__hash__`` method so that it depends on the hash of
|
||||
relevant class attributes, and we've also defined the ``__eq__`` method because it's
|
||||
good practice to do so any time you override ``__hash__`` (see
|
||||
`Python Data Model: __hash__ <https://docs.python.org/3/reference/datamodel.html#object.__hash__>`_
|
||||
for more information on this). With this addition, the example works correctly::
|
||||
|
||||
>>> c = CustomClass(2, True)
|
||||
>>> print(c.calc(3))
|
||||
6
|
||||
>>> c.mul = False
|
||||
>>> print(c.calc(3))
|
||||
3
|
||||
|
||||
A downside of marking ``self`` as static is that it does not allow ``self`` to contain
|
||||
array-like attributes, since arrays are not hashable. For example, this will break because
|
||||
JAX arrays are not hashable::
|
||||
|
||||
>>> c = CustomClass(jnp.array(2), True)
|
||||
>>> c.calc(3) # doctest: +SKIP
|
||||
---------------------------------------------------------------------------
|
||||
ValueError Traceback (most recent call last)
|
||||
File "<stdin>", line 1, in <module
|
||||
ValueError: Non-hashable static arguments are not supported. An error occured during a call to 'calc' while trying to hash an object of type <class '__main__.CustomClass'>
|
||||
|
||||
Additionally, this also has the downside that ``calc`` will be re-compiled any time the values
|
||||
within ``myfunc`` change, which could be costly depending on your program.
|
||||
|
||||
Strategy 3: Making ``CustomClass`` a PyTree
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
The most flexible approach to correctly JIT-compiling a class method is to register the
|
||||
type as a custom PyTree object; see :ref:`extending-pytrees`. This lets you specify
|
||||
exactly which components of the class should be treated as static and which should be
|
||||
treated as dynamic. Here's how it might look::
|
||||
|
||||
>>> class CustomClass:
|
||||
... def __init__(self, x: jnp.ndarray, mul: bool):
|
||||
... self.x = x
|
||||
... self.mul = mul
|
||||
...
|
||||
... @jit
|
||||
... def calc(self, y):
|
||||
... if self.mul:
|
||||
... return self.x * y
|
||||
... return y
|
||||
...
|
||||
... def _tree_flatten(self):
|
||||
... children = (self.x,) # arrays / dynamic values
|
||||
... aux_data = {'mul': self.mul} # static values
|
||||
... return (children, aux_data)
|
||||
...
|
||||
... @classmethod
|
||||
... def _tree_unflatten(cls, aux_data, children):
|
||||
... return cls(*children, **aux_data)
|
||||
|
||||
>>> from jax import tree_util
|
||||
>>> tree_util.register_pytree_node(CustomClass,
|
||||
... CustomClass._tree_flatten,
|
||||
... CustomClass._tree_unflatten)
|
||||
|
||||
This is certainly more involved, but it solves all the issues associated with the simpler
|
||||
apporaches used above::
|
||||
|
||||
>>> c = CustomClass(2, True)
|
||||
>>> print(c.calc(3))
|
||||
6
|
||||
|
||||
>>> c.mul = False # mutation is detected
|
||||
>>> print(c.calc(3))
|
||||
3
|
||||
|
||||
>>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported
|
||||
>>> print(c.calc(3))
|
||||
6
|
||||
|
||||
So long as your ``tree_flatten`` and ``tree_unflatten`` functions correctly handle all
|
||||
relevant attributes in the class, you should be able to use objects of this type directly
|
||||
as arguments to JIT-compiled functions, without any special annotations.
|
||||
|
||||
.. _faq-data-placement:
|
||||
|
||||
Controlling data and computation placement on devices
|
||||
|
Loading…
x
Reference in New Issue
Block a user