16981 Commits

Author SHA1 Message Date
Frederic Bastien
f212634454 Add a link to JAX Toolbox containers 2023-08-10 12:38:25 -07:00
John QiangZhang
cf026ce745 Instead of tf.Graph protobuf, we switch to tf.Saved_model for back_compat_tf_test.
PiperOrigin-RevId: 555500398
2023-08-10 08:22:34 -07:00
John QiangZhang
1ddc340a1a Fix typo here. 'gpu' should be 'cuda' or 'rocm' here.
PiperOrigin-RevId: 555495314
2023-08-10 08:08:10 -07:00
George Necula
c3aa3a4c31 [jax2tf] Disable some graph serialization tests on GPU
We recently increased the test coverage of testing for dot_general with different dtype for lhs and rhs. Some of the new combinations of dtypes are not supported by XLA:GPU, and we disable those tests now.

PiperOrigin-RevId: 555465495
2023-08-10 06:30:40 -07:00
Malcolm Reynolds
7012a05497 Rollback, breaks internal project
Reverts 6b8bb7bd5990c5207c8b4f793f8ce0702060c8da

PiperOrigin-RevId: 555455350
2023-08-10 05:39:36 -07:00
jax authors
eb076c4c44 Explicitly set AutoFDO profile version in CompileOptions.
Set the AutoFDO profile version specified in --jax_xla_profile_version
if non-zero. Otherwise, expect that there is a function set in
get_latest_profile_version that will return a non-zero profile version
that should be used. If this function is not set or it returns 0,
set -1 instead to indicate that no attempt should be made to retrieve
an AutoFDO profile later on.

Testing: updated unit tests.
PiperOrigin-RevId: 555333728
2023-08-09 18:24:56 -07:00
jax authors
aac4cdad56 Set up plumbing for adding new compilation-cache-key generation algorithm.
The new cache-key generation algorithm will coexist with the original
version until the new one is fully deployed. While they coexist,
--jax_use_original_compilation_cache_key_generation will determine which
one is used. Once the new algorithm is deployed, the original algorithm
and this flag will be removed.

This change sets up the plumbing. Later changes will implement the new
algorithm.

Testing: test workload.
PiperOrigin-RevId: 555333628
2023-08-09 18:16:22 -07:00
Jevin Jiang
5d8f5c20fa [Mosaic] Remove multiple results check in apply layout.
PiperOrigin-RevId: 555320679
2023-08-09 17:17:25 -07:00
Parker Schuh
74bcd65bbd Make mesh available to custom_partitioning lowering rules.
PiperOrigin-RevId: 555319896
2023-08-09 17:08:57 -07:00
Hyeontaek Lim
97b96bbd4b [JAX] Introduce DeviceList backed by C++ xla::ifrt::DeviceList
This change adds `xla_client.DeviceList` that is implemented in C++
`jax::PyDeviceList`. `jax::PyDeviceList` implements the features of
`pxla._DeviceAssignment` as a functional drop-in replacement.
`jax::PyDeviceList` internally has `xla::ifrt::DeviceList`, which will be used
when using IFRT APIs without having to construct a new copy of a potentially
large device list.

`pxla._DeviceAssignment`'s interface is changed slightly to encourage avoiding
conversion to tuple.

Note that for the backward compatibility (and fast `xla_client.Device`
conversion), `jax::PyDeviceList` still uses a Python tuple whose element can be
any Python object matches `xla_client.Device` interface with duck typing. This
duck typing support will be removed when such use case is deprecated.
Eventually, we can try to avoid any type conversion to remove a shadow copy of
device list in JAX.

