94 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
cc5171034f Add new config jax_persistent_cache_min_compile_time_secs.
This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798, since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).

I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) numbers from

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new float config functionality.
2022-11-02 00:56:19 +00:00
jax authors
89b240ba02 Merge pull request #13012 from mattjj:rng-part-overgenerate
PiperOrigin-RevId: 484567918
2022-10-28 10:41:35 -07:00
Roy Frostig
c8b9280fb3 partitionable threefry PRNG random bits implementation
the cost is 2x overgeneration of bits

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-10-28 10:07:14 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
jax authors
c1c8462371 Merge pull request #12798 from skye:cache_min_instr_count
PiperOrigin-RevId: 482349949
2022-10-19 17:54:03 -07:00
Skye Wanderman-Milne
81eb3fca55 Add new config jax_persistent_cache_min_instruction_count.
This can be used to limit the number of entries written to the
persistent compilation cache.

I defaulted to setting 6 as the minimum threshold based on running the
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) and logging
the instruction counts and complilation time:

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new int config functionality.

Fixes #12583
2022-10-20 00:17:24 +00:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Skye Wanderman-Milne
15e5f38a16 Make persistent compilation cache warn instead of raise an error on cache read/write failures
Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.

Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
jax authors
4f90af91d3 Remove unused jax_unique_mhlo_module_names flag.
PiperOrigin-RevId: 477778135
2022-09-29 11:32:22 -07:00
lenamartens
27e3981d52 lowerable errors behind a config flag. 2022-09-26 17:34:27 +01:00
lenamartens
78ecc1442c Lowerable checks!! 2022-09-26 16:54:18 +01:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
c7f2712e74 Flip default value of jax_unique_mhlo_module_names to False.
This should help avoid unnecessary cache misses.

PiperOrigin-RevId: 475852954
2022-09-21 09:48:01 -07:00
Kuangyuan Chen
0400db959b Introduce class PyArray that contains the data members of python Array.
A few key methods is implemented in C++ while the rest are still implmemented in python and added to the class later. A class decorator, @use_cpp_array, is added to add python methods to xc.Array.

PiperOrigin-RevId: 473075244
2022-09-08 13:48:28 -07:00
Yash Katariya
37089ec1b8
Add upgrade=True to jax_array flag so that its marked as transient flag which will eventually be set to True. 2022-08-22 13:07:53 -07:00
Yash Katariya
78cfbebfba Add config.jax_array to _trace_context so that in can be used in lu.cache key.
PiperOrigin-RevId: 468824719
2022-08-19 18:22:38 -07:00
Sharad Vikram
87e3898a9f Set jax_eager_pmap to True
PiperOrigin-RevId: 468265661
2022-08-17 12:31:56 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
Jake VanderPlas
a44fef4c70 Fix JIT cacheing context defaults 2022-08-16 09:30:14 -07:00
Sharad Vikram
fe040cc01e Cleaning up eager pmap implementation
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 11:10:16 -07:00
Matthew Johnson
be6f6bfe9f set new jax.remat / jax.checkpoint to be on-by-default 2022-08-10 10:29:38 -07:00
Matthew Johnson
cbcfe95e80 fix ad_checkpoint.checkpoint caching issue
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
2022-07-29 19:59:28 -07:00
jax authors
97a9b12790 Turn on coordination service by default for all JAX users.
Coordination service is the new implementation of JAX's distributed service. The API remains the same, and eventually will be expanded for newer features such as error reporting.

PiperOrigin-RevId: 463870794
2022-07-28 10:34:07 -07:00
jax authors
be6db2e619 Merge pull request #10775 from pschuh:mlir-caching
PiperOrigin-RevId: 462263487
2022-07-20 17:10:40 -07:00
George Necula
ee50140701 [jax2tf] A new experimental version with JAX native lowering.
In the future JAX will be able to use a serialization format
based on a variant of MHLO. This is not yet ready, but in this PR
we are starting to get jax2tf ready for this. As a temporary
step, we had introduced a TF op called XlaCallModule which carries
a serialized MHLO module and which e can use to wrap the JAX native
MHLO as a TF op. We still reuse parts of jax2tf, in particular
the gradient machinery.

This functionality can be enabled locally with a
`experimental_native_lowering` flag for `jax2tf.convert`, or
globally with the flag `--jax2tf_default_experimental_native_lowering`.
2022-07-19 10:50:04 +02:00
Parker Schuh
704f125c88 Add caching to trace_to_subjaxpr_dynamic2.
This allows the MLIR lowering code to cache call lowerings.

example output:

```
module @jit_fun.0 {
  func.func public @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = call @square(%arg0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    %1 = call @square(%0) : (tensor<4x8xf32>) -> tensor<4x8xf32>
    return %1 : tensor<4x8xf32>
  }
  func.func private @square(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
    %0 = mhlo.multiply %arg0, %arg0 : tensor<4x8xf32>
    return %0 : tensor<4x8xf32>
  }
}
```

