134 Commits

Author SHA1 Message Date
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Matthew Johnson
d42350f879 disable custom_jvp for softmax by default
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
2023-05-23 11:56:50 -07:00
Peter Hawkins
ee14ca2628 Add option jax_include_full_tracebacks_in_locations.
If enabled, includes full stack traces in MLIR emitted by JAX. These cannot be consumed by XLA at the moment.

PiperOrigin-RevId: 534060827
2023-05-22 07:41:29 -07:00
Peter Hawkins
26f2711aeb Fix typo in config.py.
Fixes #16066
2023-05-19 10:07:12 -04:00
Yash Katariya
34d5a6259f Default jax_spmd_mode to allow_jit which will allow explicit jax.jit to not raise the multihost error (since jit and pjit have been merged).
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.

PiperOrigin-RevId: 527398394
2023-04-26 15:56:46 -07:00
Matthew Johnson
e0d2736e37 add custom_jvp for jax.nn.softmax
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
Yash Katariya
3722d7066a Add jax_pmap_shmap_merge flag to begin the process of merging pmap and shard_map
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).

TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
2023-04-20 21:22:48 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
George Necula
2ac2dc65b1 Remove jax2tf experimental_native_lowering.
Users should use native_serialization.

PiperOrigin-RevId: 520063928
2023-03-28 10:17:58 -07:00
Matthew Johnson
6b4262d9f6 add experimental jax_log_checkpoint_residuals option
The main idea here is to improve tooling for knowing what residuals are being
saved and why. There's a lot more that can be done here (e.g. naming the
arguments, explaining what JVP rule produced these residuals, explaining what
consumed them, etc) but this is a start.

Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2023-03-22 16:26:56 -07:00
Parker Schuh
e89235ffdd Delete the C++ GetEnableJaxArray() flag.
PiperOrigin-RevId: 518119698
2023-03-20 17:18:16 -07:00
Yash Katariya
207cc10058 Error if jax_array or jax_jit_pjit_api_merge is set to False.
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00
George Necula
1a9f49963c [jax2tf] Rename experimental_native_lowering to native_serialization
We refer to the feature as serialization rather than just lowering,
because the former is both more widely understood and is actually
more accurate because jax2tf will both lower to StableHLO and then
serialize to StableHLO with compatibility guarantees.

This is part of launching the new version of jax2tf with native
serialization.

For now we keep also the parameter `experimental_native_lowering` and
the flag `jax2tf_default_experimental_native_lowering`, until we transition
projects using these flags to the new ones (separate change).

PiperOrigin-RevId: 516864636
2023-03-15 10:31:25 -07:00
George Necula
dbd2033461 Stop using version 1 of XlaCallModuleOp
Also remove configuration flag jax2tf_use_stablehlo.

PiperOrigin-RevId: 515409050
2023-03-09 12:32:14 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
Yash Katariya
8b77bde0f4 Retrace jaxpr again if the same function with the same avals is run under different meshes. Add pxla.thread_resources to the global trace_context which takes care of the caching of global variables.
PiperOrigin-RevId: 514457821
2023-03-06 10:46:46 -08:00
Peter Hawkins
055fa6b90f Remove pytype suppression for jax/_src/config.py
This file no longer seems to make pytype unhappy.

PiperOrigin-RevId: 512668863
2023-02-27 10:39:55 -08:00
jax authors
71775720a7 Merge pull request #14615 from skye:restore_opt_barrier
PiperOrigin-RevId: 511935964
2023-02-23 18:08:08 -08:00
Skye Wanderman-Milne
6572fac49e Add back opt-barrier fallback, since the fallback sometimes prevents OOMs.
This reverts 4d418fb45e, and updates for a lax change.
2023-02-21 23:45:28 +00:00
Peter Hawkins
37d4ad910a Remove uses of jax.xla_computation from metadata_test.py
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.

PiperOrigin-RevId: 509975754
2023-02-15 17:26:21 -08:00
Yash Katariya
7350f00acd Remove jax_experimental_subjaxpr_lowering_cache since it was only for jit and was False by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Yash Katariya
6ec9082cf5 Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.
This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
Yash Katariya
2b093f1c9a Fix the warning being raised when jax.Array is True about using jax.Array
PiperOrigin-RevId: 505149151
2023-01-27 10:20:44 -08:00
Yash Katariya
cfdba777fb Add jax_jit_pjit_api_merge config to help the transition to merge jit and pjit.
PiperOrigin-RevId: 499295712
2023-01-03 12:59:46 -08:00
Tom Hennigan
4f75ad66be Revert #13747 since config values already have paired env vars.
PiperOrigin-RevId: 496935378
2022-12-21 09:06:08 -08:00
Tom Hennigan
76f92c47a6 Set default value for jax_platforms based on JAX_PLATFORMS env var.
PiperOrigin-RevId: 496872386
2022-12-21 02:38:14 -08:00
Jake VanderPlas
0a2d1cd45e Set bcoo_cusparse_lowering to False by default
This was causing out-of-bound writes on some CUDA backends

