Peter Hawkins
62e66b684b
Don't monkey-patch functions in test_utils to count events for tests.
...
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Matthew Johnson
11fdda9583
add checkify rule for remat
...
fixes #23867
2024-10-01 02:01:18 +00:00
jax authors
6c52ddc97f
[Checkify] Add checks for shard_map.
...
PiperOrigin-RevId: 677798938
2024-09-23 08:11:22 -07:00
Yash Katariya
abc9ba00e9
Rename count_jit_and_pmap_compiles
to count_jit_and_pmap_lowerings
...
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
Sergei Lebedev
cbcaac2756
MAINT Migrate remaining internal/test modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008 .
2023-10-12 17:32:15 +01:00
Jake Hall
f59a4163fa
Test changes for out-of-tree backend.
2023-09-14 12:18:37 +01:00
Jake VanderPlas
2f878a7168
Tests: set jax_legacy_prng_key='error'
2023-08-28 10:56:09 -07:00
Lena Martens
55da62ff75
Better pprint rule for check_p primitive.
...
PiperOrigin-RevId: 539703344
2023-06-12 10:58:40 -07:00
Matthew Johnson
01fa7e07dd
fix checkify + custom_vjp after symbolic zeros change
...
Co-authored-by: Lena Martens <lenamartens@google.com>
2023-06-07 20:32:21 -07:00
lenamartens
ee6cbafa85
Checkify: Fix closing over Tracer in while_loop cond_f.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-05-17 18:43:23 +01:00
Matthew Johnson
f55de18933
[checkify] fix closed_call_p handling
...
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-05-10 22:00:16 -07:00
Matthew Johnson
391e95a683
fix checkify custom_jvp rule to handle symbolic zeros
...
likely broken in #15426 , or maybe not quite right before either
Co-authored-by: Roy Frostig <frostig@google.com>
2023-05-09 14:12:53 -07:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
Jake VanderPlas
8d1cf99825
checkify: dynamic_update_slice OOB index check
2023-04-17 13:43:26 -07:00
Lena Martens
d1438a1205
Checkify: close over all arguments.
...
This means you don't have to worry about passing in non-jax-types (like
strings) or marking arguments as static.
Fixes #15504 .
2023-04-12 18:32:37 +01:00
Jake VanderPlas
46297dccaf
checkify: catch OOB errors in dynamic_slice
...
This will allow checkify tests to continue working properly after #15377
2023-04-04 08:16:59 -07:00
Peter Hawkins
dea7450e4e
Remove references to jax.config.jax_array, which is always True at head.
...
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
1925aa1109
Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
...
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Jake VanderPlas
f7dec15375
checkify_test: avoid passing argument to at[i].get()
2023-03-10 12:37:33 -08:00
Yash Katariya
418c2f9d2a
Rename in_axis_resources
and out_axis_resources
with in_shardings
and out_shardings
. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
...
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Sharad Vikram
c231171fb6
Fix checkify caching with nested call primitives
2023-02-03 23:28:37 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Matthew Johnson
684846bd0f
checkify: cache jaxpr formation so we don't always retrace
2023-02-01 10:19:47 -08:00
lenamartens
641b61b164
Checkify: Validate format arguments to make sure they're arrays
2023-01-25 10:03:07 +00:00
lenamartens
5ebd81f573
Fix scan and name_stack, rewrite cond to use jaxpr_to_checkify_jaxpr.
...
Fix map primitives, test pmap more.
2023-01-05 16:35:40 +00:00
lenamartens
0bce1cf129
Checkify: switch to initial-style.
2023-01-05 16:35:02 +00:00
lenamartens
3134797968
Add checkify.debug_check which is a noop outside of checkify.
2022-12-14 11:15:34 +00:00
Jake VanderPlas
f09fd8a4e9
[x64] minor test-only updates for better type safety
2022-11-30 15:18:40 -08:00
lenamartens
e4757e8410
Rewrite Checkify to support tracking different error types.
...
In general, behavior should remain the same and this is not a breaking
change.
There are some minor changes to the API:
- checkify.ErrorCategory has changed type: it's no longer an Enum, but
the JaxException type. These have not been exposed as part of the
public API.
- some attributes on Error have changed and made private
- The raised error has changed type (JaxRuntimeError), and will have a
different traceback (pointing to the origin of the error + where the
error value was raised).
- `checkify.check` now supports formating error message with variable
size runtime info!
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-11-25 15:31:54 +00:00
Lena Martens
3116ed52a9
Checkify: fix nan check when primitive has multiple results.
...
PiperOrigin-RevId: 488653856
2022-11-15 07:35:54 -08:00
Yash Katariya
c42bad85ef
Make MeshPspecSharding
an alias for NamedSharding
(it was the other way around before this CL).
...
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
lenamartens
053b8b5bcd
Checkify: fix nan_checks+PRNGKeys - a PRNGKey is never NaN!
...
Add a guard to the nan_error_rule to not call jnp.isnan on keys.
2022-11-09 17:08:21 +00:00
lenamartens
c2a00a0526
Disallow checkify-of-vmap-of-while.
2022-10-17 23:01:43 +01:00
Peter Hawkins
0d3277b5c3
Port more tests from jtu.cases_from_list to jtu.sample_product.
2022-10-11 21:06:08 +00:00
lenamartens
0639aced5b
Raise cond index into tracing context in case of effects.
...
So even if the cond is not data dependent at all, it's included in the
dynamic trace, and effects can be discharged.
2022-09-29 11:36:04 +01:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
lenamartens
27e3981d52
lowerable errors behind a config flag.
2022-09-26 17:34:27 +01:00
jax authors
9c66569514
Merge pull request #12468 from LenaMartens:checkify-but-better
...
PiperOrigin-RevId: 476901601
2022-09-26 08:23:02 -07:00
lenamartens
7078f81dd0
Checkify: misc improvements.
...
- err.throw == check_error(err) -> meaning they have the same behavior
under checkify now
- "divided by zero" -> "division by zero"
- add validation that check_error only takes args of type Error
2022-09-23 14:33:06 +01:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
lenamartens
018e700ead
Checkify: support batched while.
2022-09-20 17:59:46 +01:00
Lena Martens
b3393e3b60
Checkify: add jvp rule for assert_p.
...
PiperOrigin-RevId: 472961963
2022-09-08 05:26:02 -07:00
Yash Katariya
acdae7c237
Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0
test for now until I investigate.
...
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Yash Katariya
4fc3518e5f
Make checkify tests pass with Array and add methods on Array that are present on DA.
...
PiperOrigin-RevId: 468058909
2022-08-16 16:52:06 -07:00
lenamartens
1ace5d351b
Checkify: support checkify-of-pjit.
2022-07-29 19:25:22 +01:00
Lena Martens
740fe6926a
Checkify: add (checkify-of-)vmap-of-check.
2022-06-27 10:34:26 +01:00
Lena Martens
9167f7248a
Checkify: support discharging checks from control-flow through effects.
...
Currently supports scan and while-loop.
2022-06-20 18:28:03 +01:00
Jake VanderPlas
305f5e0491
[x64] make checkify compatible with strict dtype promotion
2022-06-17 12:48:20 -07:00
Jake VanderPlas
d2f80ef117
[x64] deprecate unsafe type casting in scatter-update operations
2022-06-09 15:21:49 -07:00