78 Commits

Author SHA1 Message Date
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
3b89a2e573 Add a utility function to create a tangent zero value from a primal value.
PiperOrigin-RevId: 676449863
2024-09-19 09:42:12 -07:00
Dan Foreman-Mackey
30d5a78b1c Add optional automatic remat optimization to custom_vjp.
As reported in https://github.com/google/jax/issues/21303, using `remat`
with `custom_vjp` can produce inefficient results. The high level
summary is that computing the grad of such a function results in the
`fwd` function of the `custom_vjp` being evaluated twice, even though
the first time the residuals are not actually used. In many cases this
isn't a problem because DCE will clean up the unnecessary computations.
But, when the fwd function requires an opaque call (e.g. pallas_call or
ffi_call), this no longer saves the day.

In this PR, I have added a parameter to `custom_vjp` called
`optimize_remat` (open for discussion!), which can be used to opt-in to
automatic optimization of this operation. Setting this flag to true
results in the `fwd` function being wrapped in a new custom primitive
which will DCE into a call to the primal function whenever the residuals
are unused.

This can be used to fix https://github.com/google/jax/issues/21303, and
I think it would make sense to eventually make this behavior the
default, but this implementation comes with a few caveats:

1. This feature is currently implemented in "initial style", which means
   that the `fwd` function is traced to a jaxpr when it is initially
   called. This means that when `optimize_remat=True`, the `custom_vjp`
   function doesn't support data dependent conditionals within `fwd`.
   This isn't a fundamental limitation of the method, but this
   implementation is much simpler so it seemed like a good place to
   start, and much of the complexity of the "final style" version of
   this logic should be simplified by work that @dougalm is doing.
   Furthermore, for the immediate use case of opaque calls, initial
   style is not a serious limitation.
2. When `optimize_remat=True`, symbolic zeros are not supported. Again
   this isn't a required restriction, but I chose to start without this
   added complexity and we can add support for symbolic zeros as needed
   in the future.
3. More subtly, while this new primitive supports `vmap`, it doesn't
   currently implement rules for composing with the AD system. This
   means that a `custom_vjp` constructed with `optimize_remat=True`
   won't currently work with some approaches to higher-order AD. I
   expect I know how to fix that and will either include that here or in
   a follow-up.
2024-07-31 10:48:29 -04:00
Roy Frostig
d51b8e6839 custom_vjp symbolic zeros support, take two
This change re-introduces symbolic zero support for `custom_vjp`.

This time:

* The forward rule API is slightly different, accepting two-field
  records at pytree leaves rather than pairs.

* In the default setting where symbolic_zeros is not set, there are no
  new requirements from pytree node definitions that are involved in
  the primal arguments. This avoids any change in behavior on the
  default path. In particular, custom pytree node definitions that
  aren't completely polymorphic in unflattening can remain as is.

* There is an additional test involving a custom pytree node.
2023-04-05 11:17:05 -07:00
Matthew Johnson
5c4525cb10 custom_jvp symbolic zeros support
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-02-24 07:33:49 -08:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
fe665b3a64 Copybara import of the project:
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:

remove custom_jvp_call_jaxpr_p and its rules

They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!

PiperOrigin-RevId: 468373797
2022-08-17 22:40:58 -07:00
Matthew Johnson
887b7ce2cb remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!
2022-08-17 21:12:27 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Peter Hawkins
da1b819f26 Move contents of jax.custom_derivatives to jax._src.custom_derivatives.
PiperOrigin-RevId: 369340983
2021-04-19 17:51:49 -07:00
Neil Girdhar
ba2a7920d9 Annotate custom_vjp and custom_jvp 2021-04-13 16:53:51 -04:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Jake VanderPlas
8e789c7380 Run doctest on all source files except jax2tf 2021-04-05 10:39:59 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Matthew Johnson
4253c929ec use correct batching util in custom_vjp_call_jaxpr
fixes #5832
2021-03-28 11:17:41 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Lena Martens
276645dd59 Fix error when running a jvp of a jit of a custom_vjp.
Error before:
NotImplementedError: XLA translation rule for primitive 'custom_lin' not found

Error after:
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
2021-03-18 20:14:43 +00:00
Roy Frostig
049b83a4ca document what closure_convert returns 2021-03-02 09:18:46 -08:00
Roy Frostig
912cc87a3d introduce linear_call for custom transposition.
This adds a primitive with a corresponding traceable function in
`custom_derivatives` that takes a callee and its transpose, both
functions. When the primitive is encountered during transposition, the
given transpose function is invoked instead of transpose-transforming
the callee. The invocation of the custom transposition is itself done
via a `linear_call`, with the original callee set as the transpose.
This maintains, in particular, that transposing twice is an identity.
2021-02-26 10:46:54 -08:00
George Necula
617d77e037 Improve error message for when backward function in custom_vjp does not return
a tuple.

