1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-18 12:56:07 +00:00
Commit Graph

26929 Commits

Author SHA1 Message Date
Charles Hofer
6985f0d91c Remove 6.2.4 build 2025-04-15 16:50:06 +00:00
Charles Hofer
9de75a25df Trivial change for CI 2025-04-14 17:42:21 +00:00
Charles Hofer
ea55e59d70 Use .cfg file 2025-04-14 17:30:08 +00:00
Charles Hofer
98baf09f7d Add clang config file 2025-04-14 17:26:16 +00:00
Charles Hofer
248e638674 Move clang gcc path options to config file 2025-04-14 17:06:47 +00:00
Charles Hofer
3b4a7b029b Make Clang use manylinux C++ standard library 2025-04-11 19:18:23 +00:00
Charles Hofer
16d737b088 Account for versioned clang binaries 2025-04-10 15:46:54 +00:00
Yash Katariya
9af8e04fcd Fix typo in the error message
PiperOrigin-RevId: 745375892
2025-04-09 18:05:17 +00:00
Yash Katariya
48fcf02881 Rename psum2 to psum_invariant and put it in lax_parallel. We shouldn't expose this to public API and have users use psum instead which will dispatch to psum_invariant when check_rep=True.
PiperOrigin-RevId: 745352875
2025-04-09 18:05:16 +00:00
Yash Katariya
ef84e9d33b Rename pbroadcast to pvary and expose it as jax.lax.pvary.
PiperOrigin-RevId: 745342103
2025-04-09 18:05:16 +00:00
vfdev-5
ba0879a311 Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
  - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
2025-04-09 18:05:16 +00:00
Peter Hawkins
8cd2843d1a [JAX] Remove calls to xla_computation_to_mlir_module.
This (private) API will shortly be deleted, and hlo_to_stablehlo is its replacement.

PiperOrigin-RevId: 745333506
2025-04-09 18:05:16 +00:00
kaixih
8993c0e00c format 2025-04-09 18:05:16 +00:00
kaixih
56d13e0d9d Remove asserts 2025-04-09 18:05:16 +00:00
kaixih
bfe79232e9 Enable public doc for scaled dot 2025-04-09 18:05:16 +00:00
Jevin Jiang
185d65fb67 [ragged-paged-attn] Unify kv strided load to one.
I expected Mosaic can canonicalize 2 same strided loads to one but it did not. (We will fix this in Mosaic). For now, manually converting to one strided load boosts 20~35% speedup in both v6e and v5e single chip for Meta-Llama-3-8B.

PiperOrigin-RevId: 745294058
2025-04-09 18:05:16 +00:00
Yash Katariya
8a6bfd6d7d Make changes to shard_map to prepare for setting varying_axes_in_types to True.
The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
2025-04-09 18:05:16 +00:00
Peter Hawkins
d094562bd6 Split weakref_lru_cache into its own extension.
Now that db11efab3b has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.

There's no reason weakref_lru_cache is in the same Python extension as everything else.

PiperOrigin-RevId: 745271825
2025-04-09 18:05:16 +00:00
Dan Foreman-Mackey
686144b099 Finalize deprecation of ffi_call with inline arguments.
PiperOrigin-RevId: 745261995
2025-04-09 18:05:16 +00:00
Peter Hawkins
879b72a603 Remove reexports of ml_dtypes types from xla_client.py.
These should be used directly from ml_dtypes.

PiperOrigin-RevId: 745256523
2025-04-09 18:05:16 +00:00
Sergei Lebedev
8d7110636c Added jax.no_tracing to the API docs
PiperOrigin-RevId: 745247778
2025-04-09 18:05:16 +00:00
Peter Hawkins
7fd3a0751c Simplify handling of type stubs in nanobind extension rules.
Pass pytype_srcs as data to the pybind_extension rule.

PiperOrigin-RevId: 745238783
2025-04-09 18:05:16 +00:00
Jamie Townsend
872a43d7f9 Clarify jax.make_jaxpr docstring 2025-04-09 18:05:16 +00:00
Jake VanderPlas
6965601dea jnp.linalg: add symmetrize_input argument & docs 2025-04-09 18:05:16 +00:00
Jake VanderPlas
84210a32ac jnp.repeat: don't cast repeats to array, as they must be static. 2025-04-09 18:05:16 +00:00
Matthew Johnson
689f766122 change tack...
See https://github.com/jax-ml/jax/pull/18711

check_rep uses rep=None to indicate when an argument is a constant, and that's useful specifically when checking the backward pass for integer_pow, which has a multiplication by a constant that didn't get a pbroadcast applied to it. That is, we use rep=None as a special carve-out for constants.

The standard rules were compatible with rep=None, but the rules for higher-order primitives like scan and cond were not. So we had to upgrade them.
2025-04-09 18:05:16 +00:00
Matthew Johnson
67df04c688 [shard-map] canonicalize rep=None to be rep={all possible axes}
None is meant to represent the same thing as {replicated over all possible axes}. But without this canonicalization, we could compare None as not equal to {all possible axes}.

fixes 

Unrelated: in several places, including the _check_rep path, we don't handle partial auto correctly, since we treat {all possible axes} as {all mesh axes}, but actually it should be more like {all mesh axes} - auto. We'll leave that fix for a follow-up...
2025-04-09 18:05:16 +00:00
Jevin Jiang
4e094523ee [Mosaic TPU] Add MemRead and MemStore effects to load and store ops.
So duplicated load/store ops can be removed.

