258 Commits

Author SHA1 Message Date
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
jax authors
750487f2cf Adjusts error tolerance for lax_control_flow_test
PiperOrigin-RevId: 620343970
2024-03-29 14:40:39 -07:00
jax authors
e03f1d4fd1 Allows for splitting the transpose of a scan into a scan and a map.
This is an experimental feature exposed as an extra parameter: `scan(..., _split_transpose:bool)`.

If the parameter is true then the transpose of scan generates not just 2 scans
(forward and transpose of the linearized forward), but rather 3 scans: (i)
forward (as before), (ii) transposed scan that only computes loop-carried state
required for back-propagation, but saves other intermediate gradients; (iii) a
scan (actually a map) that uses any saved activation gradients and original
residuals to compute any other gradients.

Warning: this feature is somewhat experimental and may evolve or be rolled back.
PiperOrigin-RevId: 619991098
2024-03-28 10:54:50 -07:00
Jake VanderPlas
84e49bd6ce Remove internal references to deprecated jax.experimental.maps 2024-03-19 09:24:52 -07:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Philip Pham
3fc72d1f44 Fix jax.lax.fori_loop(..., unroll=True) with non-positive length 2024-01-26 17:06:30 +00:00
Lu Teng
633ddca560 Fix error caused by too many devices in lax_control_flow_test.py 2023-12-21 16:45:49 +08:00
Sharad Vikram
b04fd317c2 Add option to pass in unroll=True/False into scan and fori_loop.
PiperOrigin-RevId: 587795364
2023-12-04 11:54:50 -08:00
Sharad Vikram
54e3b7611a Add support for unrolling to lax.fori_loop
PiperOrigin-RevId: 587767613
2023-12-04 10:34:53 -08:00
George Necula
2d9da6c8fb Cleanup the code to picking lowering rules based on platform.
Previously, we had special-cased the code to pick the lowering
rule for a primitive based on the lowering platform, and separately
we had the code to handle multi-platform lowering. The latter,
called `mlir.lower_multi_platform` had its own special case for
when a single lowering rule applied.

We rename `mlir.lower_multi_platform` to `mlir.lower_per_platform`
to not imply that it is only for multi-platform. We simplify
its API (takes a dictionary instead of a list of tuples).
2023-11-19 18:39:59 +02:00
George Necula
8feb413211 Add a lax.platform_dependent API for writing platform-dependent code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
Peter Hawkins
79469fd9ed Fix test failure in lax_control_flow_test on Mac ARM.
Fixes https://github.com/google/jax/issues/14793 (the other issues in that bug no longer reproduce).

https://github.com/jax-ml/ml_dtypes/pull/112 is needed to fix some spurious cast warnings.
2023-10-09 10:34:55 -04:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Roy Frostig
14f32653a1 resolve conditionals to default "shared operand form" more often
If both the second and third operand of a `lax.cond` call are callable, then
resolve it as a new-style (default) conditional, where both branches act on the
same operands.

This changes the behavior of five-argument `lax.cond` calls. It is a breaking
change for callers using the old-style `cond` calling convention (`pred`,
`true_arg`, `true_fn`, `false_arg`, `false_fn`) with a callable `true_arg`.

PiperOrigin-RevId: 543912445
2023-06-27 18:49:16 -07:00
Peter Hawkins
9867f0dc9c Fix test failure in lax_control_flow_test.py on Windows due to line ending differences. 2023-06-16 10:33:46 -04:00
Joey Teng
c2c8314984
remove unused variable in lax.scan test
The variable was used to parameterize tests, now replaced by another.
2023-06-10 14:05:12 +01:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
4bbe26031b fori_loop: allow promotion of scalar limits.
Adapted from https://github.com/google/jax/pull/13494 being careful not to prevent the use of the scan() implementation.

PiperOrigin-RevId: 523683140
2023-04-12 06:32:16 -07:00
Matthew Johnson
ba2ff519ca improve scan error messages 2023-03-23 14:53:05 -07:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
jax authors
78599e65d1 Roll-back https://github.com/google/jax/pull/14144 due to downstream test failures
PiperOrigin-RevId: 504628432
2023-01-25 12:15:36 -08:00
jax authors
d14e144651 Use pareto optimal step size for computing numerical Jacobians in JAX. This allows us to tighten the tolerances in gradient unit testing significantly, especially for float64 and complex128.
PiperOrigin-RevId: 504579516
2023-01-25 09:12:52 -08:00
Rasmus Munk Larsen
c798fcaefc
Remove more uses of tan() in reduction tests.
This is to avoid subtly brittle tests. Tan() is an ill-conditioned function to evaluate near it's singularities.
2023-01-19 15:20:02 -08:00
Rasmus Munk Larsen
2ee33a0728
Update lax_control_flow_test.py
Fix brittle scan test. Adding tan(randn) is numerically brittle because evaluating tan() near its singularities is ill-conditioned.
2023-01-18 13:11:56 -08:00
Jake VanderPlas
6376dc9616 Fix excessive recompiles in lax.cond 2023-01-18 10:17:01 -08:00
Matthew Johnson
e516d41180 cond transpose, use UndefinedPrimal not linear for transpose inputs 2023-01-16 10:39:19 -08:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
f26f1e8afc Add support for closing over Refs in nested for loops 2022-09-13 13:32:44 -07:00
Sharad Vikram
e5725f1df1 Split for_loop_test out of lax_control_flow_test
PiperOrigin-RevId: 473848277
2022-09-12 14:46:07 -07:00
Sharad Vikram
6967c7ef51 Add sound loop invariance detection 2022-09-08 10:42:19 -07:00
Sharad Vikram
b2a5d2c3bb Add partial_eval_custom rule for for_loop 2022-09-06 11:00:26 -07:00
Matthew Johnson
bbb8048d2e Add batching rules for state primitives and for_loop
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-29 11:40:09 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Peter Hawkins
c247b9b6da Relax test tolerance to fix GPU CI failure. 2022-08-22 08:50:57 -04:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Sharad Vikram
49b7729f6b More tests for transpose 2022-08-18 18:06:21 -07:00
Sharad Vikram
72dbe31172 Initial transpose implementation
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 10:23:04 -07:00
Sharad Vikram
8b7daa8095 Refactor state out of for_loop 2022-08-01 15:26:55 -07:00
Matthew Johnson
7f3aa12142 add while_loop custom-policy partial eval rule 2022-07-28 18:04:49 -07:00
Matthew Johnson
ec9f9c3c07 add cond dce rule and custom-policy partial eval rule 2022-07-28 15:50:47 -07:00
jax authors
016c6df65e Merge pull request #11618 from mattjj:scan-partial-eval-custom-fix
PiperOrigin-RevId: 463499406
2022-07-26 21:46:02 -07:00
Matthew Johnson
c44dfce571 fix bug noticed by @levskaya 2022-07-26 21:04:18 -07:00
Roy Frostig
4cd0c68136 err on None predicate to lax.cond 2022-07-26 13:12:16 -07:00
Sharad Vikram
9d610e2de6 Add loop invariant residual fixpoint test 2022-07-11 13:10:03 -07:00