981 Commits

Author SHA1 Message Date
jax authors
d3850e7fdd Support optimization_level and memory_fitting_level XLA compilation options.
PiperOrigin-RevId: 727070422
2025-02-14 14:46:11 -08:00
George Necula
000b92f539 [better_errors] Continue adding debug info to Jaxprs (step 5)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
2025-02-07 08:23:10 +02:00
George Necula
abcaec7081 [better_errors] Add debug info to the Jaxprs formed for AD
Following #26078 , we add debug info to more calls of lu.wrap_init.
2025-02-05 19:21:02 +02:00
George Necula
e4d5427d13 [better_errors] Add more debug info test coverage
Try to cover the tracing of almost all JAX higher-order
primitives. Some of the tests added show missing debug info,
marked with TODO. Fixes will come separately.

Had to expand the helper functions _check_tracers_and_jaxprs to
use regular expressions for matching because some debug info
still contains non-deterministic elements.
2025-01-26 08:12:29 +02:00
Dan Foreman-Mackey
28d573354b Add DCE rules for custom_jvp and custom_vjp. 2025-01-23 15:22:43 -05:00
Dan Foreman-Mackey
e3b3b913f7 Add an experimental interface for customizing DCE behavior.
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior.

In https://github.com/jax-ml/jax/pull/22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration.

PiperOrigin-RevId: 718950828
2025-01-23 11:38:47 -08:00
George Necula
e5d89e738a [better_errors] Refactor debug info tests
Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.
2025-01-20 20:21:01 +01:00
Zachary Garrett
f7d097f7cc Make utils for reporting function name work with functools.partial by using the inner .func attribute if the object doesn't have a __name__ attribute. functools.partial objects do not have __name__ attributes by default.
PiperOrigin-RevId: 715881812
2025-01-15 11:40:59 -08:00
Peter Hawkins
8f2f4b45fb Annotate several tests as thread-unsafe.
PiperOrigin-RevId: 714117130
2025-01-10 11:24:39 -08:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Peter Hawkins
e20523c2e3 Make api_test.py work when test cases are run using multiple threads.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.

PiperOrigin-RevId: 713411171
2025-01-08 14:09:07 -08:00
Matthew Johnson
9acd4a95b6 improve checkpoint / remat concreteness error with static_argnums 2024-12-18 04:24:54 +00:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Dan Foreman-Mackey
092d2a0db5 Add error message when using custom_vmap with reverse-mode AD, and add docstrings.
The `custom_vmap` API is discussed in https://github.com/jax-ml/jax/issues/9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 704353963
2024-12-09 11:17:44 -08:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
Dougal
b1d1dcf607 Add linearization rule for pjit_p 2024-11-22 14:24:46 -08:00
Dougal
170718c8d4 Change signature of linearization rules.
Give the rule the nonzero tangent pattern up-front. This is needed to make a
linearization rule for pjit_p. Also make the rules return the nonzero tangents
out, an explicit residual, and a closed tangent function. Add a rule for sin_p
to test it out. We still need to figure out how to avoid having to precompute
`cos(x)`. I think we need to update our backward pass code.
2024-11-21 19:03:42 -08:00
Dougal
d0f17c0c04 Make a direct linearize trace.
This is an alternative to doing JVP followed by partial eval. The linearize
trace has two parent traces, one for the primal computation and one for the
tangent computation. If we make the tangent trace a DynamicJaxprTrace then we
get staged linearization. If we make it the same as the primal trace then we get
primal and tangent computations occurring in step (JVP). This is a neat trick
enabled by stackless which now lives up to its name. With two parent traces we
have a tree of traces not a linked list stack.

Primitive ops can have their own linearization rules but as a fallback we can
derive a linearization rule for a single op using jvp/partial-eval.

For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can
make this the default for linearize/grad. It should help with remat and AD
through state which are awkward to express via partial eval.
2024-11-20 10:03:00 -08:00
Peter Hawkins
f32505169f Filter custom dtypes by supported_dtypes in _LazyDtypes.
The other methods of `_LazyDtypes` filter by the supported dtypes, so it's strange that this property does not.

Change in preparation for landing https://github.com/jax-ml/jax/pull/23585 without breaking existing tests.

PiperOrigin-RevId: 697752034
2024-11-18 14:07:52 -08:00
jax authors
cea8176756 Merge pull request #24751 from Stella-S-Yan:feature/default_device_str
PiperOrigin-RevId: 696560063
2024-11-14 10:00:18 -08:00
Stella-S-Yan
afa518aa0e Allow setting default_device with platform names. 2024-11-11 22:46:57 +00:00
Matthew Johnson
0f3ba4250d support exec_time_optimization_effort and memory_fitting_effort xla compilation
options

PiperOrigin-RevId: 692322944
2024-11-01 16:25:50 -07:00
Yash Katariya
fff33f90b2 Add compiler_options argument to jax.jit.
This exists on `Compiled` object via AOT too i.e. `jit(f).lower(*args).compile(compiler_options={})`