PiperOrigin-RevId: 555317152
2023-08-09 16:58:01 -07:00
jax authors
22a005c2a3 Support (u)int4 for jax.dtypes.scalar_type_of()
PiperOrigin-RevId: 555265053
2023-08-09 13:52:04 -07:00
jax authors
1e32fd598d Merge pull request #17042 from cottrell:me
PiperOrigin-RevId: 555226539
2023-08-09 11:42:57 -07:00
Peter Hawkins
9529ed0f4d Remove workarounds for MLIR constant construction.
* https://reviews.llvm.org/D155209 added support to the MLIR Python bindings for passing types like bfloat16 directly if an explicit IR type is provided.
* the crash for non-splat size 1 constants appears fixed at head, although I don't know which change fixed it.

PiperOrigin-RevId: 555225604
2023-08-09 11:34:28 -07:00
jax authors
6b8bb7bd59 avoid _multi_slice for the broadcast of fully replicated arrays
PiperOrigin-RevId: 555220204
2023-08-09 11:17:54 -07:00
jax authors
4ecc85b691 Merge pull request #17047 from skye:tpu_ci_pjrt
PiperOrigin-RevId: 555209384
2023-08-09 10:51:53 -07:00
jax authors
45c1af2c61 Add functions to unregister event duration listeners.
Add private functions _unregister_event_duration_listener_by_callback and _unregister_event_duration_listener_by_index to remove registered event duration listeners. The functions are supposed to be called in test only.

PiperOrigin-RevId: 555208764
2023-08-09 10:43:02 -07:00
Skye Wanderman-Milne
0f30685dac Remove StreamExecutor-based TPU runtime from Cloud TPU CI
The old StreamExecutor-based backend is no longer supported as of
3e50fea29e
2023-08-09 10:05:46 -07:00
jax authors
be543f020d Merge pull request #17041 from mtsokol:update-ninf-usage
PiperOrigin-RevId: 555185385
2023-08-09 09:26:12 -07:00
David Cottrell
40d0d40b6c Fix for log-normal prior in example. 2023-08-09 13:38:37 +01:00
Mateusz Sokół
1fedf04ed5 API: Remove NINF and PINF usages 2023-08-09 14:16:33 +02:00
Peter Hawkins
c9cf6b4423 Remove allowlist for multihost collectives.
This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more.

PiperOrigin-RevId: 555124144
2023-08-09 04:43:51 -07:00
Yash Katariya
1bd5fd2a52 Add serialize_with_paths and deserialize_with_paths API to GlobalAsyncCheckpointManager
PiperOrigin-RevId: 555050522
2023-08-08 22:47:27 -07:00
jax authors
6615e23505 Merge pull request #17032 from jakevdp:extended-doc
PiperOrigin-RevId: 554999499
2023-08-08 18:02:40 -07:00
Bharath Ranganatha Mankalale
ca924cde2d Add visibility to jax2tf_internals.
PiperOrigin-RevId: 554994907
2023-08-08 17:40:01 -07:00
Jake VanderPlas
1d27576cfe jax.dtypes.extended: fix docstring example 2023-08-08 16:08:45 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
jax authors
d01695c746 Change --jax_xla_profile_version definition to config.
Changing the flag to a config permits more contained testing.
This is in preparation for an upcoming change to incorporate
AutoFDO profile versions in the cache key.

Testing: test workload.
PiperOrigin-RevId: 554942573
2023-08-08 14:29:09 -07:00
Skye Wanderman-Milne
3e50fea29e Remove option to use StreamExecutor Cloud TPU client in JAX
It's been over three months since the new PJRT C API client was
enabled by default
(https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023).

PiperOrigin-RevId: 554935166
2023-08-08 14:05:27 -07:00
Peter Hawkins
f05f197874 Reverts changelist 554885148
PiperOrigin-RevId: 554930183
2023-08-08 13:50:03 -07:00
Jake Vanderplas
d8f799391b COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Peter Hawkins
6b07f5b32d Use core.valid_jaxtype() in xla.check_arg().
Fixes a TODO.

