DOC: add FAQ entry on jit-compiling methods

This commit is contained in:
Jake VanderPlas 2022-05-06 12:39:12 -07:00
parent 452de3f465
commit c4836aa507

View File

@ -113,6 +113,211 @@ easier to write.
If your functions are slow to compile for another reason, please open an issue
on GitHub.
.. _faq-jit-class-methods:
How to use ``jit`` with methods?
--------------------------------
Most examples of :func:`jax.jit` concern decorating stand-alone Python functions,
but decorating a method within a class introduces some complication. For example,
consider the following simple class, where we've used a standard :func:`~jax.jit`
annotation on a method::
>>> import jax.numpy as jnp
>>> from jax import jit
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit # <---- How to do this correctly?
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
However, this approach will result in an error when you attempt to call this method::
>>> c = CustomClass(2, True)
>>> c.calc(3) # doctest: +SKIP
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
The problem is that the first argument to the function is ``self``, which has type
``CustomClass``, and JAX does not know how to handle this type.
There are three basic strategies we might use in this case, and we'll discuss
them below.
Strategy 1: JIT-compiled helper function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The most straightforward approach is to create a helper function external to the class
that can be JIT-decorated in the normal way. For example::
>>> from functools import partial
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... def calc(self, y):
... return _calc(self.mul, self.x, y)
>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
... if mul:
... return x * y
... return y
The result will work as expected::
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
The benefit of such an approach is that it is simple, explicit, and it avoids the need
to teach JAX how to handle objects of type ``CustomClass``. However, you may wish to
keep all the method logic in the same place.
Strategy 2: Marking ``self`` as static
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Another common pattern is to use ``static_argnums`` to mark the ``self`` argument as static.
But this must be done with care to avoid unexpected results.
You may be tempted to simply do this::
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... # WARNING: this example is broken, as we'll see below. Don't copy & paste!
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
If you call the method, it will no longer raise an error::
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
However, there is a catch: if you mutate the object after the first method call, the
subsequent method call may return an incorrect result::
>>> c.mul = False
>>> print(c.calc(3)) # Should print 3
6
What's happening here? The issue is that ``static_argnums`` relies on the hash of the object
to determine whether it has changed between calls, and the default ``__hash__`` method
for a user-defined class will not take into account the values of class attributes. That means
that on the second function call, JAX has no way of knowing that the class attribues have
changed, and uses the cached static value from the previous compilation.
For this reason, if you are marking ``self`` arguments as static, it is important that you
define an appropriate ``__hash__`` method for your class.
For example, you might proceed like this::
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def __hash__(self):
... return hash((self.x, self.mul))
...
... def __eq__(self, other):
... return (isinstance(other, CustomClass) and
... (self.x, self.mul) == (other.x, other.mul))
Note that we've defined the ``__hash__`` method so that it depends on the hash of
relevant class attributes, and we've also defined the ``__eq__`` method because it's
good practice to do so any time you override ``__hash__`` (see
`Python Data Model: __hash__ <https://docs.python.org/3/reference/datamodel.html#object.__hash__>`_
for more information on this). With this addition, the example works correctly::
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
>>> c.mul = False
>>> print(c.calc(3))
3
A downside of marking ``self`` as static is that it does not allow ``self`` to contain
array-like attributes, since arrays are not hashable. For example, this will break because
JAX arrays are not hashable::
>>> c = CustomClass(jnp.array(2), True)
>>> c.calc(3) # doctest: +SKIP
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File "<stdin>", line 1, in <module
ValueError: Non-hashable static arguments are not supported. An error occured during a call to 'calc' while trying to hash an object of type <class '__main__.CustomClass'>
Additionally, this also has the downside that ``calc`` will be re-compiled any time the values
within ``myfunc`` change, which could be costly depending on your program.
Strategy 3: Making ``CustomClass`` a PyTree
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The most flexible approach to correctly JIT-compiling a class method is to register the
type as a custom PyTree object; see :ref:`extending-pytrees`. This lets you specify
exactly which components of the class should be treated as static and which should be
treated as dynamic. Here's how it might look::
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def _tree_flatten(self):
... children = (self.x,) # arrays / dynamic values
... aux_data = {'mul': self.mul} # static values
... return (children, aux_data)
...
... @classmethod
... def _tree_unflatten(cls, aux_data, children):
... return cls(*children, **aux_data)
>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
... CustomClass._tree_flatten,
... CustomClass._tree_unflatten)
This is certainly more involved, but it solves all the issues associated with the simpler
apporaches used above::
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
>>> c.mul = False # mutation is detected
>>> print(c.calc(3))
3
>>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported
>>> print(c.calc(3))
6
So long as your ``tree_flatten`` and ``tree_unflatten`` functions correctly handle all
relevant attributes in the class, you should be able to use objects of this type directly
as arguments to JIT-compiled functions, without any special annotations.
.. _faq-data-placement:
Controlling data and computation placement on devices