22432 Commits

Author SHA1 Message Date
Roy Frostig
3c223cd253 docs: tidy up titles and headings
This shortens some titles and makes them more consistent. It also
removes "JAX" from several titles ("in JAX", "for JAX", "JAX's",
etc.). Since these are JAX docs, that ought to be clear from context.
2024-08-13 11:53:57 -07:00
Yash Katariya
4533aeaf26 Remove jax_enable_memories conditionals from JAX and remove it from tests too.
PiperOrigin-RevId: 662322241
2024-08-12 19:15:43 -07:00
jax authors
833560deb1 Merge pull request #23023 from froystig:docs3
PiperOrigin-RevId: 662318149
2024-08-12 19:03:12 -07:00
Roy Frostig
b8f8b7b07f docs: sentence case page titles, section headings, some content 2024-08-12 18:12:17 -07:00
Parker Schuh
734ebd5708 Support donating arrays with non-default layouts by setting up XLA donation
directly instead of defining aliasing for arrays with potentially incompatible
layouts.

PiperOrigin-RevId: 662258042
2024-08-12 15:58:52 -07:00
jax authors
7afa90780b Update XLA dependency to use revision
b18bc61250.

PiperOrigin-RevId: 662252977
2024-08-12 15:43:30 -07:00
Justin Fu
aa66fb37c3 [Pallas][XLA:Mosaic] Add python stack traces to Mosaic errors that occur in Pallas.
PiperOrigin-RevId: 662232859
2024-08-12 14:42:48 -07:00
jax authors
b2eb2d4f30 Merge pull request #22600 from ROCm:feat/manylinux_2_28
PiperOrigin-RevId: 662217658
2024-08-12 13:59:42 -07:00
Roy Frostig
2644299f7e docs: sentence case index and sub-index headings
We currently use both forms, so for consistency (and easier reading),
pick this one.
2024-08-12 13:52:43 -07:00
Mathew Odden
fafa03c60f Add missing CPython build deps for pyenv 2024-08-12 15:01:34 -05:00
Mathew Odden
701cda8ebd Fix not finding wheels in bazel output 2024-08-12 15:01:34 -05:00
Mathew Odden
df2d140f51 Fix jenkins notty issue 2024-08-12 15:01:34 -05:00
Mathew Odden
319ebf81c1 Add defaults for ROCm build vars 2024-08-12 15:01:34 -05:00
Mathew Odden
abe44f6d9e Add copyright and license headers to new files 2024-08-12 15:01:34 -05:00
Mathew Odden
a1a0a4ecdd Add support for ROCm development builds
Use get_rocm.py changes in ci_build to pull in
development builds for ROCm.

Specify ROCM_BUILD_JOB and ROCM_BUILD_NUM for
activating the development build path.
2024-08-12 15:01:34 -05:00
Mathew Odden
3175f13c59 Add internal release support to get_rocm.py 2024-08-12 15:01:34 -05:00
Mathew Odden
1e58d76772 [ROCm] Change ROCm builds to manylinux wheels 2024-08-12 15:01:34 -05:00
jax authors
e5eaff84bd Replace pjrt_c_api_gpu_plugin.so symlink with XLA dependency.
The runfiles of the original targets were lost when the symlinked files were used.

This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When pjrt_c_api_gpu_plugin.so is simlinked, the content of the runfiles is lost. With proper XLA target dependency the runfiles are preserved.

PiperOrigin-RevId: 662197057
2024-08-12 13:01:18 -07:00
Brian Wieder
ee31e95ecd Register shutdown code at import to hopefully get registered before any other atexit callbacks.
`atexit` callbacks are called in a LIFO order, meaning that since Jax currently registers its callback at runtime rather than import time, it gets called before any `atexit` callbacks registered at import time.

PiperOrigin-RevId: 662164776
2024-08-12 11:29:08 -07:00
jax authors
7a873c0312 Merge pull request #23014 from google:dependabot/github_actions/actions/upload-artifact-4.3.6
PiperOrigin-RevId: 662161436
2024-08-12 11:18:29 -07:00
jax authors
da259f8d9c Merge pull request #22979 from jakevdp:intersect1d-size
PiperOrigin-RevId: 662154746
2024-08-12 11:02:07 -07:00
dependabot[bot]
802abfef92
Bump actions/upload-artifact from 4.3.5 to 4.3.6
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.3.5 to 4.3.6.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](89ef406dd8...834a144ee9)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-08-12 17:54:09 +00:00
Yash Katariya
53045380b1 Make custom partitioning work without a mesh context manager. If the arguments have NamedSharding on them, then inside partition function, we should get NamedSharding without the existence of the mesh context manager
PiperOrigin-RevId: 662146686
2024-08-12 10:40:31 -07:00
Dan Foreman-Mackey
60bf5b7727 Add a jax.process_indices function.
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.

PiperOrigin-RevId: 662142636
2024-08-12 10:30:41 -07:00
Zhuo Peng
ad74e55dbc Support None leaves in arguments to gradient of a call_tf wrapped function.
PiperOrigin-RevId: 662115139
2024-08-12 09:24:25 -07:00
Dan Foreman-Mackey
4eb5ef28ef Update shape polymorphism tests to skip lu_pivots_to_permutations tests when jaxlib version is too old.
PiperOrigin-RevId: 662088901
2024-08-12 08:13:27 -07:00
jax authors
112cae1dad Merge pull request #22997 from superbobry:maint-2
PiperOrigin-RevId: 662079641
2024-08-12 07:41:24 -07:00
Christos Perivolaropoulos
cd4e91b2b0 [mosaic_gpu] Store untiled splat layout
PiperOrigin-RevId: 662077826
2024-08-12 07:34:07 -07:00
George Necula
7f680aaab8 [pallas] Move ops_test.py from jax_triton to jax/pallas
The `jax_triton/ops_test.py` has over time accumulated many tests that are in fact platform-independent tests.
Furthermore, those tests were only Google-internal, and they can be external as well.