PiperOrigin-RevId: 692283964
2024-11-01 14:01:19 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
George Necula
5ccfc8d716 Reverts c3b4b76080dbedfebfed978c812338e2f680ee23
PiperOrigin-RevId: 690990311
2024-10-29 06:07:15 -07:00
Matthew Johnson
86a47a7d4e fix jax.custom_gradient to allow closing over non-autodiff tracers 2024-10-29 00:32:01 +00:00
jax authors
47bacfab5e Merge pull request #24031 from garymm:garymm/vmap-error-msg
PiperOrigin-RevId: 689940504
2024-10-25 15:59:57 -07:00
Gary Miguel
9f7f08eccb Fix vmap error message when args passed by keyword
See the new test for a case that used to produce the wrong message.

Fixes: #24406
2024-10-25 15:17:03 -07:00
Matthew Johnson
4231128535 improve concreteness error message in remat 2024-10-24 18:13:42 +00:00
Yash Katariya
1efca33187 Add donate and may_alias as an argument to device_put to allow for donation and aliasing.
The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.

**Definition:**

* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.

* may_alias: If True, we may return the original buffer depending on the implementation.

**What problem are we solving?**

Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.

Adding `donate` allows users to avoid this pattern of code:

```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```

Now it can just be: `jax.device_put(inp, sharding, donate=True)`

**So what are the semantics of these 2 options?** Let's create a table:

| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |

`donate` is best effort for now until we fix the following things:

 * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.

 * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.

PiperOrigin-RevId: 681073828
2024-10-01 10:28:23 -07:00
jax authors
df042fded2 Merge pull request #23870 from Zantares:tenglu/flush_output
PiperOrigin-RevId: 679639244
2024-09-27 10:21:25 -07:00
Lu Teng
a31f79ce0b Flush stdout buffer before checking. 2024-09-25 10:30:42 +08:00
Yash Katariya
1fe0c5dad5 Fix printing of saved_residual for jit by looking for pjit as the primitive instead of xla_call which was removed 2 years ago
PiperOrigin-RevId: 678479141
2024-09-24 19:01:19 -07:00
Jake VanderPlas
a44e129ae7 Add more informative error when static argument is passed to non-static JIT parameter 2024-09-24 05:22:18 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
3b89a2e573 Add a utility function to create a tangent zero value from a primal value.
PiperOrigin-RevId: 676449863
2024-09-19 09:42:12 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Dan Foreman-Mackey
dbb34f56dd Raise a clearer error message when closure_converted function is
called with inputs with the wrong structure.

Fixes https://github.com/google/jax/issues/23588
2024-09-17 15:04:09 -04:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
Yash Katariya
dd6f0e2e2e Add weak_type to ShapeDtypeStruct because jax.Array also has it and SDS is a duck of jax.Array
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.

Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
2024-08-29 08:35:42 -07:00
Matthew Johnson
670a648b7b add experimental jax.no_tracing context manager 2024-08-23 21:21:55 +00:00
Dan Foreman-Mackey
850edee36e Fix bug in custom_vjp with optimize_remat and custom_vmap.
When used with a `custom_vmap` that introduces a new const the previous
implementation of `optimize_remat` would error in its DCE rule because
of unexpected consts when closing the fwd jaxpr. This shouldn't have
ever been hit, but there was a bug in the batching rule for
`remat_opt_p` where we weren't properly converting constvars to invars.
This fixes this bug and should unbreak internal users.
2024-08-13 09:06:57 +01:00
Dan Foreman-Mackey
69fc8bb419 Consolidate handling of input argument resolution in custom_* APIs.
This is a partial re-land of https://github.com/google/jax/pull/22869 with some updates to ensure that it doesn't break existing uses of `custom_vmap`.

Previously, using a `custom_jvp` or `custom_vjp` with a primal function that has keyword-only arguments would result in a type error, even if these arguments weren't passed by the caller. I believe that this check is actually slightly stricter than it needed to be, as discovered when adding a similar check to `custom_vmap`. Instead, I think that it is sufficient to check that the caller hasn't _passed_ any keyword-only arguments.

The previous behavior in `custom_vmap` was even harsher: it would error if any keyword arguments were passed.

In this change, I have moved `resolve_kwargs` into `api_utils` so that the same function can be used in both `custom_derivatives` and `custom_batching`. I've also updated the logic to only throw a `TypeError` if the caller passes a keyword only argument when calling a `custom_*`-decorated function. This changes the behavior of `custom_jvp` and `custom_vjp`, although users shouldn't see that effect, since previously having kwargs would have errored.

PiperOrigin-RevId: 662402158
2024-08-13 00:30:23 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
Dan Foreman-Mackey
efb7721671 Remove unnecessary constraint on keyword-only arguments in custom_vjp with optimize_remat=True.
PiperOrigin-RevId: 660945559
2024-08-08 12:49:27 -07:00
jax authors
0309adf2a5 Merge pull request #22937 from dfm:custom-vmap-errors
PiperOrigin-RevId: 660880442
2024-08-08 10:05:34 -07:00
Matthew Johnson
44ae9b30ec fix #22944 2024-08-08 16:19:19 +00:00