946 Commits

Author SHA1 Message Date
Jake VanderPlas
a44e129ae7 Add more informative error when static argument is passed to non-static JIT parameter 2024-09-24 05:22:18 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
3b89a2e573 Add a utility function to create a tangent zero value from a primal value.
PiperOrigin-RevId: 676449863
2024-09-19 09:42:12 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Dan Foreman-Mackey
dbb34f56dd Raise a clearer error message when closure_converted function is
called with inputs with the wrong structure.

Fixes https://github.com/google/jax/issues/23588
2024-09-17 15:04:09 -04:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
Yash Katariya
dd6f0e2e2e Add weak_type to ShapeDtypeStruct because jax.Array also has it and SDS is a duck of jax.Array
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.

Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
2024-08-29 08:35:42 -07:00
Matthew Johnson
670a648b7b add experimental jax.no_tracing context manager 2024-08-23 21:21:55 +00:00
Dan Foreman-Mackey
850edee36e Fix bug in custom_vjp with optimize_remat and custom_vmap.
When used with a `custom_vmap` that introduces a new const the previous
implementation of `optimize_remat` would error in its DCE rule because
of unexpected consts when closing the fwd jaxpr. This shouldn't have
ever been hit, but there was a bug in the batching rule for
`remat_opt_p` where we weren't properly converting constvars to invars.
This fixes this bug and should unbreak internal users.
2024-08-13 09:06:57 +01:00
Dan Foreman-Mackey
69fc8bb419 Consolidate handling of input argument resolution in custom_* APIs.
This is a partial re-land of https://github.com/google/jax/pull/22869 with some updates to ensure that it doesn't break existing uses of `custom_vmap`.

Previously, using a `custom_jvp` or `custom_vjp` with a primal function that has keyword-only arguments would result in a type error, even if these arguments weren't passed by the caller. I believe that this check is actually slightly stricter than it needed to be, as discovered when adding a similar check to `custom_vmap`. Instead, I think that it is sufficient to check that the caller hasn't _passed_ any keyword-only arguments.

The previous behavior in `custom_vmap` was even harsher: it would error if any keyword arguments were passed.

In this change, I have moved `resolve_kwargs` into `api_utils` so that the same function can be used in both `custom_derivatives` and `custom_batching`. I've also updated the logic to only throw a `TypeError` if the caller passes a keyword only argument when calling a `custom_*`-decorated function. This changes the behavior of `custom_jvp` and `custom_vjp`, although users shouldn't see that effect, since previously having kwargs would have errored.

PiperOrigin-RevId: 662402158
2024-08-13 00:30:23 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
Dan Foreman-Mackey
efb7721671 Remove unnecessary constraint on keyword-only arguments in custom_vjp with optimize_remat=True.
PiperOrigin-RevId: 660945559
2024-08-08 12:49:27 -07:00
jax authors
0309adf2a5 Merge pull request #22937 from dfm:custom-vmap-errors
PiperOrigin-RevId: 660880442
2024-08-08 10:05:34 -07:00
Matthew Johnson
44ae9b30ec fix #22944 2024-08-08 16:19:19 +00:00
Dan Foreman-Mackey
595ca0affa Improve error message for missing vmap rule in custom_vmap.
This is a partial re-land of https://github.com/google/jax/pull/22869
after it was rolled back to fix internal users. This part of the change
didn't cause the issues, and I'll follow up with the rest of the changes
in a second PR.
2024-08-08 14:08:51 +01:00
Jake VanderPlas
551f72979c Rollback of #22869
This is causing breakages due to overly-restrictive checks on kwargs

Reverts 893ae6eb800851b1c17c437982608bb59d3bc6be

PiperOrigin-RevId: 660803968
2024-08-08 06:00:17 -07:00
Dan Foreman-Mackey
1d425b2d30 Small tweaks to custom_vmap UI.
I'm working on some extensions to `custom_vmap` and came across these
small UI improvements (I think!). This includes the two changes:

1. A weakening of the `kwargs` check to be consistent with the one in
   `custom_vjp`/`custom_jvp`, and
2. An improved error message when `def_vmap` isn't called.
2024-08-05 17:51:47 +01:00
Dan Foreman-Mackey
80cfe83ddc Fix issue with multiple arguments when using custom_vjp with optimize_remat.
This was a silly bug in how we were handling the fact that the `fwd`
function expects `bool` entries for symbolic zeros. At least now I've
added a test!
2024-08-01 08:44:35 -04:00
Dan Foreman-Mackey
30d5a78b1c Add optional automatic remat optimization to custom_vjp.
As reported in https://github.com/google/jax/issues/21303, using `remat`
with `custom_vjp` can produce inefficient results. The high level
summary is that computing the grad of such a function results in the
`fwd` function of the `custom_vjp` being evaluated twice, even though
the first time the residuals are not actually used. In many cases this
isn't a problem because DCE will clean up the unnecessary computations.
But, when the fwd function requires an opaque call (e.g. pallas_call or
ffi_call), this no longer saves the day.

In this PR, I have added a parameter to `custom_vjp` called
`optimize_remat` (open for discussion!), which can be used to opt-in to
automatic optimization of this operation. Setting this flag to true
results in the `fwd` function being wrapped in a new custom primitive
which will DCE into a call to the primal function whenever the residuals
are unused.

