Skye Wanderman-Milne
cc5171034f
Add new config jax_persistent_cache_min_compile_time_secs
.
...
This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798 , since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).
I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt ) numbers from
name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344
Also adds new float config functionality.
2022-11-02 00:56:19 +00:00
jax authors
89b240ba02
Merge pull request #13012 from mattjj:rng-part-overgenerate
...
PiperOrigin-RevId: 484567918
2022-10-28 10:41:35 -07:00
Roy Frostig
c8b9280fb3
partitionable threefry PRNG random bits implementation
...
the cost is 2x overgeneration of bits
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-10-28 10:07:14 -07:00
Peter Hawkins
320d531521
Increase the minimum jaxlib version to 0.3.22.
...
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
jax authors
c1c8462371
Merge pull request #12798 from skye:cache_min_instr_count
...
PiperOrigin-RevId: 482349949
2022-10-19 17:54:03 -07:00
Skye Wanderman-Milne
81eb3fca55
Add new config jax_persistent_cache_min_instruction_count
.
...
This can be used to limit the number of entries written to the
persistent compilation cache.
I defaulted to setting 6 as the minimum threshold based on running the
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt ) and logging
the instruction counts and complilation time:
name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344
Also adds new int config functionality.
Fixes #12583
2022-10-20 00:17:24 +00:00
Nicholas Junge
efd61b73f6
Migrate JAX internals to builtin Python logging
...
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):
- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.
Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:
```py
import logging
logger = logging.getLogger(__name__)
logger.debug(...)
logger.info(...)
```
The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.
The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Skye Wanderman-Milne
15e5f38a16
Make persistent compilation cache warn instead of raise an error on cache read/write failures
...
Fixes #12582 . Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.
Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
jax authors
4f90af91d3
Remove unused jax_unique_mhlo_module_names flag.
...
PiperOrigin-RevId: 477778135
2022-09-29 11:32:22 -07:00
lenamartens
27e3981d52
lowerable errors behind a config flag.
2022-09-26 17:34:27 +01:00
lenamartens
78ecc1442c
Lowerable checks!!
2022-09-26 16:54:18 +01:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
c7f2712e74
Flip default value of jax_unique_mhlo_module_names to False.
...
This should help avoid unnecessary cache misses.
PiperOrigin-RevId: 475852954
2022-09-21 09:48:01 -07:00
Kuangyuan Chen
0400db959b
Introduce class PyArray that contains the data members of python Array.
...
A few key methods is implemented in C++ while the rest are still implmemented in python and added to the class later. A class decorator, @use_cpp_array, is added to add python methods to xc.Array.
PiperOrigin-RevId: 473075244
2022-09-08 13:48:28 -07:00
Yash Katariya
37089ec1b8
Add upgrade=True
to jax_array flag so that its marked as transient flag which will eventually be set to True.
2022-08-22 13:07:53 -07:00
Yash Katariya
78cfbebfba
Add config.jax_array to _trace_context so that in can be used in lu.cache key.
...
PiperOrigin-RevId: 468824719
2022-08-19 18:22:38 -07:00
Sharad Vikram
87e3898a9f
Set jax_eager_pmap
to True
...
PiperOrigin-RevId: 468265661
2022-08-17 12:31:56 -07:00
Matthew Johnson
d19e34fa4a
delete old remat implementation
...
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
Jake VanderPlas
a44fef4c70
Fix JIT cacheing context defaults
2022-08-16 09:30:14 -07:00
Sharad Vikram
fe040cc01e
Cleaning up eager pmap implementation
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 11:10:16 -07:00
Matthew Johnson
be6f6bfe9f
set new jax.remat / jax.checkpoint to be on-by-default
2022-08-10 10:29:38 -07:00
Matthew Johnson
cbcfe95e80
fix ad_checkpoint.checkpoint caching issue
...
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
jax authors
97a9b12790
Turn on coordination service by default for all JAX users.
...
Coordination service is the new implementation of JAX's distributed service. The API remains the same, and eventually will be expanded for newer features such as error reporting.
PiperOrigin-RevId: 463870794
2022-07-28 10:34:07 -07:00
jax authors
be6db2e619
Merge pull request #10775 from pschuh:mlir-caching
...
PiperOrigin-RevId: 462263487
2022-07-20 17:10:40 -07:00
George Necula
ee50140701
[jax2tf] A new experimental version with JAX native lowering.
...
In the future JAX will be able to use a serialization format
based on a variant of MHLO. This is not yet ready, but in this PR
we are starting to get jax2tf ready for this. As a temporary
step, we had introduced a TF op called XlaCallModule which carries
a serialized MHLO module and which e can use to wrap the JAX native
MHLO as a TF op. We still reuse parts of jax2tf, in particular
the gradient machinery.
This functionality can be enabled locally with a
`experimental_native_lowering` flag for `jax2tf.convert`, or
globally with the flag `--jax2tf_default_experimental_native_lowering`.
2022-07-19 10:50:04 +02:00
Parker Schuh
704f125c88
Add caching to trace_to_subjaxpr_dynamic2
.
...
This allows the MLIR lowering code to cache call lowerings.
example output:
```
module @jit_fun.0 {
func.func public @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = call @square(%arg0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
%1 = call @square(%0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
return %1 : tensor<4x8xf32>
}
func.func private @square(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = mhlo.multiply %arg0, %arg0 : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
}
```
If / when jaxprs support recursion, this approach will still work because the mlir lowering cache operates on Jaxpr object identity.
2022-07-18 17:51:05 -07:00
jax authors
d98d5ddce5
[JAX] Add jax_unique_mhlo_module_names
flag to control if MHLO should be made unique.
...
Some clients of JAX expect module names to not be altered so that they can cache XLA compilations.
PiperOrigin-RevId: 461648129
2022-07-18 10:05:44 -07:00
Peter Hawkins
88c1e7dce2
Flip after_neurips flag to True.
...
PiperOrigin-RevId: 459541278
2022-07-07 10:12:15 -07:00
George Necula
5983d385da
[dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
...
Also add more tests.
2022-07-05 12:14:15 +03:00
Parker Schuh
6c5d204d7e
Jax caches should depend on axis env.
2022-06-29 14:25:14 -07:00
Peter Hawkins
dbae3e5ed1
Remove long-deprecated omnistaging flag.
...
PiperOrigin-RevId: 457794581
2022-06-28 12:35:08 -07:00
jax authors
406a61cf52
Merge pull request #11146 from sshahrokhi:AbortIfNotInitialized
...
PiperOrigin-RevId: 457115405
2022-06-24 16:24:57 -07:00
Shiva Shahrokhi
df8c6263de
Change JAX_PLATFORMS to raise an exception when platform initialization fails
2022-06-24 21:54:53 +00:00
Ian McKenzie
0cc2ada432
Fix broken links for moved design_notes folder
2022-06-24 12:18:11 -07:00
Kuangyuan Chen
dc1c519547
Reduce jax.jit dispatch overhead by avoiding directly comparing python objects
...
Previously the thread local state might be updated, leading to expensive python compare logic during compilation cache lookup. This CL adds a thread local cache for the state.
PiperOrigin-RevId: 456667829
2022-06-22 20:04:40 -07:00
Skye Wanderman-Milne
7098088f4e
Add jax.config.jax_default_device to jax in-memory cache key
...
This fixes a case where we'd get a cache hit when evaluating a
primitive (e.g. jnp.ones) even if the default device was changed,
causing the default device to not take effect.
PiperOrigin-RevId: 454986939
2022-06-14 16:44:19 -07:00
Hyeontaek Lim
4f8881539d
Make the transfer guard documentation easier to find
...
Move the main documentation of the transfer guard from "Notes" to "Reference documentation" section for better visibility.
Add a link to the main documentation to the docstring of jax.transfer_guard(), which currently shows up as the top result when searching for "jax transfer_guard".
2022-06-09 18:13:03 +00:00
jax authors
6c89e90808
Allow JAX OSS users to switch between experimental coordination service and default PjRT distributed runtime service via a flag.
...
PiperOrigin-RevId: 452625054
2022-06-02 14:41:40 -07:00
jax authors
ea54754c49
Merge pull request #9118 from skye:device_context_manager
...
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
Sharad Vikram
a29f96b01a
Enable new name stack by default
...
PiperOrigin-RevId: 452407951
2022-06-01 16:18:34 -07:00
Jean-Baptiste Lespiau
bab8520d0c
Initialize the thread-local compilation context when undefined in new threads.
...
PiperOrigin-RevId: 452119314
2022-05-31 12:57:48 -07:00
Tianjian Lu
ae9f9f77ee
[sparse] Set jax_bcoo_cusparse_lowering
default to true.
...
PiperOrigin-RevId: 451314487
2022-05-26 22:23:00 -07:00
Jake VanderPlas
ceae6fe5e2
Add jax_numpy_dtype_promotion='strict' mode
2022-05-26 10:56:09 -07:00
Sharad Vikram
c6c230eb48
Enable new name stack by default
...
PiperOrigin-RevId: 450837809
2022-05-24 21:23:32 -07:00
Sharad Vikram
30a506bf72
Enable new name stack by default
...
PiperOrigin-RevId: 450566748
2022-05-23 18:01:31 -07:00
jax authors
6110be40dc
Merge pull request #10678 from JeppeKlitgaard:precommit-pyupgrade
...
PiperOrigin-RevId: 449561541
2022-05-18 13:25:52 -07:00
Yash Katariya
633ab51a8e
Add a new flag to prepare for the new pjit behavior and introduction of jax.Array.
...
PiperOrigin-RevId: 449326394
2022-05-17 15:11:35 -07:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Sharad Vikram
b8a523f977
Enable colors when we are using a terminal or IPython
2022-05-14 14:54:11 -07:00
Roy Frostig
a62ca21b15
use upgrade
option in defining our latest upgrade flag
...
PiperOrigin-RevId: 447783540
2022-05-10 11:25:23 -07:00