213 Commits

Author SHA1 Message Date
George Necula
24b42eed5e [export] Clean up BUILD targets for jax.experimental.export
jax.experimental.export is deprecated and will be removed in a future version of JAX.

See migration guide at: https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export

PiperOrigin-RevId: 647562073
2024-06-27 23:08:48 -07:00
Yash Katariya
e1a496d3b6 Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
2024-06-27 16:47:31 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
Justin Fu
8ba8f3bf65 [Pallas] Implement block-invariant sampling.
PiperOrigin-RevId: 646161271
2024-06-24 11:20:39 -07:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
jax authors
ce4a56a137 Merge pull request #21394 from ayaka14732:lru-cache
PiperOrigin-RevId: 642333998
2024-06-11 11:29:18 -07:00
Ayaka
1a3a15c9e3 Implement LRU cache eviction for persistent compilation cache
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-11 21:48:35 +04:00
Sergei Lebedev
f8473509cf Removed kernel_regeneration_util from Mosaic
It was only used for persisting kernel metadata, and that can be done via
jax.named_scope instead.

PiperOrigin-RevId: 642195336
2024-06-11 02:36:41 -07:00
Justin Fu
9439f63645 [Pallas] Add pallas TPU random key impls and lowering rules for basic prng ops (seed/foldin/bits/unwrap/wrap).
PiperOrigin-RevId: 642085019
2024-06-10 18:08:19 -07:00
Sergei Lebedev
5e7ad600e2 Removed the double re-exporting of Pallas GPU/TPU APIs
jax.experimental.pallas.{gpu,tpu} now import directly from the relevant
jax._src.pallas.{triton,mosaic} submodules.

PiperOrigin-RevId: 641875127
2024-06-10 05:59:09 -07:00
George Necula
14d87d3bf7 [export] Move the export implementation to jax._src.export.
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.

Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.

PiperOrigin-RevId: 641688200
2024-06-09 08:59:50 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Yash Katariya
9e3f290de3 Delete XLACompatibleSharding and replace with jax.sharding.Sharding.
As of this change, `XLACompatibleSharding` is an alias of `jax.sharding.Sharding` but it will be deprecated in a follow up change.

Why do this?

* All shardings JAX has are XLA Compatible. The reason why `Sharding` was created was to allow non-xla shardings but that's not happened in the past 2 years. So let's simplify!

* Having these 2 types makes things very confusing. One example is:
  * `jax.jit` only accepts XLACompatibleShardings.
  * `jax.device_put` accepts `jax.sharding.Sharding` but if you use `device_put` inside `jax.jit` with a memory_kind then you can only pass `XLACompatibleSharding`. This is contradicting and confusing and we can simplify.

PiperOrigin-RevId: 640527070
2024-06-05 08:03:23 -07:00
Sergei Lebedev
40f107e5a5 Moved Pallas GPU ops into pallas/ops/gpu
PiperOrigin-RevId: 640439838
2024-06-05 01:34:46 -07:00
George Necula
39ac584729 [shape_poly] Move to jax._src in preparation for adding to AOT APIs.
The shape polymorphism APIs are still private and are only exposed through `jax.experimental.export` as before.

PiperOrigin-RevId: 640393089
2024-06-04 22:03:24 -07:00
Yash Katariya
1273028018 Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
2024-06-03 14:52:50 -07:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Yazhou Zu
91d68b5564 creat jax config api to allow custom pjrt client create option settings. this allows a device platform's pjrt client be aware of the calling (customer) ml framework
PiperOrigin-RevId: 638009713
2024-05-28 13:43:06 -07:00
jax authors
93170d9c80 Add JAX version to TPU_ML_PLATFORM_VERSION environment variable.
This will allow us to track the JAX version that is being used on Cloud TPUs

PiperOrigin-RevId: 637025132
2024-05-24 13:56:19 -07:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Sergei Lebedev
071a48719d Added pl.debug_print() -- a new primitive for printing from Pallas kernels
The primitive is currently only support in Pallas GPU when lowering to Triton.
See documentation inline for the Triton-specific restrictions.