If / when jaxprs support recursion, this approach will still work because the mlir lowering cache operates on Jaxpr object identity.
2022-07-18 17:51:05 -07:00
jax authors
d98d5ddce5 [JAX] Add jax_unique_mhlo_module_names flag to control if MHLO should be made unique.
Some clients of JAX expect module names to not be altered so that they can cache XLA compilations.

PiperOrigin-RevId: 461648129
2022-07-18 10:05:44 -07:00
Peter Hawkins
88c1e7dce2 Flip after_neurips flag to True.
PiperOrigin-RevId: 459541278
2022-07-07 10:12:15 -07:00
George Necula
5983d385da [dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
Also add more tests.
2022-07-05 12:14:15 +03:00
Parker Schuh
6c5d204d7e Jax caches should depend on axis env. 2022-06-29 14:25:14 -07:00
Peter Hawkins
dbae3e5ed1 Remove long-deprecated omnistaging flag.
PiperOrigin-RevId: 457794581
2022-06-28 12:35:08 -07:00
jax authors
406a61cf52 Merge pull request #11146 from sshahrokhi:AbortIfNotInitialized
PiperOrigin-RevId: 457115405
2022-06-24 16:24:57 -07:00
Shiva Shahrokhi
df8c6263de Change JAX_PLATFORMS to raise an exception when platform initialization fails 2022-06-24 21:54:53 +00:00
Ian McKenzie
0cc2ada432 Fix broken links for moved design_notes folder 2022-06-24 12:18:11 -07:00
Kuangyuan Chen
dc1c519547 Reduce jax.jit dispatch overhead by avoiding directly comparing python objects
Previously the thread local state might be updated, leading to expensive python compare logic during compilation cache lookup. This CL adds a thread local cache for the state.

PiperOrigin-RevId: 456667829
2022-06-22 20:04:40 -07:00
Skye Wanderman-Milne
7098088f4e Add jax.config.jax_default_device to jax in-memory cache key
This fixes a case where we'd get a cache hit when evaluating a
primitive (e.g. jnp.ones) even if the default device was changed,
causing the default device to not take effect.

PiperOrigin-RevId: 454986939
2022-06-14 16:44:19 -07:00
Hyeontaek Lim
4f8881539d Make the transfer guard documentation easier to find
Move the main documentation of the transfer guard from "Notes" to "Reference documentation" section for better visibility.

Add a link to the main documentation to the docstring of jax.transfer_guard(), which currently shows up as the top result when searching for "jax transfer_guard".
2022-06-09 18:13:03 +00:00
jax authors
6c89e90808 Allow JAX OSS users to switch between experimental coordination service and default PjRT distributed runtime service via a flag.
PiperOrigin-RevId: 452625054
2022-06-02 14:41:40 -07:00
jax authors
ea54754c49 Merge pull request #9118 from skye:device_context_manager
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
Sharad Vikram
a29f96b01a Enable new name stack by default
PiperOrigin-RevId: 452407951
2022-06-01 16:18:34 -07:00
Jean-Baptiste Lespiau
bab8520d0c Initialize the thread-local compilation context when undefined in new threads.
PiperOrigin-RevId: 452119314
2022-05-31 12:57:48 -07:00
Tianjian Lu
ae9f9f77ee [sparse] Set jax_bcoo_cusparse_lowering default to true.
PiperOrigin-RevId: 451314487
2022-05-26 22:23:00 -07:00
Jake VanderPlas
ceae6fe5e2 Add jax_numpy_dtype_promotion='strict' mode 2022-05-26 10:56:09 -07:00
Sharad Vikram
c6c230eb48 Enable new name stack by default
PiperOrigin-RevId: 450837809
2022-05-24 21:23:32 -07:00
Sharad Vikram
30a506bf72 Enable new name stack by default
PiperOrigin-RevId: 450566748
2022-05-23 18:01:31 -07:00
jax authors
6110be40dc Merge pull request #10678 from JeppeKlitgaard:precommit-pyupgrade
PiperOrigin-RevId: 449561541
2022-05-18 13:25:52 -07:00
Yash Katariya
633ab51a8e Add a new flag to prepare for the new pjit behavior and introduction of jax.Array.
PiperOrigin-RevId: 449326394
2022-05-17 15:11:35 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Sharad Vikram
b8a523f977 Enable colors when we are using a terminal or IPython 2022-05-14 14:54:11 -07:00
Roy Frostig
a62ca21b15 use upgrade option in defining our latest upgrade flag
PiperOrigin-RevId: 447783540
2022-05-10 11:25:23 -07:00