1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-18 04:46:06 +00:00
Commit Graph

26892 Commits

Author SHA1 Message Date
Yash Katariya
42751359e6 Fix typo in the error message
PiperOrigin-RevId: 745375892
2025-04-08 19:00:33 -07:00
Yash Katariya
f95f6a8bdb Rename psum2 to psum_invariant and put it in lax_parallel. We shouldn't expose this to public API and have users use psum instead which will dispatch to psum_invariant when check_rep=True.
PiperOrigin-RevId: 745352875
2025-04-08 17:28:59 -07:00
Yash Katariya
84016bc368 Rename pbroadcast to pvary and expose it as jax.lax.pvary.
PiperOrigin-RevId: 745342103
2025-04-08 16:51:27 -07:00
jax authors
373ac2ef7e Merge pull request from vfdev-5:ft-adapt-state-test-2
PiperOrigin-RevId: 745341315
2025-04-08 16:48:54 -07:00
Peter Hawkins
a516988bd5 [JAX] Remove calls to xla_computation_to_mlir_module.
This (private) API will shortly be deleted, and hlo_to_stablehlo is its replacement.

PiperOrigin-RevId: 745333506
2025-04-08 16:24:16 -07:00
jax authors
b8d9e7f427 Merge pull request from kaixih:enable_doc_scaled_dot
PiperOrigin-RevId: 745322012
2025-04-08 15:50:54 -07:00
Jevin Jiang
7b4555247d [ragged-paged-attn] Unify kv strided load to one.
I expected Mosaic can canonicalize 2 same strided loads to one but it did not. (We will fix this in Mosaic). For now, manually converting to one strided load boosts 20~35% speedup in both v6e and v5e single chip for Meta-Llama-3-8B.

PiperOrigin-RevId: 745294058
2025-04-08 14:33:30 -07:00
vfdev-5
5a340a9781 Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
  - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
2025-04-08 21:02:55 +00:00
Yash Katariya
8301c304c1 Make changes to shard_map to prepare for setting varying_axes_in_types to True.
The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
2025-04-08 13:47:13 -07:00
Peter Hawkins
b4629c230c Split weakref_lru_cache into its own extension.
Now that db11efab3b has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.

There's no reason weakref_lru_cache is in the same Python extension as everything else.

PiperOrigin-RevId: 745271825
2025-04-08 13:34:47 -07:00
Dan Foreman-Mackey
2d44f985c3 Finalize deprecation of ffi_call with inline arguments.
PiperOrigin-RevId: 745261995
2025-04-08 13:09:42 -07:00
Peter Hawkins
09fed2f643 Remove reexports of ml_dtypes types from xla_client.py.
These should be used directly from ml_dtypes.

PiperOrigin-RevId: 745256523
2025-04-08 12:55:10 -07:00
Sergei Lebedev
62df2e8d86 Added jax.no_tracing to the API docs
PiperOrigin-RevId: 745247778
2025-04-08 12:32:35 -07:00
Peter Hawkins
a43136b34d Simplify handling of type stubs in nanobind extension rules.
Pass pytype_srcs as data to the pybind_extension rule.

PiperOrigin-RevId: 745238783
2025-04-08 12:07:28 -07:00
jax authors
f1bcf3bb94 Merge pull request from j-towns:clarify-make-jaxpr-docstr
PiperOrigin-RevId: 745216259
2025-04-08 11:11:57 -07:00
jax authors
76825a2d45 Merge pull request from jakevdp:eigvalsh-symmetrize
PiperOrigin-RevId: 745216021
2025-04-08 11:09:58 -07:00
jax authors
b073e8db25 Merge pull request from jakevdp:fix-repeat
PiperOrigin-RevId: 745215941
2025-04-08 11:07:57 -07:00
jax authors
ef68063497 Merge pull request from mattjj:26621
PiperOrigin-RevId: 745212009
2025-04-08 10:57:57 -07:00
Jevin Jiang
29cb6cd19b [Mosaic TPU] Add MemRead and MemStore effects to load and store ops.
So duplicated load/store ops can be removed.

PiperOrigin-RevId: 745209849
2025-04-08 10:52:44 -07:00
jax authors
03c1bf96c6 Merge pull request from mattjj:27644
PiperOrigin-RevId: 745201720
2025-04-08 10:34:01 -07:00
Jake VanderPlas
b7d430f96b jnp.repeat: don't cast repeats to array, as they must be static. 2025-04-08 10:32:03 -07:00
Matthew Johnson
4d2808c115 [mutable-arrays] limit implicit ref_swap dtype promotion
fixes 

In b7715e279, specifically this line:

b7715e279d (diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193)

we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in . A repro looks like:

```python
import jax.numpy as jnp
from jax._src import core

v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v)  # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```

We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.

So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.

I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.

```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0  # don't error!
```

But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.

PiperOrigin-RevId: 745198669
2025-04-08 10:25:10 -07:00
Matthew Johnson
ae95797708 change tack...
See https://github.com/jax-ml/jax/pull/18711

check_rep uses rep=None to indicate when an argument is a constant, and that's useful specifically when checking the backward pass for integer_pow, which has a multiplication by a constant that didn't get a pbroadcast applied to it. That is, we use rep=None as a special carve-out for constants.

The standard rules were compatible with rep=None, but the rules for higher-order primitives like scan and cond were not. So we had to upgrade them.
2025-04-08 17:11:35 +00:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Adam Paszke
b8353d1b90 [Mosaic TPU] Add support for non-32bit types in vector.extract
At least for as long as the extracted value is not a scalar.

