1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-22 10:26:06 +00:00

22922 Commits

Author SHA1 Message Date
Ruturaj4
1c75f5f9b1 [ROCm] jaxlib linalg fix rocm-jax-v0.4.33 2024-10-13 18:40:52 -05:00
Ruturaj4
241152c74d apt udpdate 2024-10-12 00:27:28 -05:00
Kiran Thumma
9ca504a310 Fix DOCKER_BUILDKIT issue for jax CI
Resolves: SWDEV-482397

---------

Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
(cherry picked from commit 2a0e2125f6a322b2ef25233176e8c1a607e07590)
2024-10-11 23:12:46 -05:00
Mathew Odden
8b11832d6a Fix formatting issues 2024-10-11 23:09:13 -05:00
Mathew Odden
8ff967925e [ROCm] Fix invalid repo url for EL path
(cherry picked from commit 83f88b3ca2010c13b129143ab1c059381fb62d10)
2024-10-11 23:07:58 -05:00
Mathew Odden
83e3d0a840 Change run_multi_gpu set opts 2024-10-11 22:53:48 -05:00
Ruturaj4
9be19a7ee4 BUILD file fixes and linter run 2024-10-11 22:53:34 -05:00
Ruturaj4
dfb7db0e75 [ROCm] Bring up clang support for JAX+XLA
* Add clang path

* bazelrc env fixes

* Fix wheelhouse installation and preserve wheels

* dockerfile changes

* Add target.lst

* Change target architectures

* Install bzip2 and sqlite packages
2024-10-10 15:30:14 -07:00
Mathew Odden
9dbbb3a391 [ROCm] Fix formatting on python files
Reformatted with black
2024-09-24 18:13:37 -05:00
Mathew Odden
b2c1b4973d [ROCm] Use specific amdgpu version for EL8 systems
We were always installing the latest driver versions
but this had some side effects when yum would try
to download index files from a URL with changing content.
2024-09-24 18:13:30 -05:00
Mathew Odden
f27e9d6125 [ROCm] Remove broken legacy env vars
These env vars are no longer used or need and were
being set incorrectly.
2024-09-24 18:13:24 -05:00
Ruturaj4
aff7f1e4aa [ROCM] fix typename 2024-09-23 17:12:16 -05:00
Dan Foreman-Mackey
df62e8db49 Simplify logic in jaxlib FFI_ASSIGN_OR_RETURN macro, and fix gcc build.
In https://github.com/google/jax/issues/23687, it was reported that recent jaxlib changes introduced issues when building from source using gcc, instead of the clang build that we test. I'm not 100% sure why the previous macro didn't work, but in investigating I found a version that seems to work on both clang and gcc with simpler logic.

PiperOrigin-RevId: 675641259
2024-09-23 17:00:12 -05:00
Peter Hawkins
80e1c94de6 Prepare for v0.4.33 release.
This release is branched off the v0.4.32 release, with two changes:
a) a fixed libtpu pin, and
b) a patch to revert an F64 tanh issue on CPU.
2024-09-16 13:30:35 +00:00
Peter Hawkins
1594d2f30f Prepare for v0.4.32 release. 2024-09-11 14:29:55 -04:00
Peter Hawkins
ed849ff9e0 Make sure to call the superclass' __init__() on a newly created instance in PositionalSharding._remake().
If we don't do this, the C++ base class is left in an uninitialized state, leading to failures elsewhere in the test suite.

PiperOrigin-RevId: 673411282
2024-09-11 08:54:50 -07:00
Peter Hawkins
2bd1fdead8 Relax test tolerance in pinv test to fix a CI failure on Windows CPU.
https://github.com/google/jax/actions/runs/10812364182/job/29993831201

PiperOrigin-RevId: 673409820
2024-09-11 08:50:57 -07:00
jax authors
e869a9d65e Merge pull request from kaixih:key_value_seq_lengths
PiperOrigin-RevId: 673409724
2024-09-11 08:49:50 -07:00
Sergei Lebedev
ea68f4569c Internal change
PiperOrigin-RevId: 673409076
2024-09-11 08:47:58 -07:00
Peter Hawkins
49dd6ed8d8 Disable a pallas export compatibility test that fails on TPU v6e.
PiperOrigin-RevId: 673295487
2024-09-11 02:00:42 -07:00
Peter Hawkins
808003b4e2 Update users of jax.tree.map() to be more careful about how they handle Nones.
Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 673258116
2024-09-10 23:54:11 -07:00
Justin Fu
e3c4b20fa0 [Pallas] Implement tiled and swizzled Memref loads for Mosaic GPU via "GPUBlockSpec"
PiperOrigin-RevId: 673165201
2024-09-10 17:21:20 -07:00
Justin Fu
c659dc9a01 [Pallas] Disable win32 gpu_ops_test.
PiperOrigin-RevId: 673149107
2024-09-10 16:23:17 -07:00
jax authors
14b86259d5 Merge pull request from pschuh:docs-update
PiperOrigin-RevId: 673148286
2024-09-10 16:20:42 -07:00
Yash Katariya
0b04dd022a Reverts 5e4250e64bb415be94ddc8a80dba6083a6a4123a
PiperOrigin-RevId: 673141373
2024-09-10 16:01:00 -07:00
Sergei Lebedev
e7b261c386 Removed a sneaky comma in Pallas Mosaic GPU lowering
PiperOrigin-RevId: 673109846
2024-09-10 14:37:07 -07:00
Yash Katariya
5e4250e64b Prepare for jax 0.4.32 release
PiperOrigin-RevId: 673105544
2024-09-10 14:26:17 -07:00
jax authors
4957ab9a5e Clean up JAX backend for all backends to avoid dangling PyClient references.
PiperOrigin-RevId: 673102539
2024-09-10 14:19:00 -07:00
Parker Schuh
030b6c655d Update the docs for conv_general_dilated to clarify 'W' 'H'. 2024-09-10 14:02:06 -07:00
Ayaka
46bcb1e057 [Pallas] Simplify lowering and fix the test for lax.erf_inv_p
This PR is a follow-up of https://github.com/google/jax/pull/23192, which implements the lowering rule for `lax.erf_inv_p`. However, I've realised that the lowering rule can be simplified, and the test for it was moved to the wrong place. This PR resolves the above 2 issues.

