Clarify that functions passed to jax.jit must be weakly referenceable.

This commit is contained in:
Peter Hawkins 2022-06-10 12:21:23 -07:00
parent 859883cfae
commit b32f83d84d

View File

@ -237,13 +237,19 @@ def jit(
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
fun: Function to be jitted. Should be a pure function, as side-effects may
only be executed once. Its arguments and return value should be arrays,
fun: Function to be jitted. ``fun`` should be a pure function, as
side-effects may only be executed once.
The arguments and return value of ``fun`` should be arrays,
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
Positional arguments indicated by ``static_argnums`` can be anything at
all, provided they are hashable and have an equality operation defined.
Static arguments are included as part of a compilation cache key, which is
why hash and equality operators must be defined.
JAX keeps a weak reference to ``fun`` for use as a compilation cache key,
so the object ``fun`` must be weakly-referenceable. Most :class:`Callable`
objects will already satisfy this requirement.
static_argnums: An optional int or collection of ints that specify which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded in
@ -258,8 +264,9 @@ def jit(
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
arguments are treated as static. If ``static_argnums`` is not provided but
``static_argnames`` is, or vice versa, JAX uses ``inspect.signature(fun)``
to find any positional arguments that correspond to ``static_argnames``
``static_argnames`` is, or vice versa, JAX uses
:code:`inspect.signature(fun)` to find any positional arguments that
correspond to ``static_argnames``
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are
provided, ``inspect.signature`` is not used, and only actual
parameters listed in either ``static_argnums`` or ``static_argnames`` will