PiperOrigin-RevId: 745151577
2025-04-08 08:16:40 -07:00
Sergei Lebedev
f5d73b89ca [pallas:mosaic_gpu] Added test for custom pretty-printing rules
PiperOrigin-RevId: 745145207
2025-04-08 07:59:44 -07:00
Adam Paszke
b926fac66e [Mosaic GPU] Simplify load/store methods now that we have fewer layouts
PiperOrigin-RevId: 745139008
2025-04-08 07:39:53 -07:00
jax authors
d6524dc461 Update XLA dependency to use revision
3764aee831.

PiperOrigin-RevId: 745130406
2025-04-08 07:11:47 -07:00
jax authors
aa6e701648 Merge pull request from apaszke:mgpu-docs
PiperOrigin-RevId: 745119028
2025-04-08 06:31:33 -07:00
Adam Paszke
511f78202f Add a skeleton for Pallas:Mosaic GPU documentation 2025-04-08 13:13:51 +00:00
jax authors
73ecf0bb48 Remove unused return wrapper in annotate_function that creates a self reference cycle loop in python.
PiperOrigin-RevId: 745099538
2025-04-08 05:25:23 -07:00
Sergei Lebedev
5f33280ded [pallas:mosaic_gpu] emit_pipeline* now allows the grid to be dynamic
PiperOrigin-RevId: 745091128
2025-04-08 04:55:15 -07:00
Sergei Lebedev
12811f08a8 Removed eager_pmap config option
It defaults to True and is not flipped to False by any internal JAX users.

PiperOrigin-RevId: 745067361
2025-04-08 03:30:36 -07:00
Dimitar (Mitko) Asenov
c4cc94a10c [Mosaic GPU] Add warpgroup lowering for RunState in Pallas.
After this change we no longer skip tests that required 'RunState`. This necessitated a small fix in the pallas lowering of `while` and also enabling multiple i32 register bundling in the `optimization_barrier` lowering.

PiperOrigin-RevId: 745065173
2025-04-08 03:23:20 -07:00
Dimitar (Mitko) Asenov
d12cbffd49 [Mosaic GPU] Refactor and generalize code in optimization_barrier.
The change in `utils.py` is to enable the use of `bitwidth` when the mlir dialect is not registered.

PiperOrigin-RevId: 745060221
2025-04-08 03:05:43 -07:00
Sergei Lebedev
af072feb5a Removed redundant passes
If a function or class has a docstring, it does not need a `pass`.

PiperOrigin-RevId: 745052107
2025-04-08 02:38:21 -07:00
Sergei Lebedev
8ed59d8b5d Removed jax._src.raise_to_shaped
It is just an identity after the "stackless" rewrite.

PiperOrigin-RevId: 745042532
2025-04-08 02:06:40 -07:00
jax authors
c2eaedfe94 Merge pull request from gnecula:export_keys
PiperOrigin-RevId: 745038060
2025-04-08 01:53:11 -07:00
Jamie Townsend
bc11a63113 Clarify jax.make_jaxpr docstring 2025-04-08 09:50:31 +02:00
Dimitar (Mitko) Asenov
19fcae1207 [Mosaic GPU] Add support for replicated warp_dim parsing and a dedicated test for parsing all canonical layouts.
PiperOrigin-RevId: 745015431
2025-04-08 00:34:11 -07:00
George Necula
51dbcd4dad [export] Add backwards compatibility test for annotate_device_placement.
This enables exporting functions that use memory kinds to place
data in different memories.

jax-fixit

PiperOrigin-RevId: 745008959
2025-04-08 00:10:23 -07:00
Alex Pivovarov
bb515aa74f Address previous FP8-related TODOs in jaxlib/XLA.
The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4

This update allows us to address previous FP8-related TODOs in jaxlib/XLA.

PiperOrigin-RevId: 744943824
2025-04-07 20:01:53 -07:00
Peter Hawkins
86de4783bb Remove unused function jax._src.interpreters.mlir.xla_computation_to_mlir_module.
PiperOrigin-RevId: 744934776
2025-04-07 19:26:20 -07:00
Dan Foreman-Mackey
31589960ff Migrate custom_call filecheck to use internal custom_call since the external one is deprecated.
PiperOrigin-RevId: 744908555
2025-04-07 17:32:02 -07:00
jax authors
4bae9cdaf3 Merge pull request from ZacCranko:harden-cache
PiperOrigin-RevId: 744905319
2025-04-07 17:19:11 -07:00
jax authors
9e0368653c Merge pull request from dfm:lin-out-fwd
PiperOrigin-RevId: 744901130
2025-04-07 17:03:43 -07:00
Zac Cranko
ca6e470d2f harden cache against jaxlib ver 2025-04-07 23:30:31 +00:00
Rachel Han
84e04fe608 Add custom pretty print rule for the unary ops with accuracy s.t. accuracy is not printed if it's None.
PiperOrigin-RevId: 744889524
2025-04-07 16:25:01 -07:00
Yash Katariya
0a72e856cf Add **experimental** with_dll_constraint API. This is for cases when the users wants to let SPMD decide the sharding.
But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.

PiperOrigin-RevId: 744888557
2025-04-07 16:21:58 -07:00
Sergei Lebedev
2944e3b2a6 Removed data_dependent_tracing_fallback config option
No internal code needs it any more.

PiperOrigin-RevId: 744870756
2025-04-07 15:27:57 -07:00