144 Commits

Author SHA1 Message Date
George Necula
8d80e2587b [jax2tf] Turn on JAX native serialization by default.
See changes to the README.md for mechanisms to override the default.

PiperOrigin-RevId: 554390866
2023-08-07 01:03:55 -07:00
Patrick Kidger
5e276d0935 Tracebacks no longer have JAX-internal frames prepended by default 2023-08-03 11:38:38 -07:00
Peter Hawkins
a40f900e23 Fix exception if additional ABSL flags were registered after config_with_absl() was called.
PiperOrigin-RevId: 552584357
2023-07-31 13:59:01 -07:00
Skye Wanderman-Milne
8b58e38ec5 Add jax_debug_log_modules config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).

Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-07-28 18:11:12 +00:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
George Necula
71e2d289a4 [jax2tf] Document the JAX serialization version numbers. 2023-07-21 08:11:44 +03:00
Peter Hawkins
59509dc2b3 Remove the jax_array config option, which does nothing.
PiperOrigin-RevId: 548981491
2023-07-18 06:16:06 -07:00
George Necula
603eeb1901 Copybara import of the project:
--
06bf5fe7b2ac97156df541bab989dc5beb1aff0c by George Necula <gcnecula@gmail.com>:

[jax2tf] Added a flag and environment variable to control the serialization version.

This allows us to control the serialization version to be compatible with
the deployed version of tf.XlaCallModule. In particular, we can run
most tests with the maximum available version, while keeping the
default lower.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16746 from gnecula:tf_version 06bf5fe7b2ac97156df541bab989dc5beb1aff0c
PiperOrigin-RevId: 548504243
2023-07-16 09:27:12 -07:00
Peter Hawkins
651f87733b Remove jax_jit_pjit_api_merge.
PiperOrigin-RevId: 548236671
2023-07-14 15:25:00 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Matthew Johnson
d42350f879 disable custom_jvp for softmax by default
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
2023-05-23 11:56:50 -07:00
Peter Hawkins
ee14ca2628 Add option jax_include_full_tracebacks_in_locations.
If enabled, includes full stack traces in MLIR emitted by JAX. These cannot be consumed by XLA at the moment.

PiperOrigin-RevId: 534060827
2023-05-22 07:41:29 -07:00
Peter Hawkins
26f2711aeb Fix typo in config.py.
Fixes #16066
2023-05-19 10:07:12 -04:00
Yash Katariya
34d5a6259f Default jax_spmd_mode to allow_jit which will allow explicit jax.jit to not raise the multihost error (since jit and pjit have been merged).
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.

PiperOrigin-RevId: 527398394
2023-04-26 15:56:46 -07:00
Matthew Johnson
e0d2736e37 add custom_jvp for jax.nn.softmax
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
Yash Katariya
3722d7066a Add jax_pmap_shmap_merge flag to begin the process of merging pmap and shard_map
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).

TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
2023-04-20 21:22:48 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
George Necula
2ac2dc65b1 Remove jax2tf experimental_native_lowering.
Users should use native_serialization.

PiperOrigin-RevId: 520063928
2023-03-28 10:17:58 -07:00
Matthew Johnson
6b4262d9f6 add experimental jax_log_checkpoint_residuals option
The main idea here is to improve tooling for knowing what residuals are being
saved and why. There's a lot more that can be done here (e.g. naming the
arguments, explaining what JVP rule produced these residuals, explaining what
consumed them, etc) but this is a start.

Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2023-03-22 16:26:56 -07:00
Parker Schuh
e89235ffdd Delete the C++ GetEnableJaxArray() flag.
PiperOrigin-RevId: 518119698
2023-03-20 17:18:16 -07:00
Yash Katariya
207cc10058 Error if jax_array or jax_jit_pjit_api_merge is set to False.
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00
George Necula
1a9f49963c [jax2tf] Rename experimental_native_lowering to native_serialization
We refer to the feature as serialization rather than just lowering,
because the former is both more widely understood and is actually
more accurate because jax2tf will both lower to StableHLO and then
serialize to StableHLO with compatibility guarantees.

This is part of launching the new version of jax2tf with native
serialization.

For now we keep also the parameter `experimental_native_lowering` and
the flag `jax2tf_default_experimental_native_lowering`, until we transition
projects using these flags to the new ones (separate change).

PiperOrigin-RevId: 516864636
2023-03-15 10:31:25 -07:00
George Necula
dbd2033461 Stop using version 1 of XlaCallModuleOp
Also remove configuration flag jax2tf_use_stablehlo.

PiperOrigin-RevId: 515409050
2023-03-09 12:32:14 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
Yash Katariya
8b77bde0f4 Retrace jaxpr again if the same function with the same avals is run under different meshes. Add pxla.thread_resources to the global trace_context which takes care of the caching of global variables.
PiperOrigin-RevId: 514457821
2023-03-06 10:46:46 -08:00
Peter Hawkins
055fa6b90f Remove pytype suppression for jax/_src/config.py
This file no longer seems to make pytype unhappy.