PiperOrigin-RevId: 494280591
2022-12-09 15:41:49 -08:00
Jongwook Choi
cd225853f7 Fix a false-positive typing warning on jax.default_device
Consider the following code where static type checkers can report an
error:

```python
CPU = jax.devices('cpu')[0]
with jax.default_device(CPU):
  ...                 # ^^^
```

Error message:
```
Pyright: Argument of type "Device" cannot be assigned to parameter "new_val" of type "NoDefault"
  "Device" is incompatible with "NoDefault" (reportGeneralTypeIssues)
```

This is because `_StateContextManager.__call__` does not have a proper
type annotation on the parameter, unlike the attribute `_default_value`
which has a type annotation. Adding a `Any` to the parameter would
make the error disappear.
2022-12-06 21:05:35 -05:00
Peter Hawkins
1f9c988e63 Use _thread_local_state.__dict__.get() instead of getattr(_thread_local_state, ...).
`getattr` turns out to be a tiny bit slower than `__get__()` on `__dict__` in the case that the attribute is absent. `getattr` appears to form an error message that is thrown away if a default is present.

Improves the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  51.4µs ± 1%  48.9µs ± 3%  -4.87%  (p=0.000 n=8+9)

name        old time/op             new time/op             delta
device_put  51.4µs ± 1%             48.9µs ± 3%  -4.87%          (p=0.000 n=8+9)
```

PiperOrigin-RevId: 493108288
2022-12-05 14:09:47 -08:00
Peter Hawkins
f9b5312149 Do not mirror JAX config options back to ABSL flags.
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.

However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.

This gives a decent improvement on the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  79.5µs ± 6%  69.4µs ± 7%  -12.73%  (p=0.000 n=10+9)

name        old time/op             new time/op             delta
device_put  79.5µs ± 6%             69.4µs ± 7%  -12.73%         (p=0.000 n=10+9)
```

PiperOrigin-RevId: 492519085
2022-12-02 11:37:22 -08:00
Roy Frostig
6a52339dcc include jax_threefry_partitionable setting in staging cache key 2022-11-22 15:20:01 -08:00
Yash Katariya
b6fa77cb60 Fix forward (Add deprecation warnings to DA, SDA and GDA): By raising the warnings in the hook of the jax_array config.
PiperOrigin-RevId: 489503583
2022-11-18 10:12:40 -08:00
Parker Schuh
4d418fb45e Remove opt-barrier fallbacks.
PiperOrigin-RevId: 489285590
2022-11-17 12:57:57 -08:00
Yash Katariya
ea930e1d8d Default jax.Array to True globally. See https://jax.readthedocs.io/en/latest/jax_array_migration.html for migration to jax.Array.
PiperOrigin-RevId: 488764287
2022-11-15 14:50:05 -08:00
George Necula
e4751d4b02 [jax2tf] Enable StableHLO in jax2tf native lowering.
PiperOrigin-RevId: 488654050
2022-11-15 07:42:49 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
jax authors
ef63f75e39 Merge pull request #13039 from skye:cache_compile_time_heuristic
PiperOrigin-RevId: 485644419
2022-11-02 11:13:52 -07:00
Skye Wanderman-Milne
cc5171034f Add new config jax_persistent_cache_min_compile_time_secs.
This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798, since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).

I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) numbers from

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new float config functionality.
2022-11-02 00:56:19 +00:00
Yash Katariya
ca1f58e37b Add a new jax.spmd_mode config for preventing unintentional hangs and incorrect results when users pass jax.Arrays that span across multiple processes (i.e. not fully addressable) to jit or jnp operations (that are jitted by default).
Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array.

Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays).
* Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`.

PiperOrigin-RevId: 485075693
2022-10-31 09:51:42 -07:00
Matthew Johnson
213d2c8592 integrate new (partitionable, count-space-exhaustive) counts generation 2022-10-29 00:05:49 -07:00
jax authors
89b240ba02 Merge pull request #13012 from mattjj:rng-part-overgenerate
PiperOrigin-RevId: 484567918
2022-10-28 10:41:35 -07:00
Roy Frostig
c8b9280fb3 partitionable threefry PRNG random bits implementation
the cost is 2x overgeneration of bits

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-10-28 10:07:14 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
jax authors
c1c8462371 Merge pull request #12798 from skye:cache_min_instr_count
PiperOrigin-RevId: 482349949
2022-10-19 17:54:03 -07:00
Skye Wanderman-Milne
81eb3fca55 Add new config jax_persistent_cache_min_instruction_count.
This can be used to limit the number of entries written to the
persistent compilation cache.

I defaulted to setting 6 as the minimum threshold based on running the
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) and logging
the instruction counts and complilation time:

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new int config functionality.

Fixes #12583
2022-10-20 00:17:24 +00:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Skye Wanderman-Milne
15e5f38a16 Make persistent compilation cache warn instead of raise an error on cache read/write failures
Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.

Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
jax authors
4f90af91d3 Remove unused jax_unique_mhlo_module_names flag.
PiperOrigin-RevId: 477778135
2022-09-29 11:32:22 -07:00
lenamartens
27e3981d52 lowerable errors behind a config flag. 2022-09-26 17:34:27 +01:00