Prior to this we got an assertion that `py_cts_in is not iterable`.
2021-01-31 22:16:36 +02:00
Sharad Vikram
6061b0979a Allow jax.custom_gradient to return vjp with singleton return value 2021-01-26 13:40:23 -08:00
Matthew Johnson
9787894d94 refactor batching transform logic, fix leak checks
See PR description in #5492 for details.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2021-01-22 20:17:03 -08:00
Matthew Johnson
203af4517b revive the leak checker, as a debug mode
Co-authored-by: James Bradbury <jekbradbury@google.com>
2021-01-22 18:31:00 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Roy Frostig
0ad5d2f0c2 update changelog with closure_convert 2020-12-30 17:24:20 -08:00
Roy Frostig
2271651bc1 add closure_convert to api, write a docstring for it 2020-12-30 11:03:41 -08:00
Roy Frostig
85b0fa700d factor convert_closure from ode to custom_derivatives. 2020-12-30 11:03:41 -08:00
8bitmp3
b0b447646d
Nit: fix typo in custom_derivatives.py ("separate") 2020-12-29 16:35:19 +00:00
Adam Paszke
5879967c25 Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.

One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.

* Implementation details *

This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.

** Thunking **

The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:

*** Transformations ***

Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
  outputs = yield args, kwargs
  yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
  old_out_axes = params['out_axes_thunk']()
  return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).

The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.

The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.

*** Compilation cache ***

Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.

Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.

* Why final style? *

Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-24 17:11:38 +00:00
Peter Hawkins
424594feb2 Short-circuit references to jax.core via jax.abstract_arrays. 2020-11-19 14:15:28 -05:00
Jamie Townsend
931b2ddbcb Improve docs for custom_jvp and custom_vjp
Correct the custom_jvp docstring to include the defjvps instance method. Add the
defjvp/defvjp instance methods to the sphinx doc.
2020-11-06 11:40:16 +00:00
Peter Hawkins
81b6cd29ff [JAX] Move traceback_util.py into jax._src.
traceback_util is a JAX-private API.

PiperOrigin-RevId: 340659195
2020-11-04 09:02:59 -08:00
jax authors
d158647c83 Merge pull request #4706 from apaszke:vmap-collectives-in-scan
PiperOrigin-RevId: 339646941
2020-10-29 05:11:23 -07:00
Roy Frostig
5d50e19364 add path exclusion opt-in to filtered stack traces and use it throughout the codebase 2020-10-26 12:31:19 -07:00
Adam Paszke
6348a99fb4 Add support for vmap collectives in control flow primitives
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.
2020-10-26 12:09:18 +00:00
Matthew Johnson
3c6cdcfc8f add jax.custom_gradient wrapper for jax.custom_vjp
There was a deprecatd version of this wrapper implemented in terms of
jax.custom_transforms (which itself is deprecated, and hopefully soon to
be removed), but this commit adds an implementation in terms of
jax.custom_vjp. One drawback it has relative to jax.custom_vjp is that
it doesn't support Python control flow in the backward-pass function.
2020-10-23 22:32:19 -07:00
Matthew Johnson
a46d0028cc fix a custom_jvp vmap bug from @dpfau 2020-10-20 21:08:59 -07:00
Matthew Johnson
79e3db5508 fixes on #4008 (thanks @apaszke) 2020-10-20 17:51:51 -07:00
jax authors
4a20eea828 Copybara import of the project:
--
609f6f3e16d21fed34cc5269c54a0d78ac44a8bc by Matthew Johnson <mattjj@google.com>:

fix custom_jvp/vjp closure issues

PiperOrigin-RevId: 337457689
2020-10-16 00:21:32 -07:00
Matthew Johnson
f3b4f43c20 temporarily work around a bug that #4008 will fix 2020-10-15 21:58:27 -07:00
Matthew Johnson
3a75145fd2 allow custom_vjp bwd to return Nones for zeros
This change sets up some internal users so that we can then land #4008.
2020-10-15 10:53:16 -07:00
Jake VanderPlas
6393349783 raise_to_shaped: preserve weak_type by default 2020-10-08 11:53:52 -07:00
Lena Martens
e3d622cc67 Recast int/bool tangents to float0 in custom_jvp/vjps
(also in the initial_style path).
2020-10-08 17:37:35 +01:00
Matthew Johnson
6614f94890
rename and simplify TypedJaxpr -> ClosedJaxpr (#4328)
rename and simplify TypedJaxpr -> ClosedJaxpr

This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
  in_avals / out_avals no longer need to be constructed), making them
  easier to work with;
* correspondingly rules out a class of errors (mismatches between
  invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
  but they're closed in that they are packaged with their constant
  values).

This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-09-18 10:07:13 -07:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Tom Hennigan
f0fb7d0925
Use omnistaging env var even when not using absl flags for config. (#4152) 2020-08-26 14:06:27 -07:00
Adam Paszke
b75bae6437
Initial version of vmap collectives (#4005)
This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.

This feature is also omnistaging-only.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-08-14 18:22:04 +02:00