Yash Katariya
a32ed7e002
Bump shard_count for shard_map_test to fix the asan failures
...
PiperOrigin-RevId: 569520202
2023-09-29 10:02:38 -07:00
Matthew Johnson
a9dc3c1ea3
[shard_map] internal change to shard_map CI testing
...
PiperOrigin-RevId: 569036873
2023-09-27 20:06:24 -07:00
Hyeontaek Lim
f0bde75dd3
[JAX] Export shard_map_test for testing on additional JAX backends
...
PiperOrigin-RevId: 567522898
2023-09-21 22:52:36 -07:00
Sharad Vikram
afb6691885
Disable msan/tsan for xmap_tests thanks to timeouts
...
PiperOrigin-RevId: 567412260
2023-09-21 14:24:08 -07:00
Junwhan Ahn
6a551a1efa
Add memories_test.py to the list of exported tests
...
PiperOrigin-RevId: 567375604
2023-09-21 11:57:09 -07:00
Peter Hawkins
f863cfbaad
Relax some test tolerances to fix failures on Linux aarch64.
...
PiperOrigin-RevId: 565930178
2023-09-16 06:55:22 -07:00
Peter Hawkins
bbfba9ace8
Remove code that disabled tests on "stream_executor" backends.
...
These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more.
Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices.
PiperOrigin-RevId: 565367453
2023-09-14 07:52:43 -07:00
Peter Hawkins
306c60d4c7
Remove references to deprecated "tpu_se" build configuration.
...
PiperOrigin-RevId: 565156675
2023-09-13 14:10:30 -07:00
Yash Katariya
76a5dc3cac
Move memories_test.py to JAX
...
PiperOrigin-RevId: 564551723
2023-09-11 17:41:55 -07:00
George Necula
660a015652
[export] Move jax_export and shape_poly out of jax2tf.
...
Those modules have been developed initially for jax2tf
but they do not depend on TF anymore. They are used for JAX
native serialization. We move them under
jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF.
We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later.
PiperOrigin-RevId: 562988740
2023-09-05 22:15:59 -07:00
George Necula
e0a6230214
[host_callback] Delete unused code paths.
...
This is part of deprecating host_callback and moving to io_callback.
PiperOrigin-RevId: 561851494
2023-08-31 22:08:23 -07:00
Roy Frostig
a71c0e6ecc
create jax.extend.random
as a copy of jax.prng
...
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00
Richard Levasseur
f891cbf64b
Load Python rules from rules_python
...
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
George Necula
26f091e446
[callback] Disable stream_executor tests.
...
PiperOrigin-RevId: 559252832
2023-08-22 16:15:00 -07:00
Peter Hawkins
a259df0d76
Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
...
Refactoring only, no user-visible changes intended.
PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
Jake VanderPlas
ad8e719b82
Add jnp.ufunc and jnp.frompyfunc
2023-08-10 14:58:18 -07:00
Peter Hawkins
afd56c15d9
Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
...
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.
PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Adam Paszke
0228bf7d3c
Fix MSAN errors in cache_key_test
...
The device_assignment array was never initialized, causing MSAN errors.
Replacing it with np.arange fixes the issue.
PiperOrigin-RevId: 553469463
2023-08-03 07:28:32 -07:00
jax authors
9e7502ce60
Disable MSAN testing for cache-key unit tests.
...
This is an existing issue with the compilation cache tests.
The refactoring of the cache key generation part into a separate
file requires tagging the refactored tests also.
PiperOrigin-RevId: 551972670
2023-07-28 16:02:12 -07:00
Skye Wanderman-Milne
8b58e38ec5
Add jax_debug_log_modules
config option.
...
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-07-28 18:11:12 +00:00
jax authors
3b28d4e180
Refactor Jax compilation cache key generation.
...
This is in preparation for introducing a more robust key-generation
algorithm.
This refactoring does not introduce any change in behavior.
Testing: refactored unit tests and test workload.
PiperOrigin-RevId: 551744892
2023-07-27 23:01:00 -07:00
Sharad Vikram
3d556b7a19
Add Mosaic to Jaxlib and expose bindings in jax.experimental.mosaic
...
PiperOrigin-RevId: 549801858
2023-07-20 18:28:51 -07:00
Tao Wang
b7686f41aa
Enable passing fdo_profile in compiler_options in pxla.py
...
PiperOrigin-RevId: 549109629
2023-07-18 14:18:28 -07:00
Peter Hawkins
3fcf72af32
Disable tests that are flaky in CI.
...
PiperOrigin-RevId: 547464866
2023-07-12 05:13:38 -07:00
jax authors
2fa6a9c9bf
Allow other backends to run the array_test.py
test.
...
PiperOrigin-RevId: 547191886
2023-07-11 08:05:25 -07:00
Chris Flesher
5be17ed90c
Added scipy.spatial.transform Rotation and Slerp classes
2023-06-08 07:51:32 -05:00
Peter Hawkins
32026ad18b
Disable random_test_with_custom_prng on CPU under msan.
...
This test flakily times out in CI.
PiperOrigin-RevId: 535293997
2023-05-25 10:10:01 -07:00
Peter Hawkins
1d20d2f301
Increase sharding of host_callback_test on TPU to fix CI flakiness.
...
PiperOrigin-RevId: 533451822
2023-05-19 07:44:53 -07:00
Peter Hawkins
e6628e2e72
Disable tests that time out in CI.
...
PiperOrigin-RevId: 532792740
2023-05-17 08:16:07 -07:00
Peter Hawkins
9471bb3045
Disable sparsify_test on CPU under tsan.
...
Under tsan this test times out in CI.
PiperOrigin-RevId: 531210930
2023-05-11 08:33:35 -07:00
Peter Hawkins
e8c735125c
Disable more tests that are flaky in CI.
...
PiperOrigin-RevId: 529724306
2023-05-05 08:33:33 -07:00
pizzud
40d730be49
aot_test: Stop forcing XLA to assume a certain number of devices.
...
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.
PiperOrigin-RevId: 529437419
2023-05-04 09:53:26 -07:00
Peter Hawkins
09fce87f54
Increase sharding of or disable some flaky CI tests.
...
PiperOrigin-RevId: 529405705
2023-05-04 07:41:56 -07:00
Peter Hawkins
57e62ca03c
Reenable scipy_stats_test in CI.
...
Disable testTruncNormPdf on CPU, which is failing after an LLVM update.
PiperOrigin-RevId: 528884880
2023-05-02 14:11:08 -07:00
Skye Wanderman-Milne
70cac773f7
Exclude scipy_fft_test from msan as well as t/asan.
...
PiperOrigin-RevId: 528562775
2023-05-01 13:42:24 -07:00
Skye Wanderman-Milne
fa68c1f882
Bump up lax_test TPU sharding to avoid asan timeouts
...
PiperOrigin-RevId: 528559870
2023-05-01 13:31:22 -07:00
Skye Wanderman-Milne
c662fd216d
Disable tsan CI for random_test_with_custom_prng to avoid timeouts.
...
asan is already disabled, and the comment and "cpu" case indicates
that tsan should already have been disabled as well.
PiperOrigin-RevId: 528000458
2023-04-28 15:26:46 -07:00
Skye Wanderman-Milne
67d80c21cb
Increase sharding count on nn_test and svd_test to avoid ASAN timeouts.
...
PiperOrigin-RevId: 527387005
2023-04-26 15:11:29 -07:00
Jake VanderPlas
1c7f8efce6
Add test framework for module attribute
2023-04-21 13:20:16 -07:00
jax authors
db2cbd4ae8
Merge pull request #15665 from hawkinsp:sourceinfo
...
PiperOrigin-RevId: 525581713
2023-04-19 16:30:23 -07:00
Peter Hawkins
3bb7386149
[JAX] Improve handling of metadata in compilation cache.
...
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.
There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.
This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.
PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
a3b262c379
Use the traceback of the call site when assigning a source location to an inlined function.
...
Improves but does not completely fix https://github.com/google/jax/issues/15663 . The non-inlined case still has similar problems.
2023-04-19 13:56:53 -04:00
Peter Hawkins
017548c40b
Move implementation of compilation cache out of jax/experimental and into jax/_src.
...
Use a Protocol instead of an abstract base class for the CacheInterface since it allows us to use one fewer file.
No functional change intended.
PiperOrigin-RevId: 524855263
2023-04-17 08:35:53 -07:00
Jake VanderPlas
8c8f50f688
Fix tolerance and shard_count for experimental_rnn_test
...
This should fix the current GPU test timeout.
PiperOrigin-RevId: 522167894
2023-04-05 15:19:19 -07:00
jax authors
3c1f3abba2
Merge pull request #15149 from sharadmv:runstate
...
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
Parker Schuh
c2b15a1eb8
Break out aot_test from array_test (for serialization and other aot APIs).
...
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00
Peter Hawkins
47177e1417
Split more targets out the main JAX Bazel target.
...
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Skye Wanderman-Milne
4cb3b011a0
Remove PJRT C API bypass.
...
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.
PiperOrigin-RevId: 519018438
2023-03-23 18:39:14 -07:00
Jieying Luo
b403c2a083
[PJRT C API] Add parsing PJRT client create options from json file.
...
PiperOrigin-RevId: 518418760
2023-03-21 16:57:34 -07:00
Yash Katariya
207cc10058
Error if jax_array or jax_jit_pjit_api_merge is set to False.
...
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00