295 Commits

Author SHA1 Message Date
jax authors
45e6808bb5 Merge pull request #27084 from danielsuo:switch-fwd
PiperOrigin-RevId: 743452172
2025-04-03 01:07:50 -07:00
Daniel Suo
e364abe961 Prune passthrough outputs in lax.switch. 2025-03-26 18:53:14 +00:00
Matthew Johnson
a092df90ba fix a linearize-of-remat-of-while_loop-fixpoint bug
We were using the original unknown-carries-in rather than the fixpoint-updated ones.
2025-03-23 03:50:55 +00:00
Peter Hawkins
67aa997f84 Increase the number of iterations in a test that compares rolled versus unrolled HLO for length.
A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations.

PiperOrigin-RevId: 735875835
2025-03-11 13:45:19 -07:00
Jake VanderPlas
4ae3211ea2 jax.disable_jit: ensure while_loop behaves similarly to non-disable_jit version 2025-03-11 09:53:34 -07:00
Jake Harmon
cdeeacabcf Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
Peter Hawkins
33bbd5f119 Fix failures in TSAN free threading CI. 2025-02-26 06:04:26 -05:00
George Necula
e5d89e738a [better_errors] Refactor debug info tests
Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.
2025-01-20 20:21:01 +01:00
George Necula
3faff78ca8 [better_errors] Ensure that tracer errors in for_loop points to use code
Fixes: 23637
2025-01-13 15:33:30 +00:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
George Necula
c2adfbf1c2 [better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140

Previously, the following code:
```
def f(i, x):
  return lax.switch(i, [lambda x: dict(a=x),
                        lambda x: dict(a=(x, x))], x)
f(0, 42)
```

resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```

With this change the error message is more specific where the
difference is in the pytree structure:

```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
    * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-10 08:03:33 +02:00
Jake VanderPlas
74e9275bf2 Fix incorrect capitalization in scan error message 2024-12-16 11:37:31 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
IvyZX
bd77a703fd Avoid index out of range error in carry structure check 2024-12-09 10:44:28 -08:00
Benjamin Chetioui
15a11365e4 Change the lowering rule for jax.lax.scan to avoid emitting a while loop
when the intent is to fully unroll the loop.

PiperOrigin-RevId: 691393597
2024-10-30 06:20:39 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
bc415f9153 Relax test tolerances to fix CI failures on Mac ARM. 2024-09-03 09:45:28 -04:00
Roy Frostig
b3e3115391 improve scan error message on non-concrete unroll argument 2024-08-24 23:09:12 -07:00
Roy Frostig
a9b41e9fe7 improve scan error message on non-concrete length argument
Specifically, make it speak concretely about the `length` argument.
2024-08-24 22:30:33 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
Matthew Johnson
bdcd358b65 improve while_loop carry pytree/type mismatch errors
Now we call into the same error utility as we use in scan.
2024-08-03 21:57:29 +00:00
Matthew Johnson
f72a3f8ef4 deflake cond memory leak regression test 2024-07-25 23:12:21 +00:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Matthew Johnson
8db862c02e fix memory leak in cond jaxpr tracig
fixes #12719
2024-07-23 23:57:02 +00:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
jax authors
fe3c8e15a8 Merge pull request #21806 from cgarciae:cond-passthrough-outputs
PiperOrigin-RevId: 646970169
2024-06-26 09:13:07 -07:00
Cristian Garcia
dae7e41ade fix cond passthrough outputs 2024-06-26 16:17:45 +01:00
George Necula
6e3fc9a768 Fix the eager mode execution for lax.platform_dependent
When we use lax.platform_dependent in eager mode, and some
of the branches contain custom calls that are not recognized on
some platforms, we must eagerly pick the required branch.
In jit mode, the constant folding that the XLA compiler already
does will eliminate the unnecessary branches.
2024-06-21 17:07:48 +03:00
Yash Katariya
175183775b Replace jax.xla_computation with the AOT API and add a way to unaccelerate the deprecation in jax tests.
PiperOrigin-RevId: 644535402
2024-06-18 15:47:24 -07:00
Yash Katariya
44a13c9d4b Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.

Since we can keep the existing behavior and still merge the implementation is a good cleanup!

Fixes https://github.com/google/jax/issues/21116

PiperOrigin-RevId: 641347140
2024-06-07 13:48:31 -07:00
Matthew Johnson
7c125701c5 make cond forward inputs to outputs, reduces vmap lifting
Co-authored-by: Cristian Garcia <cgarciae@google.com>
2024-06-05 16:39:55 +00:00
Matthew Johnson
a24b73802f avoid singleton dim in scan lowering when unroll==1 2024-05-25 19:07:49 +00:00
Dan Foreman-Mackey
09a4b38ae2 Add informative error for invalid unroll in scan
As reported in #20481, setting `unroll=0` in `lax.scan` resulted in an
uninformative `ZeroDivisionError`. This PR adds a check which raises a
`ValueError` for `unroll<=0`.
2024-05-15 15:40:27 -04:00
Yash Katariya
96f888bcfe Reverts 1956ff7d7b73794012fece2d8452e097196587fc
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
Yash Katariya
1956ff7d7b Add specialize on jax.jit so that we can delete the duplicate code in jax.make_jaxpr.
You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized`

PiperOrigin-RevId: 628748620
2024-04-27 18:58:16 -07:00
carlosgmartin
2b332de9d7 Let xs=None by default in lax.scan. 2024-04-23 17:26:23 -04:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
jax authors
750487f2cf Adjusts error tolerance for lax_control_flow_test
PiperOrigin-RevId: 620343970
2024-03-29 14:40:39 -07:00
jax authors
e03f1d4fd1 Allows for splitting the transpose of a scan into a scan and a map.
This is an experimental feature exposed as an extra parameter: `scan(..., _split_transpose:bool)`.

If the parameter is true then the transpose of scan generates not just 2 scans
(forward and transpose of the linearized forward), but rather 3 scans: (i)
forward (as before), (ii) transposed scan that only computes loop-carried state
required for back-propagation, but saves other intermediate gradients; (iii) a
scan (actually a map) that uses any saved activation gradients and original
residuals to compute any other gradients.

Warning: this feature is somewhat experimental and may evolve or be rolled back.
PiperOrigin-RevId: 619991098
2024-03-28 10:54:50 -07:00
Jake VanderPlas
84e49bd6ce Remove internal references to deprecated jax.experimental.maps 2024-03-19 09:24:52 -07:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Philip Pham
3fc72d1f44 Fix jax.lax.fori_loop(..., unroll=True) with non-positive length 2024-01-26 17:06:30 +00:00
Lu Teng
633ddca560 Fix error caused by too many devices in lax_control_flow_test.py 2023-12-21 16:45:49 +08:00
Sharad Vikram
b04fd317c2 Add option to pass in unroll=True/False into scan and fori_loop.
PiperOrigin-RevId: 587795364
2023-12-04 11:54:50 -08:00
Sharad Vikram
54e3b7611a Add support for unrolling to lax.fori_loop
PiperOrigin-RevId: 587767613
2023-12-04 10:34:53 -08:00
George Necula
2d9da6c8fb Cleanup the code to picking lowering rules based on platform.
Previously, we had special-cased the code to pick the lowering
rule for a primitive based on the lowering platform, and separately
we had the code to handle multi-platform lowering. The latter,
called `mlir.lower_multi_platform` had its own special case for
when a single lowering rule applied.

We rename `mlir.lower_multi_platform` to `mlir.lower_per_platform`
to not imply that it is only for multi-platform. We simplify
its API (takes a dictionary instead of a list of tuples).
2023-11-19 18:39:59 +02:00
George Necula
8feb413211 Add a lax.platform_dependent API for writing platform-dependent code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
Peter Hawkins
79469fd9ed Fix test failure in lax_control_flow_test on Mac ARM.
Fixes https://github.com/google/jax/issues/14793 (the other issues in that bug no longer reproduce).

https://github.com/jax-ml/ml_dtypes/pull/112 is needed to fix some spurious cast warnings.
2023-10-09 10:34:55 -04:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00