PiperOrigin-RevId: 554885023
2023-08-08 11:21:09 -07:00
Majid Dadashi
f37afb1ee1 Add JAX primitives test suite in TFLite.
PiperOrigin-RevId: 554883170
2023-08-08 11:12:26 -07:00
Peter Hawkins
e58f1ba86e Move some utilities out of dispatch.py next to their users, add more types.
Internal cleanups only, no user-visible changes intended.

PiperOrigin-RevId: 554876522
2023-08-08 10:52:11 -07:00
Peter Hawkins
afd56c15d9 Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.

PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Peter Hawkins
b024e01440 Improve dispatch.py typing.
Inline _xla_callable_uncached, which is trivial, into its only caller.

Cleanup only, no user-visible changes intended.

PiperOrigin-RevId: 554805210
2023-08-08 06:34:34 -07:00
John QiangZhang
dec2366c16 Create the failure test when tf.SavedModel miss the XLACallModule function_list after loading.
PiperOrigin-RevId: 554726455
2023-08-08 00:50:50 -07:00
George Necula
d17addea0b [jax2tf] Adjust tolerance for jax2tf graph serialization GPU tests.
Should fix test flakyness.

PiperOrigin-RevId: 554704965
2023-08-07 23:04:30 -07:00
jax authors
5efc681702 Cleanup comments for define_{int,float}_state.
There is no parameter enum_values in these functions.
Probably a copy/paste issue from define_enum_state.

PiperOrigin-RevId: 554644871
2023-08-07 17:41:09 -07:00
Ce Zheng
b80498874a [XLA:Client] Make HloSharding::iota_tile actually produce V2 shardings.
PiperOrigin-RevId: 554631780
2023-08-07 16:46:53 -07:00
jax authors
c9498692b6 Merge pull request #17015 from mattjj:hypothesis-skip
PiperOrigin-RevId: 554625928
2023-08-07 16:23:44 -07:00
Matthew Johnson
cdb946bf3f [pallas] skip indexing tests when hypothesis not available
Co-authored-by: Roy Frostig <frostig@google.com>
2023-08-07 15:28:50 -07:00
Peter Hawkins
c879f65aa6 [JAX] Remove the non-coordination service distributed service implementation from JAX.
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.

PiperOrigin-RevId: 554608165
2023-08-07 15:17:25 -07:00
jax authors
22285e69fb Merge pull request #16971 from apaszke:pallas-tpu-docs
PiperOrigin-RevId: 554587987
2023-08-07 14:10:07 -07:00
Sharad Vikram
bf8e550ad1 [Pallas] Flatten in_specs for PrefetchScalarGridSpec
PiperOrigin-RevId: 554574437
2023-08-07 13:28:44 -07:00
jax authors
9d67fb7e98 Merge pull request #16987 from sharadmv:pallas-docs
PiperOrigin-RevId: 554568088
2023-08-07 13:15:21 -07:00
Antonio Sanchez
a600020346 Update ducc to commit: 2b2cead005e08d2632478e831d7f45da754162dc
NOTE: this version of DUCC has a breaking change, where the fft.h header
no longer contains the definitions of many fft functions - instead they exist
within fft1d_impl.h and fftnd_impl.h.
PiperOrigin-RevId: 554567641
2023-08-07 13:06:43 -07:00
jax authors
079ecfbf20 Allow device mesh handler to return None and use the default logic
PiperOrigin-RevId: 554563151
2023-08-07 12:50:43 -07:00
jax authors
915241413b Merge pull request #17007 from jakevdp:callbacks-doc
PiperOrigin-RevId: 554559029
2023-08-07 12:35:48 -07:00
jax authors
e21945661f Merge pull request #16972 from mtsokol:update-np-exceptions-imports
PiperOrigin-RevId: 554548376
2023-08-07 11:58:59 -07:00
Yash Katariya
2fa0bb0d32 Add initialization_timeout as a parameter to allow users to increase/decreases the init_timeout parameter.
PiperOrigin-RevId: 554545535
2023-08-07 11:49:41 -07:00