PiperOrigin-RevId: 636120214
2024-05-22 04:41:42 -07:00
Yash Katariya
02c19e9600 Make jax.grad and compute_on work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
PiperOrigin-RevId: 634917402
2024-05-17 16:38:35 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
jax authors
c4559115ec Internal BUILD file change
PiperOrigin-RevId: 634713068
2024-05-17 04:30:21 -07:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
Sergei Lebedev
e2918ca138 Added a very rough sketch of Mosaic GPU lowering for Pallas
Almost nothing is supported, including

* PyTree inputs/outputs
* indexers
* non-trivial grids
* block specs
* any primitives beyond the ones added here
* etc etc

PiperOrigin-RevId: 633713366
2024-05-14 14:48:09 -07:00
jax authors
11da3df238 Merge pull request #21096 from gspschmid:gschmid/sourcemaps
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
George Necula
b40a31006c [export] Add backwards compatibility test for Pallas call on GPUs.
Note that this adds the minimum of safety net to protect against
non-backwards-compatible changes. We really should have more tests
that cover more of the Triton MLIR.

Also enable serialization of such calls.

PiperOrigin-RevId: 630033989
2024-05-02 05:38:33 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Matthew Johnson
89f26db36d start adding EArray, a jax.Array analog that can contain extended dtypes 2024-04-06 13:09:25 -07:00
George Necula
a510f03ef8 [callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
2024-04-05 08:51:30 +01:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Sergei Lebedev
f74f4ed48b Removed unnecessary BUILD dependencies from :ops_test
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Yue Sheng
291a5cd3e0 [PJRT][IFRT] Update PJRT, IFRT, and Py executable getters to return PjRtLayouts
PiperOrigin-RevId: 617889924
2024-03-21 10:30:57 -07:00
Tomás Longeri
99fadcbcec [Mosaic] Restore Python pipeline and add a CLI flag to run it.
We decided to expose a Python alternative again to make it easier for OSS users to see and customize the pipeline. The default is still to run the pipeline from XLA.

The original one was removed in cl/596464480 and cl/597332393.

PiperOrigin-RevId: 617291995
2024-03-19 14:18:33 -07:00
Yue Sheng
1cef1d9503 jax.clear_backends() is not doing what it is intended to do, users should try to avoid using it.
We decide to move it into `jax.extend`. This CL is the first step which adds a new module `jax.extend.backend`.

PiperOrigin-RevId: 615934218
2024-03-14 16:11:31 -07:00
jax authors
2e83fed0b3 Merge pull request #20026 from mattjj:mutable-arrays
PiperOrigin-RevId: 611707543
2024-02-29 22:18:05 -08:00
Matthew Johnson
ab0f7061ad [mutable-arrays] allow state effects in jit by building in run_state
with help from @sharadmv, @yashkatariya, @dougalm, and others

The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
   handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
   refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.

As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
2024-02-29 21:50:19 -08:00
Qiao Zhang
9fcf9e52b5 Add Pallas attention kernel for GPU serving.
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 607404565
2024-02-15 11:44:20 -08:00
jax authors
0b33eb7c68 Merge pull request #19588 from jakevdp:jax-tree
PiperOrigin-RevId: 606665122
2024-02-13 10:18:29 -08:00
jax authors
7b05bbdda0 Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Jake VanderPlas
6934a4b76b Add jax.tree module with aliases of jax.tree_util 2024-02-12 13:07:59 -08:00
Cjkkkk
916e53a8a2 add keyword-only argument & fix scale issue 2024-02-09 09:05:09 -08:00
jax authors
9b27d43e70 Import submodules from jax._src explicitly, instead of relying on import side-effects. It will lead to the missing x-refs in code search according to go/pywald-sawmill-analysis.
PiperOrigin-RevId: 604788105
2024-02-06 15:47:16 -08:00
jax authors
0d152dcfab Merge pull request #19528 from superbobry:strict-abc
PiperOrigin-RevId: 602392902
2024-01-29 08:18:50 -08:00
Sergei Lebedev
078bb00fdb Replaced most usages of abc.ABC with util.StrictABC
StrictABC does not allow registering virtual subclasses and can thus avoid
using relatively expensive __instancecheck__/__sublclasscheck__ defined in
abc.ABCMeta.

The only abc.ABC subclass left is jax.Array which *does* use virtual
subclasses for natively-defined array types.
2024-01-29 12:40:43 +00:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00