6685 Commits

Author SHA1 Message Date
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Bixia Zheng
c81f5cd2fc [xla] Replace debug option xla_use_shardy with execution option
use_shardy_partitioner.

Replace the use of xla_use_shardy with use_shardy_partitioner and remove
xla_use_shardy.

PiperOrigin-RevId: 657359119
2024-07-29 16:11:36 -07:00
jax authors
63303973a2 Merge pull request #22716 from superbobry:pallas
PiperOrigin-RevId: 657333519
2024-07-29 14:50:01 -07:00
Sergei Lebedev
a44265aa73 Added a trivial discharge rule for debug_callback_p
This allows using jax.debug.print with Refs in interpreted Pallas kernels.
2024-07-29 22:26:01 +01:00
Jake VanderPlas
6516a079f8 [array API] add device argument to fftfreq & rfftfreq 2024-07-29 13:23:54 -07:00
Jake VanderPlas
00ba7a6d25 [array API] move api metadata into jax.numpy namespace 2024-07-29 12:43:11 -07:00
jax authors
f070c0658f Merge pull request #22703 from Rifur13:plugin-fix
PiperOrigin-RevId: 657283607
2024-07-29 12:10:42 -07:00
jax authors
85e83b508b Merge pull request #22690 from jakevdp:inplace-doc
PiperOrigin-RevId: 657237218
2024-07-29 10:03:43 -07:00
Peter Hawkins
d1c0d993fc Bump the minimum CUDNN version to v9.1.
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00
Jake VanderPlas
cfa1e78549 Improve documentation for jnp.put, jnp.place, jnp.fill_diagonal
These are all the APIs that have an inplace parameter
2024-07-29 08:31:54 -07:00
Gleb Pobudzey
0224235edf Skip cuda backend initialization if no nvidia GPUs are visible. 2024-07-29 15:12:58 +00:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
George Necula
68972de021 [pallas] Add lowering errors for block shapes that are not supported.
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.

PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
Jake VanderPlas
a17c8d945b Finalize deprecation of jax.random.shuffle
This has been raising a DeprecationWarning for longer than anyone can remember.

PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
jax authors
aeff5b61a9 Merge pull request #22080 from vfdev-5:add-device-kwarg-linspace-array
PiperOrigin-RevId: 656467191
2024-07-26 11:18:24 -07:00
jax authors
0df074c285 Merge pull request #22680 from superbobry:maint
PiperOrigin-RevId: 656427681
2024-07-26 09:25:47 -07:00
Yash Katariya
05677694d8 Document copy_to_host_async method of jax.Array
PiperOrigin-RevId: 656408298
2024-07-26 08:21:01 -07:00
jax authors
694c14bbe6 Merge pull request #22556 from cool-RR:log-cache-key
PiperOrigin-RevId: 656364840
2024-07-26 05:32:11 -07:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
jax authors
2db99e03dd Merge pull request #22283 from ayaka14732:ayx/lowering/sign
PiperOrigin-RevId: 656317943
2024-07-26 02:28:33 -07:00
Eugene Zhulenev
15d4389247 Use vmap for random_gamma implementation on CPU backend
XLA:CPU is preparing to switch from compiling whole XLA program into a single LLVM function to a mode where each fusion/kernel will have its own entry point, and a thin runtime that will dispatch compute functions concurrently. This execution mode does not work very well with while loops with tiny computations and large number of iterations. Similar to GPU backend use vmap to avoid excessive runtime overheads.