This can be used to fix https://github.com/google/jax/issues/21303, and
I think it would make sense to eventually make this behavior the
default, but this implementation comes with a few caveats:

1. This feature is currently implemented in "initial style", which means
   that the `fwd` function is traced to a jaxpr when it is initially
   called. This means that when `optimize_remat=True`, the `custom_vjp`
   function doesn't support data dependent conditionals within `fwd`.
   This isn't a fundamental limitation of the method, but this
   implementation is much simpler so it seemed like a good place to
   start, and much of the complexity of the "final style" version of
   this logic should be simplified by work that @dougalm is doing.
   Furthermore, for the immediate use case of opaque calls, initial
   style is not a serious limitation.
2. When `optimize_remat=True`, symbolic zeros are not supported. Again
   this isn't a required restriction, but I chose to start without this
   added complexity and we can add support for symbolic zeros as needed
   in the future.
3. More subtly, while this new primitive supports `vmap`, it doesn't
   currently implement rules for composing with the AD system. This
   means that a `custom_vjp` constructed with `optimize_remat=True`
   won't currently work with some approaches to higher-order AD. I
   expect I know how to fix that and will either include that here or in
   a follow-up.
2024-07-31 10:48:29 -04:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Jake VanderPlas
f887b66d5d Remove the unaccelerate_deprecation utility 2024-07-23 05:07:49 -07:00
Robert Dyro
eb3f538c7e Correctly counting cache miss logs
PiperOrigin-RevId: 654860872
2024-07-22 12:53:09 -07:00
jax authors
c86b5f7281 Merge pull request #22244 from mattjj:remat-reduce-precision
PiperOrigin-RevId: 651504424
2024-07-11 12:53:58 -07:00
Matthew Johnson
f498d7a3ba make remat reduce precision of saved values to avoid xla excess precision
problem: f(x) != value_and_grad(f)(x)[0] ??

Co-authored-by: Peter Hawkins <phawkins@google.com>
2024-07-11 18:21:28 +00:00
Hyeontaek Lim
ad8e6713ba [JAX] Check if an array is deleted when resharding it with a different device order
`jax.device_put()` that changes only the device order did not have a check on
whether the input array has been deleted or donated. When this is the case, it
would generate an error "TypeError: 'NoneType' object is not iterable" when
attempting to access `array._arrays`". Calling an explicit check makes the
error message more understandable.

PiperOrigin-RevId: 651122614
2024-07-10 12:57:06 -07:00
Sergei Lebedev
ec7dd0fac1 `debug_info no longer requires non-None func_src_info`
I suspect in the past lack of source info meant that the function also has
no signature, but this is no longer the case.

I also removed an unused parameter from ``explain_tracing_cache_miss`` as
a drive by change.

This is a follow up to #22269.
2024-07-05 20:08:53 +01:00
jax authors
061ccd4e73 Merge pull request #22269 from superbobry:main
PiperOrigin-RevId: 649395181
2024-07-04 06:31:08 -07:00
Sergei Lebedev
ffa39c0858 Handle missing `debug_info in explain_tracing_cache_miss` 2024-07-04 14:07:10 +01:00
jax authors
a8f22f6e34 Merge pull request #19614 from cgarciae:batch_map
PiperOrigin-RevId: 649208823
2024-07-03 15:01:52 -07:00
Cristian Garcia
557b273707 support axes and batching in map 2024-07-03 17:46:10 +01:00
jax authors
d067255d83 Merge pull request #21766 from cgarciae:improve-vmap-outaxis-error
PiperOrigin-RevId: 649078839
2024-07-03 08:14:49 -07:00
Cristian Garcia
e45a95d96c Update jax/_src/api.py
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-07-02 16:10:36 +01:00
Cristian Garcia
a2524b1c8b fix from_elt cont handler 2024-07-01 21:55:54 +01:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Matthew Johnson
ad4e5ec43c fix error message path since #22069 allowed Literals in jaxpr builder eqns 2024-06-26 00:38:23 +00:00
Matthew Johnson
8564b55ee2 remove double dots from an error message 2024-06-26 00:04:42 +00:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Peter Hawkins
d7a22d3720 [JAX] Teach jit fast path how to handle negative static_argnums correctly.
PiperOrigin-RevId: 645172085
2024-06-20 15:18:25 -07:00
Yash Katariya
175183775b Replace jax.xla_computation with the AOT API and add a way to unaccelerate the deprecation in jax tests.
PiperOrigin-RevId: 644535402
2024-06-18 15:47:24 -07:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Yash Katariya
6ba16e0348 Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
2024-06-17 11:59:10 -07:00
Jake VanderPlas
33465274da fix some additional warnings related to #21834 2024-06-13 16:06:14 -07:00
Jake VanderPlas
3f210c63a0 avoid globally silencing the jit backend/device warning 2024-06-12 14:43:14 -07:00
Yash Katariya
956226c929 Raise an error if device_put sees an invalid value.
PiperOrigin-RevId: 642053543
2024-06-10 16:07:44 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
Matthew Johnson
2f4cc6e9cb [custom_vjp] allow bwd rule to produce top-level list (not just tuple) 2024-05-31 21:49:06 +00:00
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00