PiperOrigin-RevId: 512668863
2023-02-27 10:39:55 -08:00
jax authors
71775720a7 Merge pull request #14615 from skye:restore_opt_barrier
PiperOrigin-RevId: 511935964
2023-02-23 18:08:08 -08:00
Skye Wanderman-Milne
6572fac49e Add back opt-barrier fallback, since the fallback sometimes prevents OOMs.
This reverts 4d418fb45e, and updates for a lax change.
2023-02-21 23:45:28 +00:00
Peter Hawkins
37d4ad910a Remove uses of jax.xla_computation from metadata_test.py
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.

PiperOrigin-RevId: 509975754
2023-02-15 17:26:21 -08:00
Yash Katariya
7350f00acd Remove jax_experimental_subjaxpr_lowering_cache since it was only for jit and was False by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Yash Katariya
6ec9082cf5 Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.
This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
Yash Katariya
2b093f1c9a Fix the warning being raised when jax.Array is True about using jax.Array
PiperOrigin-RevId: 505149151
2023-01-27 10:20:44 -08:00
Yash Katariya
cfdba777fb Add jax_jit_pjit_api_merge config to help the transition to merge jit and pjit.
PiperOrigin-RevId: 499295712
2023-01-03 12:59:46 -08:00
Tom Hennigan
4f75ad66be Revert #13747 since config values already have paired env vars.
PiperOrigin-RevId: 496935378
2022-12-21 09:06:08 -08:00
Tom Hennigan
76f92c47a6 Set default value for jax_platforms based on JAX_PLATFORMS env var.
PiperOrigin-RevId: 496872386
2022-12-21 02:38:14 -08:00
Jake VanderPlas
0a2d1cd45e Set bcoo_cusparse_lowering to False by default
This was causing out-of-bound writes on some CUDA backends

PiperOrigin-RevId: 494280591
2022-12-09 15:41:49 -08:00
Jongwook Choi
cd225853f7 Fix a false-positive typing warning on jax.default_device
Consider the following code where static type checkers can report an
error:

```python
CPU = jax.devices('cpu')[0]
with jax.default_device(CPU):
  ...                 # ^^^
```

Error message:
```
Pyright: Argument of type "Device" cannot be assigned to parameter "new_val" of type "NoDefault"
  "Device" is incompatible with "NoDefault" (reportGeneralTypeIssues)
```

This is because `_StateContextManager.__call__` does not have a proper
type annotation on the parameter, unlike the attribute `_default_value`
which has a type annotation. Adding a `Any` to the parameter would
make the error disappear.
2022-12-06 21:05:35 -05:00
Peter Hawkins
1f9c988e63 Use _thread_local_state.__dict__.get() instead of getattr(_thread_local_state, ...).
`getattr` turns out to be a tiny bit slower than `__get__()` on `__dict__` in the case that the attribute is absent. `getattr` appears to form an error message that is thrown away if a default is present.

Improves the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  51.4µs ± 1%  48.9µs ± 3%  -4.87%  (p=0.000 n=8+9)

name        old time/op             new time/op             delta
device_put  51.4µs ± 1%             48.9µs ± 3%  -4.87%          (p=0.000 n=8+9)
```

PiperOrigin-RevId: 493108288
2022-12-05 14:09:47 -08:00
Peter Hawkins
f9b5312149 Do not mirror JAX config options back to ABSL flags.
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.

However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.

This gives a decent improvement on the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  79.5µs ± 6%  69.4µs ± 7%  -12.73%  (p=0.000 n=10+9)

name        old time/op             new time/op             delta
device_put  79.5µs ± 6%             69.4µs ± 7%  -12.73%         (p=0.000 n=10+9)
```

PiperOrigin-RevId: 492519085
2022-12-02 11:37:22 -08:00
Roy Frostig
6a52339dcc include jax_threefry_partitionable setting in staging cache key 2022-11-22 15:20:01 -08:00
Yash Katariya
b6fa77cb60 Fix forward (Add deprecation warnings to DA, SDA and GDA): By raising the warnings in the hook of the jax_array config.
PiperOrigin-RevId: 489503583
2022-11-18 10:12:40 -08:00
Parker Schuh
4d418fb45e Remove opt-barrier fallbacks.
PiperOrigin-RevId: 489285590
2022-11-17 12:57:57 -08:00
Yash Katariya
ea930e1d8d Default jax.Array to True globally. See https://jax.readthedocs.io/en/latest/jax_array_migration.html for migration to jax.Array.
PiperOrigin-RevId: 488764287
2022-11-15 14:50:05 -08:00
George Necula
e4751d4b02 [jax2tf] Enable StableHLO in jax2tf native lowering.
PiperOrigin-RevId: 488654050
2022-11-15 07:42:49 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
jax authors
ef63f75e39 Merge pull request #13039 from skye:cache_compile_time_heuristic
PiperOrigin-RevId: 485644419
2022-11-02 11:13:52 -07:00
Skye Wanderman-Milne
cc5171034f Add new config jax_persistent_cache_min_compile_time_secs.
This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798, since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).

I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) numbers from

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new float config functionality.
2022-11-02 00:56:19 +00:00
Yash Katariya
ca1f58e37b Add a new jax.spmd_mode config for preventing unintentional hangs and incorrect results when users pass jax.Arrays that span across multiple processes (i.e. not fully addressable) to jit or jnp operations (that are jitted by default).
Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array.

Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays).
* Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`.

PiperOrigin-RevId: 485075693
2022-10-31 09:51:42 -07:00