After merging this PR, I will continue with https://github.com/google/jax/pull/22310, which adds 64-bit lowering support for `lax.erf_inv_p`.

PiperOrigin-RevId: 673095319
2024-09-10 14:00:28 -07:00
jax authors
8681bf6dc2 Update XLA dependency to use revision
720b2c5334.

PiperOrigin-RevId: 673090332
2024-09-10 13:45:51 -07:00
jax authors
deda649138 Merge pull request from selamw1:kron_outer_docstring
PiperOrigin-RevId: 673071112
2024-09-10 12:54:38 -07:00
Yash Katariya
90892f533a Check for jax.Sharding's number of devices instead of py_array.num_shards which looks at IFRT sharding's num_devices to check against global_devices and deciding whether to fall back to python shard_arg.
This is because IFRT sharding's `num_shards` method is busted. It doesn't return the global shards (in some cases) which leads to JAX program unnecessarily falling back to python.

PiperOrigin-RevId: 673067095
2024-09-10 12:43:52 -07:00
jax authors
02ab741155 Merge pull request from dfm:ffi-release-notes
PiperOrigin-RevId: 673065825
2024-09-10 12:39:49 -07:00
selamw1
bacda603fc kron_and_outer_docstring_added
description_fixed_and_kron_desc_added

description_text_and_return_fixed

description_text_and_return_fixed

return_fixed
2024-09-10 11:13:08 -07:00
jax authors
6037dba98b Merge pull request from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 673011373
2024-09-10 10:18:09 -07:00
jax authors
a8b68c26b0 Merge pull request from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 673007161
2024-09-10 10:07:46 -07:00
Sergei Lebedev
9fa0164ad2 Estimate the amount of required scratch SMEM automatically in Pallas Mosaic GPU lowering
No estimation is done if `smem_scratch_bytes` was explicitly specified via
`compiler_params=`.

PiperOrigin-RevId: 672998660
2024-09-10 09:43:04 -07:00
rajasekharporeddy
ee04646f33 Improve docs for jax.numpy: float_power and nextafter 2024-09-10 22:08:44 +05:30
rajasekharporeddy
c5bc2412a7 Improve doc for jnp.trim_zeros 2024-09-10 20:56:49 +05:30
Peter Hawkins
1b2ba9d1c2 Disable two lax_scipy_test testcases that fail on TPU v6e.
PiperOrigin-RevId: 672973757
2024-09-10 08:26:27 -07:00
Bart Chrzaszcz
062a69a97e Make JAX extract the mesh from an AUTO in/out sharding.
Automatic partitioners using JAX+Shardy want to partition models which are fully marked as `AUTO` - so no in/out sharding with a `NamedSharding`. In such a case they weren't seeing the mesh on the MLIR module. This makes sure we extract it from the `AUTO` sharding.

PiperOrigin-RevId: 672881018
2024-09-10 03:07:02 -07:00
Ayaka
7d2f0a75c1 [Pallas GPU] Fix the behavior of jnp.sign(jnp.nan) and move the TPU test case for jnp.sign into the general test
This PR is similar to https://github.com/google/jax/pull/23192, which moves TPU test case for `lax.erf_inv` into the general test

Fixes https://github.com/google/jax/issues/23504

PiperOrigin-RevId: 672682048
2024-09-09 14:49:40 -07:00
Peter Hawkins
72c095261f Improve the docstring for jax.Array.copy_to_host_async.
PiperOrigin-RevId: 672666190
2024-09-09 14:04:17 -07:00
Yash Katariya
d6c36255e8 Create optimal order for v5e:8 devices which is [0, 1, 2, 3, 7, 6, 5, 4]
PiperOrigin-RevId: 672652104
2024-09-09 13:24:18 -07:00
jax authors
623cbb8ce7 Update XLA dependency to use revision
32004c2727.

PiperOrigin-RevId: 672647417
2024-09-09 13:11:01 -07:00
jax authors
cdd68b5f94 Merge pull request from hawkinsp:nightly
PiperOrigin-RevId: 672634290
2024-09-09 12:31:56 -07:00
Peter Hawkins
5cc5ed2c5c Disable a shard_map test case that fails on TPU v5e.
PiperOrigin-RevId: 672618556
2024-09-09 11:45:41 -07:00
Peter Hawkins
b975592478 Change nightly install commands to include all packages.
pip doesn't update transitive dependencies, and we probably want the latest versions of everything when installing a nightly.
2024-09-09 14:38:41 -04:00
Justin Fu
4bdfe09241 [Pallas] Fully skip GPU attention tests on win32.
PiperOrigin-RevId: 672588009
2024-09-09 10:21:33 -07:00