Context: https://github.com/openxla/community/pull/96
PiperOrigin-RevId: 656199716
2024-07-25 19:41:59 -07:00
Ayaka
6cc09173d5 Add lowering for lax.sign 2024-07-26 10:33:42 +08:00
Yash Katariya
2eb1888c98 Make the vmap(jit) or vmap(wsc) with a concrete layout error more informative
PiperOrigin-RevId: 656176702
2024-07-25 18:32:37 -07:00
jax authors
f75e52dbf1 Merge pull request #22674 from mattjj:remove-pdot
PiperOrigin-RevId: 656127142
2024-07-25 16:03:19 -07:00
vfdev-5
76d61f9d8f Added device kwargs to jnp.linspace, jnp.array, jnp.asarray 2024-07-26 00:36:34 +02:00
Peter Hawkins
1ac2085417 Fix "unhashable type" error when passing a jax array as the "repeats" argument to jnp.repeat().
PiperOrigin-RevId: 656112851
2024-07-25 15:22:59 -07:00
Matthew Johnson
88d1cd731d remove pdot and xeinsum (since xmap is gone) 2024-07-25 21:19:17 +00:00
Yash Katariya
7de3c06147 Delete mesh.Loop now that xmap has been deleted
PiperOrigin-RevId: 656084608
2024-07-25 14:08:32 -07:00
jax authors
3ed9acba3a Merge pull request #22669 from hawkinsp:repeat
PiperOrigin-RevId: 656075132
2024-07-25 13:45:25 -07:00
jax authors
92806ee9f8 Merge pull request #22668 from cool-RR:nanoseconds
PiperOrigin-RevId: 656072892
2024-07-25 13:40:05 -07:00
Peter Hawkins
f07e963bf0 Simplify jaxpr for jnp.repeat in scalar repeat case.
Before:
```
In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4))
Out[2]:
{ lambda ; a:i32[3,4]. let
    b:i32[3,4,1] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 1)] a
    c:i32[1,3,1,4,1,1] = reshape[dimensions=None new_sizes=(1, 3, 1, 4, 1, 1)] b
    d:i32[1,3,1,4,3,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1, 2, 3, 4, 5)
      shape=(1, 3, 1, 4, 3, 1)
    ] c
    e:i32[3,4,3] = reshape[dimensions=None new_sizes=(3, 4, 3)] d
    f:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] e
  in (f,) }
```

After:
```
In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4))
Out[2]:
{ lambda ; a:i32[3,4]. let
    b:i32[3,4,3] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 3)] a
    c:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] b
  in (c,) }
```
2024-07-25 15:50:23 -04:00
jax authors
d4e08a9805 Merge pull request #22619 from jaro-sevcik:rename-mock-gpus
PiperOrigin-RevId: 656049327
2024-07-25 12:44:49 -07:00
jax authors
5d352a8b0c Merge pull request #22665 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 656047033
2024-07-25 12:37:17 -07:00
Ram Rachum
0d92d31063 Show elapsed time in nanoseconds 2024-07-25 22:20:25 +03:00
rajasekharporeddy
d717135564 Better docs for jnp.triu_indices_from and tril_indices_from 2024-07-26 00:31:44 +05:30
Jake VanderPlas
81b9db6b80 [array api] streamline astype device implementation
When this was first implemented, convert_element_type did not yet
have a sharding argument. Now we can simplify things by using it.
2024-07-25 10:42:05 -07:00
Jake VanderPlas
31abce1c80 register several deprecations in jax.numpy
This is in preparation for finalizing these deprecations. They include:
- complex->real casting in jnp.astype
- complex inputs to jnp.clip
- complex inputs to jnp.hypot

PiperOrigin-RevId: 656005670
2024-07-25 10:40:00 -07:00
Paweł Paruzel
ae40c87919 Activate Cholesky Factorization Kernel to XLA's FFI
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
jax authors
9ea79c61f4 Merge pull request #22653 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 655980588
2024-07-25 09:29:07 -07:00
Sergei Lebedev
99a8b92a7f Fixed Pallas Mosaic GPU tests
* Migrated to the new barrier APIs
* Fixed scratch view casting logic, it previously didn't work for >1 view

PiperOrigin-RevId: 655937541
2024-07-25 06:51:13 -07:00
Bart Chrzaszcz
b00f978f70 #sdy Support with_sharding_constraint lowering through Shardy.
PiperOrigin-RevId: 655905063
2024-07-25 04:20:52 -07:00
rajasekharporeddy
4ff69d3853 Fix jnp.triu_indices_from and jnp.tril_indices_from to emulate NumPy's behavior for arrays other than 2-D 2024-07-25 16:26:14 +05:30
jax authors
e14752c0ab Merge pull request #22642 from ROCm:ci_jax_exp
PiperOrigin-RevId: 655894235
2024-07-25 03:36:36 -07:00
jax authors
2dadbd7eb6 Merge pull request #22605 from Cjkkkk:add_sm86_sm89_flash_attention
PiperOrigin-RevId: 655894058
2024-07-25 03:32:45 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
Matthew Johnson
c8ea86c9c9 remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069

PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
jax authors
76b4c70c23 Merge pull request #22628 from hawkinsp:broadcast2
PiperOrigin-RevId: 655779730
2024-07-24 19:17:25 -07:00
Yash Katariya
51e27923e8 Simplify pjit's batching rule now that xmap is deleted. Also do cleanup around adding manual axes under shard_map
PiperOrigin-RevId: 655776234
2024-07-24 19:02:13 -07:00
jax authors
086b500da6 Merge pull request #21069 from mattjj:remove-named-shapes
PiperOrigin-RevId: 655766534
2024-07-24 18:20:50 -07:00