Ruturaj4
1c75f5f9b1
[ROCm] jaxlib linalg fix
rocm-jax-v0.4.33
2024-10-13 18:40:52 -05:00
Ruturaj4
241152c74d
apt udpdate
2024-10-12 00:27:28 -05:00
Kiran Thumma
9ca504a310
Fix DOCKER_BUILDKIT issue for jax CI
...
Resolves: SWDEV-482397
---------
Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
(cherry picked from commit 2a0e2125f6a322b2ef25233176e8c1a607e07590)
2024-10-11 23:12:46 -05:00
Mathew Odden
8b11832d6a
Fix formatting issues
2024-10-11 23:09:13 -05:00
Mathew Odden
8ff967925e
[ROCm] Fix invalid repo url for EL path
...
(cherry picked from commit 83f88b3ca2010c13b129143ab1c059381fb62d10)
2024-10-11 23:07:58 -05:00
Mathew Odden
83e3d0a840
Change run_multi_gpu set opts
2024-10-11 22:53:48 -05:00
Ruturaj4
9be19a7ee4
BUILD file fixes and linter run
2024-10-11 22:53:34 -05:00
Ruturaj4
dfb7db0e75
[ROCm] Bring up clang support for JAX+XLA
...
* Add clang path
* bazelrc env fixes
* Fix wheelhouse installation and preserve wheels
* dockerfile changes
* Add target.lst
* Change target architectures
* Install bzip2 and sqlite packages
2024-10-10 15:30:14 -07:00
Mathew Odden
9dbbb3a391
[ROCm] Fix formatting on python files
...
Reformatted with black
2024-09-24 18:13:37 -05:00
Mathew Odden
b2c1b4973d
[ROCm] Use specific amdgpu version for EL8 systems
...
We were always installing the latest driver versions
but this had some side effects when yum would try
to download index files from a URL with changing content.
2024-09-24 18:13:30 -05:00
Mathew Odden
f27e9d6125
[ROCm] Remove broken legacy env vars
...
These env vars are no longer used or need and were
being set incorrectly.
2024-09-24 18:13:24 -05:00
Ruturaj4
aff7f1e4aa
[ROCM] fix typename
2024-09-23 17:12:16 -05:00
Dan Foreman-Mackey
df62e8db49
Simplify logic in jaxlib FFI_ASSIGN_OR_RETURN macro, and fix gcc build.
...
In https://github.com/google/jax/issues/23687 , it was reported that recent jaxlib changes introduced issues when building from source using gcc, instead of the clang build that we test. I'm not 100% sure why the previous macro didn't work, but in investigating I found a version that seems to work on both clang and gcc with simpler logic.
PiperOrigin-RevId: 675641259
2024-09-23 17:00:12 -05:00
Peter Hawkins
80e1c94de6
Prepare for v0.4.33 release.
...
This release is branched off the v0.4.32 release, with two changes:
a) a fixed libtpu pin, and
b) a patch to revert an F64 tanh issue on CPU.
2024-09-16 13:30:35 +00:00
Peter Hawkins
1594d2f30f
Prepare for v0.4.32 release.
2024-09-11 14:29:55 -04:00
Peter Hawkins
ed849ff9e0
Make sure to call the superclass' __init__() on a newly created instance in PositionalSharding._remake().
...
If we don't do this, the C++ base class is left in an uninitialized state, leading to failures elsewhere in the test suite.
PiperOrigin-RevId: 673411282
2024-09-11 08:54:50 -07:00
Peter Hawkins
2bd1fdead8
Relax test tolerance in pinv test to fix a CI failure on Windows CPU.
...
https://github.com/google/jax/actions/runs/10812364182/job/29993831201
PiperOrigin-RevId: 673409820
2024-09-11 08:50:57 -07:00
jax authors
e869a9d65e
Merge pull request #23415 from kaixih:key_value_seq_lengths
...
PiperOrigin-RevId: 673409724
2024-09-11 08:49:50 -07:00
Sergei Lebedev
ea68f4569c
Internal change
...
PiperOrigin-RevId: 673409076
2024-09-11 08:47:58 -07:00
Peter Hawkins
49dd6ed8d8
Disable a pallas export compatibility test that fails on TPU v6e.
...
PiperOrigin-RevId: 673295487
2024-09-11 02:00:42 -07:00
Peter Hawkins
808003b4e2
Update users of jax.tree.map() to be more careful about how they handle Nones.
...
Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.
Fix code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.
PiperOrigin-RevId: 673258116
2024-09-10 23:54:11 -07:00
Justin Fu
e3c4b20fa0
[Pallas] Implement tiled and swizzled Memref loads for Mosaic GPU via "GPUBlockSpec"
...
PiperOrigin-RevId: 673165201
2024-09-10 17:21:20 -07:00
Justin Fu
c659dc9a01
[Pallas] Disable win32 gpu_ops_test.
...
PiperOrigin-RevId: 673149107
2024-09-10 16:23:17 -07:00
jax authors
14b86259d5
Merge pull request #23549 from pschuh:docs-update
...
PiperOrigin-RevId: 673148286
2024-09-10 16:20:42 -07:00
Yash Katariya
0b04dd022a
Reverts 5e4250e64bb415be94ddc8a80dba6083a6a4123a
...
PiperOrigin-RevId: 673141373
2024-09-10 16:01:00 -07:00
Sergei Lebedev
e7b261c386
Removed a sneaky comma in Pallas Mosaic GPU lowering
...
PiperOrigin-RevId: 673109846
2024-09-10 14:37:07 -07:00
Yash Katariya
5e4250e64b
Prepare for jax 0.4.32 release
...
PiperOrigin-RevId: 673105544
2024-09-10 14:26:17 -07:00
jax authors
4957ab9a5e
Clean up JAX backend for all backends to avoid dangling PyClient references.
...
PiperOrigin-RevId: 673102539
2024-09-10 14:19:00 -07:00
Parker Schuh
030b6c655d
Update the docs for conv_general_dilated to clarify 'W' 'H'.
2024-09-10 14:02:06 -07:00
Ayaka
46bcb1e057
[Pallas] Simplify lowering and fix the test for lax.erf_inv_p
...
This PR is a follow-up of https://github.com/google/jax/pull/23192 , which implements the lowering rule for `lax.erf_inv_p`. However, I've realised that the lowering rule can be simplified, and the test for it was moved to the wrong place. This PR resolves the above 2 issues.
After merging this PR, I will continue with https://github.com/google/jax/pull/22310 , which adds 64-bit lowering support for `lax.erf_inv_p`.
PiperOrigin-RevId: 673095319
2024-09-10 14:00:28 -07:00
jax authors
8681bf6dc2
Update XLA dependency to use revision
...
720b2c5334
.
PiperOrigin-RevId: 673090332
2024-09-10 13:45:51 -07:00
jax authors
deda649138
Merge pull request #23443 from selamw1:kron_outer_docstring
...
PiperOrigin-RevId: 673071112
2024-09-10 12:54:38 -07:00
Yash Katariya
90892f533a
Check for jax.Sharding
's number of devices instead of py_array.num_shards
which looks at IFRT sharding's num_devices to check against global_devices
and deciding whether to fall back to python shard_arg.
...
This is because IFRT sharding's `num_shards` method is busted. It doesn't return the global shards (in some cases) which leads to JAX program unnecessarily falling back to python.
PiperOrigin-RevId: 673067095
2024-09-10 12:43:52 -07:00
jax authors
02ab741155
Merge pull request #23478 from dfm:ffi-release-notes
...
PiperOrigin-RevId: 673065825
2024-09-10 12:39:49 -07:00
selamw1
bacda603fc
kron_and_outer_docstring_added
...
description_fixed_and_kron_desc_added
description_text_and_return_fixed
description_text_and_return_fixed
return_fixed
2024-09-10 11:13:08 -07:00
jax authors
6037dba98b
Merge pull request #23536 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 673011373
2024-09-10 10:18:09 -07:00
jax authors
a8b68c26b0
Merge pull request #23540 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 673007161
2024-09-10 10:07:46 -07:00
Sergei Lebedev
9fa0164ad2
Estimate the amount of required scratch SMEM automatically in Pallas Mosaic GPU lowering
...
No estimation is done if `smem_scratch_bytes` was explicitly specified via
`compiler_params=`.
PiperOrigin-RevId: 672998660
2024-09-10 09:43:04 -07:00
rajasekharporeddy
ee04646f33
Improve docs for jax.numpy: float_power and nextafter
2024-09-10 22:08:44 +05:30
rajasekharporeddy
c5bc2412a7
Improve doc for jnp.trim_zeros
2024-09-10 20:56:49 +05:30
Peter Hawkins
1b2ba9d1c2
Disable two lax_scipy_test testcases that fail on TPU v6e.
...
PiperOrigin-RevId: 672973757
2024-09-10 08:26:27 -07:00
Bart Chrzaszcz
062a69a97e
Make JAX extract the mesh from an AUTO
in/out sharding.
...
Automatic partitioners using JAX+Shardy want to partition models which are fully marked as `AUTO` - so no in/out sharding with a `NamedSharding`. In such a case they weren't seeing the mesh on the MLIR module. This makes sure we extract it from the `AUTO` sharding.
PiperOrigin-RevId: 672881018
2024-09-10 03:07:02 -07:00
Ayaka
7d2f0a75c1
[Pallas GPU] Fix the behavior of jnp.sign(jnp.nan)
and move the TPU test case for jnp.sign
into the general test
...
This PR is similar to https://github.com/google/jax/pull/23192 , which moves TPU test case for `lax.erf_inv` into the general test
Fixes https://github.com/google/jax/issues/23504
PiperOrigin-RevId: 672682048
2024-09-09 14:49:40 -07:00
Peter Hawkins
72c095261f
Improve the docstring for jax.Array.copy_to_host_async
.
...
PiperOrigin-RevId: 672666190
2024-09-09 14:04:17 -07:00
Yash Katariya
d6c36255e8
Create optimal order for v5e:8 devices which is [0, 1, 2, 3, 7, 6, 5, 4]
...
PiperOrigin-RevId: 672652104
2024-09-09 13:24:18 -07:00
jax authors
623cbb8ce7
Update XLA dependency to use revision
...
32004c2727
.
PiperOrigin-RevId: 672647417
2024-09-09 13:11:01 -07:00
jax authors
cdd68b5f94
Merge pull request #23526 from hawkinsp:nightly
...
PiperOrigin-RevId: 672634290
2024-09-09 12:31:56 -07:00
Peter Hawkins
5cc5ed2c5c
Disable a shard_map test case that fails on TPU v5e.
...
PiperOrigin-RevId: 672618556
2024-09-09 11:45:41 -07:00
Peter Hawkins
b975592478
Change nightly install commands to include all packages.
...
pip doesn't update transitive dependencies, and we probably want the latest versions of everything when installing a nightly.
2024-09-09 14:38:41 -04:00
Justin Fu
4bdfe09241
[Pallas] Fully skip GPU attention tests on win32.
...
PiperOrigin-RevId: 672588009
2024-09-09 10:21:33 -07:00