PiperOrigin-RevId: 745209849
2025-04-09 18:05:16 +00:00
Matthew Johnson
54e8df41f8 [shard-map] add while_map rep rule
fixes 
2025-04-09 18:05:16 +00:00
Matthew Johnson
4b4f8285d0 [mutable-arrays] limit implicit ref_swap dtype promotion
fixes 

In b7715e279, specifically this line:

b7715e279d (diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193)

we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in . A repro looks like:

```python
import jax.numpy as jnp
from jax._src import core

v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v)  # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```

We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.

So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.

I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.

```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0  # don't error!
```

But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.

PiperOrigin-RevId: 745198669
2025-04-09 18:05:16 +00:00
Peter Hawkins
c4340d966e Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-09 18:05:16 +00:00
Adam Paszke
2c17538838 [Mosaic TPU] Add support for non-32bit types in vector.extract
At least for as long as the extracted value is not a scalar.

PiperOrigin-RevId: 745151577
2025-04-09 18:05:16 +00:00
Sergei Lebedev
cb8d42cd11 [pallas:mosaic_gpu] Added test for custom pretty-printing rules
PiperOrigin-RevId: 745145207
2025-04-09 18:05:16 +00:00
Adam Paszke
72e420c2ed [Mosaic GPU] Simplify load/store methods now that we have fewer layouts
PiperOrigin-RevId: 745139008
2025-04-09 18:05:16 +00:00
jax authors
aaef00f029 Update XLA dependency to use revision
3764aee831.

PiperOrigin-RevId: 745130406
2025-04-09 18:05:16 +00:00
Adam Paszke
dad3e7d7a0 Add a skeleton for Pallas:Mosaic GPU documentation 2025-04-09 18:05:16 +00:00
jax authors
f12a4aff23 Remove unused return wrapper in annotate_function that creates a self reference cycle loop in python.
PiperOrigin-RevId: 745099538
2025-04-09 18:05:16 +00:00
Sergei Lebedev
199fa7de84 [pallas:mosaic_gpu] emit_pipeline* now allows the grid to be dynamic
PiperOrigin-RevId: 745091128
2025-04-09 18:05:16 +00:00
Sergei Lebedev
1264f07e97 Removed eager_pmap config option
It defaults to True and is not flipped to False by any internal JAX users.

PiperOrigin-RevId: 745067361
2025-04-09 18:05:16 +00:00
Dimitar (Mitko) Asenov
e8dcd3799f [Mosaic GPU] Add warpgroup lowering for RunState in Pallas.
After this change we no longer skip tests that required 'RunState`. This necessitated a small fix in the pallas lowering of `while` and also enabling multiple i32 register bundling in the `optimization_barrier` lowering.

PiperOrigin-RevId: 745065173
2025-04-09 18:05:16 +00:00
Dimitar (Mitko) Asenov
cfd9e513aa [Mosaic GPU] Refactor and generalize code in optimization_barrier.
The change in `utils.py` is to enable the use of `bitwidth` when the mlir dialect is not registered.

PiperOrigin-RevId: 745060221
2025-04-09 18:05:16 +00:00
Sergei Lebedev
2ade7da06e Removed redundant passes
If a function or class has a docstring, it does not need a `pass`.

PiperOrigin-RevId: 745052107
2025-04-09 18:05:16 +00:00
Sergei Lebedev
82d5d6e6e0 Removed jax._src.raise_to_shaped
It is just an identity after the "stackless" rewrite.

PiperOrigin-RevId: 745042532
2025-04-09 18:05:16 +00:00
George Necula
3aacfce2d3 [export] Add support for serializing functions with PRNG keys as inputs/outputs
This introduces version 4 of serialization, fully backwards compatible
with versions 2 and 3.

Fixes: 
2025-04-09 18:05:16 +00:00
Dimitar (Mitko) Asenov
5ddff5f6db [Mosaic GPU] Add support for replicated warp_dim parsing and a dedicated test for parsing all canonical layouts.
PiperOrigin-RevId: 745015431
2025-04-09 18:05:16 +00:00
George Necula
1b918dcd70 [export] Add backwards compatibility test for annotate_device_placement.
This enables exporting functions that use memory kinds to place
data in different memories.

jax-fixit

PiperOrigin-RevId: 745008959
2025-04-09 18:05:16 +00:00
Alex Pivovarov
03e7cb878c Address previous FP8-related TODOs in jaxlib/XLA.
The ml_dtype requirement in JAX was updated to version 0.5.0+ (on Mar 20, 2025) - commit 4b7ead4

This update allows us to address previous FP8-related TODOs in jaxlib/XLA.

PiperOrigin-RevId: 744943824
2025-04-09 18:05:16 +00:00
Peter Hawkins
d3f27dbcf3 Remove unused function jax._src.interpreters.mlir.xla_computation_to_mlir_module.
PiperOrigin-RevId: 744934776
2025-04-09 18:05:16 +00:00
Dan Foreman-Mackey
322632fd4c Migrate custom_call filecheck to use internal custom_call since the external one is deprecated.
PiperOrigin-RevId: 744908555
2025-04-09 18:05:16 +00:00
Zac Cranko
4c6b0361bf harden cache against jaxlib ver 2025-04-09 18:05:16 +00:00