23928 Commits

Author SHA1 Message Date
Yash Katariya
fff33f90b2 Add compiler_options argument to jax.jit.
This exists on `Compiled` object via AOT too i.e. `jit(f).lower(*args).compile(compiler_options={})`

PiperOrigin-RevId: 692283964
2024-11-01 14:01:19 -07:00
Yash Katariya
07858fa98d [sharding_in_types] Allow device_put to reshard inputs. device_put is a good choice for resharding since it already handles transpose correctly because it tracks the src sharding too.
PiperOrigin-RevId: 692274137
2024-11-01 13:25:08 -07:00
jax authors
d606c24293 Update XLA dependency to use revision
8ec02b3611.

PiperOrigin-RevId: 692266047
2024-11-01 12:57:20 -07:00
jax authors
a0b0a8e5a1 Set minimum supported Python version to 3.10 for matplotlib.
Temporary fixes an issue with `python -m build` that fails when python 3.8 is used because `matplotlib~=3.8.4` is unavailable for this python version.

We are working on creating Bazel build rule with the hermetic Python for JAX wheel ([we already have Jaxlib and plugins build rules ready](https://github.com/jax-ml/jax/pull/23276)). The required python modules are provided in requirements.in file, so when we implement Bazel build rule for JAX wheel, requirements.in will be the only source of dependencies, and test-requirements.txt won't be needed for building JAX wheel.

PiperOrigin-RevId: 692260046
2024-11-01 12:34:28 -07:00
jax authors
7bc026e496 Merge pull request #24669 from jakevdp:fix-sig-test
PiperOrigin-RevId: 692229460
2024-11-01 10:55:07 -07:00
jax authors
453d6ff0f2 Merge pull request #24668 from Li-Jesse-Jiaze:fix-issue-#24661
PiperOrigin-RevId: 692219522
2024-11-01 10:24:50 -07:00
Jake VanderPlas
97e8a4c8c6 Fix signatures test: new axis argument in trim_zeros 2024-11-01 10:15:31 -07:00
Naums Mogers
f462d7e586 [Mosaic] Set TPU CustomCall device type based on the core_type attribute
This CL deprecates the device_type parameter of `tpu_custom_call.as_tpu_kernel()` in favour of the `tpu.core_type` annotation.
The latter is more fine-grained: it is applied on `func.FuncOp` instead of the entire module, supports `tc`, `sc_scalar_subcore` and `sc_vector_subcore`.

`device_type` of the TPU CustomCall HLO is set to `sparsecore` if `sc_scalar_subcore` or `sc_vector_subcore` annotation is provided. Otherwise, `device_type` is not set and the CustomCall targets TC.

PiperOrigin-RevId: 692212644
2024-11-01 10:02:49 -07:00
Li-Jesse-Jiaze
5e1366c4ce Fix #24661: Add zsh support to conda install documentation 2024-11-01 17:57:18 +01:00
jax authors
bd7c301968 Merge pull request #24667 from jakevdp:fix-array-api
PiperOrigin-RevId: 692210603
2024-11-01 09:56:36 -07:00
Jake VanderPlas
e657a4b283 Fix array API tests.
This is currently causing failures on main.
2024-11-01 09:48:45 -07:00
jax authors
24f3a2bb79 Merge pull request #24665 from mattjj:shmap-jep
PiperOrigin-RevId: 692207887
2024-11-01 09:47:41 -07:00
Matthew Johnson
26f70c9c16 remove busted example from shmap jep 2024-11-01 16:37:46 +00:00
jax authors
2a41c04fef Merge pull request #24652 from jakevdp:old-deps
PiperOrigin-RevId: 691995759
2024-10-31 18:10:38 -07:00
jax authors
f4d675ea12 Merge pull request #24649 from jakevdp:array-api-update
PiperOrigin-RevId: 691992127
2024-10-31 17:54:14 -07:00
Ayaka
f60b97cea1 [Pallas TPU] Add lowering for lax.nextafter
Also improved the corresponding test cases to ensure better coverage and accuracy.

This PR is similar to https://github.com/jax-ml/jax/pull/22283, which adds lowering for `lax.sign`.

PiperOrigin-RevId: 691988164
2024-10-31 17:34:38 -07:00
Peter Hawkins
84c8794b30 Add a JaxIrContext that subclasses mlir.ir.Context and avoids calling ir.Context's __init__.
mlir.ir.Context has the unfortunate behavior that it loads all dialects linked into the binary, even those we have no intention of using. This is fairly benign in JAX's usual configuration, but if JAX is linked together with other MLIR-using software it can be problematic.

PiperOrigin-RevId: 691984229
2024-10-31 17:18:08 -07:00
jax authors
5a3ed6c792 Merge pull request #24647 from emilyfertig:emilyaf-doc-pytree-dataclass
PiperOrigin-RevId: 691984161
2024-10-31 17:16:31 -07:00
jax authors
423cd2ad5e Simplified conditional in flash attention.
PiperOrigin-RevId: 691972341
2024-10-31 16:28:11 -07:00
Emily Fertig
467bd09f03 Add a register_dataclass example to the pytree tutorial. 2024-10-31 16:26:42 -07:00
Jake VanderPlas
2b9c73d10d Remove a number of expired deprecations.
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
2024-10-31 15:40:54 -07:00
Jake VanderPlas
17ad8a9582 [array api] update test suite to latest commit 2024-10-31 14:53:28 -07:00
Tzu-Wei Sung
7af7a60dcc [Pallas:TPU] Use arith.divui for uint32 div.
PiperOrigin-RevId: 691939453
2024-10-31 14:37:47 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
jax authors
8536eca46e Update XLA dependency to use revision
edf18ce242.

PiperOrigin-RevId: 691908973
2024-10-31 13:06:07 -07:00
jax authors
c758373b9c Remove implicit sharding annotation for tpu custom call.
PiperOrigin-RevId: 691876343
2024-10-31 11:30:13 -07:00
Praveen Batra
8296f6e0ba [Mosaic] Add extension files for infer/apply vector layout.
PiperOrigin-RevId: 691868278
2024-10-31 11:08:37 -07:00
jax authors
7ff5a4eac2 Merge pull request #24190 from dfm:ffi-examples-gpu
PiperOrigin-RevId: 691862136
2024-10-31 10:53:27 -07:00
Vadym Matsishevskyi
a75d94622c Reverts 72f9a493589a1046e6927a5f16d7dc71df530743
PiperOrigin-RevId: 691843537
2024-10-31 10:05:22 -07:00
Praveen Batra
7d9f565647 [Mosaic] Fix some imports.
PiperOrigin-RevId: 691830491
2024-10-31 09:25:34 -07:00
Dan Foreman-Mackey
ce8dba98fb Move the CUDA end-to-end example to FFI examples workflow + hosted
runner.
2024-10-31 12:21:51 -04:00
jax authors
8abedda8a6 Merge pull request #24480 from dfm:dot-algorithm-plugin-enable
PiperOrigin-RevId: 691734684
2024-10-31 03:10:23 -07:00
Dan Foreman-Mackey
52ad60521c Run dot algorithm tests with PJRT plugin. 2024-10-31 06:01:11 -04:00
Benjamin Chetioui
c708a04c6e [Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect.
Also start moving the existing C++ tests to Python.

PiperOrigin-RevId: 691729887
2024-10-31 02:47:30 -07:00
Sergei Lebedev
85662f6dd8 [pallas:mosaic_gpu] plgpu.copy_smem_to_gmem no longer transparently commits SMEM
Users are expected to call `pltpu.commit_smem` manually instead.

PiperOrigin-RevId: 691724662
2024-10-31 02:21:10 -07:00
Dimitar (Mitko) Asenov
7d504cd95a [MOSAIC:GPU] Extend the mosaic mlir dialect with fragmented layouts.
PiperOrigin-RevId: 691712579
2024-10-31 01:29:22 -07:00
jax authors
5aeffde707 [Mosaic] Extend tpu matmulop to have dimension dims. Add support for batching and simple transposition.
PiperOrigin-RevId: 691706218
2024-10-31 00:59:13 -07:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Yash Katariya
7f4a34e12b Remove the variant since sparsecore is only on v5p and it's device kind is TPU v5.
PiperOrigin-RevId: 691586791
2024-10-30 16:18:54 -07:00
Jake VanderPlas
0181cb396d Re-land #24589 with fixes to handle dtype that is not compatible with NumPy.
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.

Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d

PiperOrigin-RevId: 691568933
2024-10-30 15:13:00 -07:00
Naums Mogers
242e6634ff [Mosaic] Add the core type enum
The new attribute allows differentiating compilation by target core.

PiperOrigin-RevId: 691531726
2024-10-30 13:23:34 -07:00
jax authors
af14c43893 Update XLA dependency to use revision
2d9d84487e.

PiperOrigin-RevId: 691516089
2024-10-30 12:36:35 -07:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Dougal Maclaurin
32bf19ac6f Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.

PiperOrigin-RevId: 691494516
2024-10-30 11:34:08 -07:00
jax authors
3904ced255 [Mosaic] Test only cl - add triu test, skip bf16 due to select being native bitwidth only
PiperOrigin-RevId: 691477248
2024-10-30 10:48:44 -07:00
jax authors
99ea4c1a4a [Fix] Put * packing into reshape no-op condition (Bug in my original CL)
PiperOrigin-RevId: 691476663
2024-10-30 10:47:23 -07:00
Sergei Lebedev
409517fcbc [pallas:mosaic_gpu] Disabled verbose lowering errors in Mosaic GPU tests
PiperOrigin-RevId: 691472782
2024-10-30 10:37:32 -07:00
Sergei Lebedev
6283eab2ff [pallas] Added a flag disabling verbose error reporting
PiperOrigin-RevId: 691463398
2024-10-30 10:13:22 -07:00
Nitin Srinivasan
da994d3552 Move utility functions in build.py to utils.py
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.

PiperOrigin-RevId: 691458051
2024-10-30 10:00:32 -07:00
Tzu-Wei Sung
d2f5804449 [Pallas] Add test cases for var + constant.
PiperOrigin-RevId: 691450143
2024-10-30 09:37:50 -07:00