This moves test coverage for Pallas from the jax_triton package to the Pallas core package.

A small number of the tests were deleted, because they were already present in Pallas, e.g., tests in `jax_triton/ops_test.py:ControlFlowTest`, and tests for unary and binary ops in `jax_triton/ops_test.py:OpsTest`.

The other tests were distributed to different files in the Pallas repo, according to their purpose:

  * tests in `jax_triton/ops_test.py:PrettyPrintingTest` are moved to `tpu_pallas_test.py::PrettyPrintingTest`
  * tests in `jax_triton/ops_test.py::IndexingTest` are appended to `indexing_test.py::IndexingTest`; some other indexing tests from `jax_triton/ops_test.py::LoadStoreTest` are also moved there.
   * some tests in `jax_triton/ops_test.py:OpsTest` are moved to `ops_test.py::OpsTest`.
   * some tests for TPU specific ops in `jax_triton/ops_test.py:OpsTest` are moved to a new test file `tpu_ops_tests.py`

Some of this required adding sharding and hypothesis support to `ops_test.py`,
and adding TPU versions of `indexing_test.py`.

PiperOrigin-RevId: 662045774
2024-08-12 05:09:37 -07:00
Sergei Lebedev
c9142cbe75 Collapsed a few unnecessary `if TYPE_CHECKING` blocks 2024-08-12 13:08:55 +01:00
Dan Foreman-Mackey
3c014a4c27 Add support for shape polymorphism with lu_pivots_to_permutation.
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.

PiperOrigin-RevId: 662024940
2024-08-12 03:39:54 -07:00
Dan Foreman-Mackey
ae5b4284d5 Make ffi_call tests backwards compatible with the released jaxlib.
PiperOrigin-RevId: 662017095
2024-08-12 03:08:49 -07:00
jax authors
be4d52b814 Merge pull request #22667 from ROCm:rocm-jax-triton-add-get_arch_detail
PiperOrigin-RevId: 662007143
2024-08-12 02:30:49 -07:00
Rahul Batra
4b7c198a1c [ROCm]: Add get_arch_details for triton kernel call 2024-08-12 06:16:27 +00:00
jax authors
d2fa88496d Merge pull request #22991 from froystig:keys2
PiperOrigin-RevId: 661891112
2024-08-11 16:08:44 -07:00
jax authors
312ebdb8ba Merge pull request #22989 from froystig:keys
PiperOrigin-RevId: 661890876
2024-08-11 16:08:25 -07:00
jax authors
b9e80df624 Merge pull request #22992 from froystig:docs2
PiperOrigin-RevId: 661890694
2024-08-11 16:04:32 -07:00
jax authors
3c7bd54c54 Update XLA dependency to use revision
a6fc99fadc.

PiperOrigin-RevId: 661883548
2024-08-11 15:16:44 -07:00
Roy Frostig
c54ffd41bc in dot docstring, format and link to dot_general 2024-08-11 12:44:50 -07:00
Roy Frostig
dd535d88a7 emphasize typed over legacy RNG keys in random module docs
Update both docstrings and move the `PRNGKey` function listing lower
in the API reference.
2024-08-11 12:41:50 -07:00
Dan Foreman-Mackey
4f8f66f10b Add more complete tests for attribute serialization when lowering an FFI call.
PiperOrigin-RevId: 661849681
2024-08-11 12:34:02 -07:00
jax authors
6da8fdff74 Merge pull request #22988 from froystig:docs
PiperOrigin-RevId: 661816430
2024-08-11 08:36:12 -07:00
Roy Frostig
371935cc10 update README and several docs to typed RNG keys 2024-08-11 08:09:47 -07:00
Roy Frostig
ded5b5366b indent consistently in auto-parallelization and shard_map tutorials 2024-08-11 08:04:14 -07:00
Dan Foreman-Mackey
96045043a4 Move ir_attribute builder from extend.ffi to interpreters.mlir.
While this function is currently only used for lowering FFI calls, it could be used most places where `ir.*Attr` objects are directly constructed.

PiperOrigin-RevId: 661761712
2024-08-11 01:47:49 -07:00
jax authors
9e86416a32 Update XLA dependency to use revision
dfb2a8b498.

PiperOrigin-RevId: 661668331
2024-08-10 14:20:14 -07:00
Jake VanderPlas
c2c116dc5c jnp.intersect1d: add support for static size argument. 2024-08-10 05:22:05 -07:00
Yash Katariya
c08656c61d [Rollback] We still want to allow multiple meshes in the user program
Reverts dd958adc39550d2758ecdb13809c6d85df7658a2

PiperOrigin-RevId: 661537233
2024-08-09 23:17:46 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
jax authors
7a75c96aa9 Update XLA dependency to use revision
46e205a0b6.

PiperOrigin-RevId: 661412627
2024-08-09 14:53:42 -07:00