18 Commits

Author SHA1 Message Date
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Peter Hawkins
b6abd738d9 Relax some test tolerances in for_loop_test.py.
This PR attempts to fix some CI failures on Mac ARM.

PiperOrigin-RevId: 672312564
2024-09-08 12:09:45 -07:00
Peter Hawkins
bc415f9153 Relax test tolerances to fix CI failures on Mac ARM. 2024-09-03 09:45:28 -04:00
Yash Katariya
175183775b Replace jax.xla_computation with the AOT API and add a way to unaccelerate the deprecation in jax tests.
PiperOrigin-RevId: 644535402
2024-06-18 15:47:24 -07:00
George Karpenkov
de14e3b32e Reverts 49bd4d6f01d6cda00f9b1bdfbda156636baae928
PiperOrigin-RevId: 633221195
2024-05-13 08:35:40 -07:00
Peter Hawkins
49bd4d6f01 Reverts 586568f4fe44cf9ad8b1bd022148a10c4b69f33a
PiperOrigin-RevId: 632818524
2024-05-11 12:24:06 -07:00
George Karpenkov
586568f4fe Simplify JAX lowering rules for cumulative sum
Rely on XLA decomposition.

# JAX GPU microbenchmarks

285us for cumsum over 1e8 elements

449us for cumsum over 1e8 elements.

# JAX CPU microbenchmarks:

1.8s vs. 0.7s for 50 iterations over cumsum over 1e7 elements

PiperOrigin-RevId: 632547166
2024-05-10 11:03:28 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
08c5753ee2 Implement unrolling for for_loop 2022-09-14 18:32:37 -07:00
Matthew Johnson
71b0968f70 skip some for_loop test cases on gpu due to flakey timeouts
PiperOrigin-RevId: 474168747
2022-09-13 17:51:55 -07:00
Sharad Vikram
f26f1e8afc Add support for closing over Refs in nested for loops 2022-09-13 13:32:44 -07:00
Sharad Vikram
ad326b99da Use cases_from_list to subsample enumerated cases in for_loop_test
PiperOrigin-RevId: 474093596
2022-09-13 12:34:10 -07:00
Sharad Vikram
e5725f1df1 Split for_loop_test out of lax_control_flow_test
PiperOrigin-RevId: 473848277
2022-09-12 14:46:07 -07:00