184 Commits

Author SHA1 Message Date
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Matthew Johnson
89f26db36d start adding EArray, a jax.Array analog that can contain extended dtypes 2024-04-06 13:09:25 -07:00
George Necula
a510f03ef8 [callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
2024-04-05 08:51:30 +01:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Sergei Lebedev
f74f4ed48b Removed unnecessary BUILD dependencies from :ops_test
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Yue Sheng
291a5cd3e0 [PJRT][IFRT] Update PJRT, IFRT, and Py executable getters to return PjRtLayouts
PiperOrigin-RevId: 617889924
2024-03-21 10:30:57 -07:00
Tomás Longeri
99fadcbcec [Mosaic] Restore Python pipeline and add a CLI flag to run it.
We decided to expose a Python alternative again to make it easier for OSS users to see and customize the pipeline. The default is still to run the pipeline from XLA.

The original one was removed in cl/596464480 and cl/597332393.

PiperOrigin-RevId: 617291995
2024-03-19 14:18:33 -07:00
Yue Sheng
1cef1d9503 jax.clear_backends() is not doing what it is intended to do, users should try to avoid using it.
We decide to move it into `jax.extend`. This CL is the first step which adds a new module `jax.extend.backend`.

PiperOrigin-RevId: 615934218
2024-03-14 16:11:31 -07:00
jax authors
2e83fed0b3 Merge pull request #20026 from mattjj:mutable-arrays
PiperOrigin-RevId: 611707543
2024-02-29 22:18:05 -08:00
Matthew Johnson
ab0f7061ad [mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others

The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
   handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
   refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.

As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-29 21:50:19 -08:00
Qiao Zhang
9fcf9e52b5 Add Pallas attention kernel for GPU serving.
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 607404565
2024-02-15 11:44:20 -08:00
jax authors
0b33eb7c68 Merge pull request #19588 from jakevdp:jax-tree
PiperOrigin-RevId: 606665122
2024-02-13 10:18:29 -08:00
jax authors
7b05bbdda0 Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Jake VanderPlas
6934a4b76b Add jax.tree module with aliases of jax.tree_util 2024-02-12 13:07:59 -08:00
Cjkkkk
916e53a8a2 add keyword-only argument & fix scale issue 2024-02-09 09:05:09 -08:00
jax authors
9b27d43e70 Import submodules from jax._src explicitly, instead of relying on import side-effects. It will lead to the missing x-refs in code search according to go/pywald-sawmill-analysis.
PiperOrigin-RevId: 604788105
2024-02-06 15:47:16 -08:00
jax authors
0d152dcfab Merge pull request #19528 from superbobry:strict-abc
PiperOrigin-RevId: 602392902
2024-01-29 08:18:50 -08:00
Sergei Lebedev
078bb00fdb Replaced most usages of abc.ABC with util.StrictABC
StrictABC does not allow registering virtual subclasses and can thus avoid
using relatively expensive __instancecheck__/__sublclasscheck__ defined in
abc.ABCMeta.

The only abc.ABC subclass left is jax.Array which *does* use virtual
subclasses for natively-defined array types.
2024-01-29 12:40:43 +00:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
jax authors
7418f55987 Merge pull request #19007 from sshahrokhi:enhanced
PiperOrigin-RevId: 598002075
2024-01-12 17:12:16 -08:00
Shiva Shahrokhi
65f3e4fffd making sure enhanced barrier only turns on when there is a supported TPU available. 2024-01-12 23:47:37 +00:00
jax authors
adbbe69cc2 Add option to share compiled module between hosts.
PiperOrigin-RevId: 597754861
2024-01-11 23:38:02 -08:00
Sharad Vikram
598b46aab5 [Pallas/TPU] Open source "Splash Attention" (Sparse Flash Attention), a general purpose attention kernel where you can specify an attention mask using NumPy.
PiperOrigin-RevId: 597658315
2024-01-11 14:47:39 -08:00
Hyeontaek Lim
f2e526dc78 Internal code cleanup for reducing private API access.
PiperOrigin-RevId: 597397571
2024-01-10 17:25:05 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
Sharad Vikram
836563fadf [Pallas] Refactor indexing primitives to use NDIndexer abstraction
Some notes about this change:
* This change upgrades the `RefView` abstraction to store multiple indexers.
  This allows doing things like `ref.at[0].at[0]` to recursively create a view
  of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s.
* This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p)
  but does *not* generalize their rules. Most of the rules will raise a
  NotImplementedError if you use multiple `NDIndexer`s. Adding support will be
  done in a future CL.
* With the above in mind, this change only preserves existing public facing APIs
  and adding actual support will involve updating the rules.

PiperOrigin-RevId: 595229523
2024-01-02 15:53:40 -08:00
Yash Katariya
72fbdb2eb5 Expose shard_alike via jax.experimental. The API is x, y = shard_like(x, y).
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.

The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.

The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.

Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
2023-12-19 16:31:33 -08:00
George Necula
bb84e6c22e Improve support for JAX_DUMP_IR_TO.
Previously the environment variable JAX_DUMP_IR_TO controlled
whether and where to dump the MLIR module prior to compilation. Now we move the code for that support from
compiler.py to mlir.py, so that it can be used in other
parts of the code. We also add support for logging to Sponge.

Using this support we now log the module on errors from
refine_polymorphic_shapes.

PiperOrigin-RevId: 592099633
2023-12-18 21:25:45 -08:00
George Necula
eed61f68aa Move export backwards compatibility tests out of jax2tf. Step 1.
These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.

This is part of a larger cl/562671314 that ran into OSS build problems.
I am attempting this smaller change first, and afterwards I will move more of the test data files, and then the actual test.

PiperOrigin-RevId: 591927484
2023-12-18 09:49:52 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Peter Hawkins
d95084dbc8 Use an explicit MLIR dialect registration, rather than _site_initialize_0.
Remove some special case handling of the SCF dialect, use upstream utilities instead.

PiperOrigin-RevId: 588433245
2023-12-06 08:19:55 -08:00
Peter Hawkins
720ff42cbf [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib.
Cleanup only, NFC intended.

PiperOrigin-RevId: 588074047
2023-12-05 08:05:17 -08:00
Peter Hawkins
7fa0f464fd [bazel] Add a BUILD file for jax/extend, and add more granular targets for individual pieces of extend.
In general we'd like to use more granular BUILD targets rather than larger monolithic targets. If nothing else, they interact better with pytype.

This change is in preparation for adding the JAX MLIR bindings to jax.extend, since they are something that JAX users sometimes need especially for defining custom ops.

PiperOrigin-RevId: 587893573
2023-12-04 17:48:50 -08:00
George Necula
c1f54d447e Move back_compat_test_util.py to jax._src.internal_test_util.
Until now the backwards compatibility tests for exporting JAX functions with custom calls were part of the jax2tf test suite. But these tests are independent of TF, and we need to write such tests for Pallas and other projects that should not depend on jax2tf.

Here we move the test utilities out of jax2tf.
This is needed to enable writing Pallas backwards compatibility tests.

We rename back_compat_test_util.py to export_back_compat_test_util.py for clarity.

In a subsequent move we will move the actual backwards compatibility tests themselves out of jax2tf.

PiperOrigin-RevId: 583312085
2023-11-17 02:05:30 -08:00
Jake VanderPlas
271d31c1c8 Add jax.experimental.array_api interface 2023-11-16 14:21:04 -08:00
Jieying Luo
43732e3fd4 Change the definition of the config to run bazel test for cuda plugin to match //jax:build_jaxlib.
When build_cuda_plugin_from_source is true, it will build cuda plugin from source, and it is used for the case of `bazel test` without preinstall jax cuda packages.

PiperOrigin-RevId: 583057751
2023-11-16 08:44:22 -08:00
Jieying Luo
88685d8de0 Support bazel test without bazel build for CUDA PJRT plugin.
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.

The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```

Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.

PiperOrigin-RevId: 582728477
2023-11-15 10:38:19 -08:00
Yash Katariya
5c3da219c0 Add a private API to allow setting layouts on jitted computations.
We expose 3 modes:

* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.

* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.

* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.

Public API coming soon.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
2023-11-15 08:48:53 -08:00
George Necula
5001a21bad Move primitive_harness.py to jax._src.internal_test_util.test_harnesses.
The primitive_harness.py defines a set of about 7000 test harnesses, each with a JAX callable and a recipe for generating the arguments for the callable. Note that the test harness does not define any expected behavior. The test harnesses can be used in several kinds of tests.

Initially these harnesses were designed to test the completeness of the jax2tf lowering: for each test harness we convert it to TF and then we test that the result of invoking it is the same as for JAX native. Since then we have found other uses of test harnesses.

  * E.g., shape_poly_test.py tests that we can apply `jax.vmap` to each test harness and that we get a JAX callable that can be traced shape polymorphically, using a dimension variable for the batch dimension.
  * E.g., multi_platform_lowering_test.py tests that we can generate multi-platform lowering for each test harnesse.
  * E.g., the TFLite team is using the test harnesses to check the completeness of the TFLite lowering.

Since the test harnesses are useful for non-jax2tf uses we hereby moved them to jax._src.internal_test_util.test_harnesses. (We also renamed the module from primitive_harness to test_harnesses.)

This change is necessary to move some tests out of jax2tf: multi_platform_lowering_test.py, shape_poly_test.py.

PiperOrigin-RevId: 581016785
2023-11-09 13:58:00 -08:00
Jake VanderPlas
2932c7eb91 Set public module for exported jax.dtypes APIs 2023-10-17 15:07:28 -07:00
Adam Paszke
b84ae9821f Make sure we don't filter stack frames of packages that start with a jax prefix
This ended up accidentally setting up filters for jax_triton. This change additionally
adds an opt-in mechanism for paths, that overrides exclusions. We use this to avoid
treating pallas ops implementations as JAX-internal.

PiperOrigin-RevId: 574167963
2023-10-17 09:06:30 -07:00
Sergei Lebedev
5ab05e42c9 MAINT Clean up leftover Array = Any aliases in jax/_src/**.py
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
2023-10-01 12:19:21 +01:00
Adam Paszke
fe0f12e00e [Pallas] Wire up cost estimates in Mosaic params
We could probably estimate the cost by running the standard HLO analysis
on the kernel body and scaling by the grid size, but that would require
more work, so for now I only exposed the manual knob.

PiperOrigin-RevId: 566351498
2023-09-18 10:54:05 -07:00
jax authors
871c9f4d76 Merge pull request #17307 from froystig:wrap-key
PiperOrigin-RevId: 560536131
2023-08-27 12:58:50 -07:00
Roy Frostig
a69f134cde add jax.extend.random.wrap_key_data 2023-08-26 11:39:25 -07:00
jax authors
841baabd3f Adds Pallas flash attention TPU kernel. Implementation based on https://arxiv.org/pdf/2205.14135.pdf.
PiperOrigin-RevId: 560346791
2023-08-26 08:03:48 -07:00
Roy Frostig
a71c0e6ecc create jax.extend.random as a copy of jax.prng
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00