George Necula
6d53aaf7d0
[pallas] Improve the error localization
...
* Add the source location information for the index map function to
`BlockMapping`.
* Removed the `compute_index` wrapper around the index_map, so that
we can get the location information for the index_map, not the wrapper.
* Added source location to the errors related to index map functions.
* Added an error if the index map returns something other than integer
scalars.
* Construct BlockSpec origins for arguments using JAX helper functions
to get argument names
* Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
Yash Katariya
30037547d7
Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
...
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Bixia Zheng
c81f5cd2fc
[xla] Replace debug option xla_use_shardy with execution option
...
use_shardy_partitioner.
Replace the use of xla_use_shardy with use_shardy_partitioner and remove
xla_use_shardy.
PiperOrigin-RevId: 657359119
2024-07-29 16:11:36 -07:00
jax authors
63303973a2
Merge pull request #22716 from superbobry:pallas
...
PiperOrigin-RevId: 657333519
2024-07-29 14:50:01 -07:00
Sergei Lebedev
a44265aa73
Added a trivial discharge rule for debug_callback_p
...
This allows using jax.debug.print with Refs in interpreted Pallas kernels.
2024-07-29 22:26:01 +01:00
Jake VanderPlas
6516a079f8
[array API] add device argument to fftfreq & rfftfreq
2024-07-29 13:23:54 -07:00
Jake VanderPlas
00ba7a6d25
[array API] move api metadata into jax.numpy namespace
2024-07-29 12:43:11 -07:00
jax authors
f070c0658f
Merge pull request #22703 from Rifur13:plugin-fix
...
PiperOrigin-RevId: 657283607
2024-07-29 12:10:42 -07:00
jax authors
85e83b508b
Merge pull request #22690 from jakevdp:inplace-doc
...
PiperOrigin-RevId: 657237218
2024-07-29 10:03:43 -07:00
Peter Hawkins
d1c0d993fc
Bump the minimum CUDNN version to v9.1.
...
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.
Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.
PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00
Jake VanderPlas
cfa1e78549
Improve documentation for jnp.put, jnp.place, jnp.fill_diagonal
...
These are all the APIs that have an inplace parameter
2024-07-29 08:31:54 -07:00
Gleb Pobudzey
0224235edf
Skip cuda backend initialization if no nvidia GPUs are visible.
2024-07-29 15:12:58 +00:00
George Necula
70a11acbb1
[pallas] More simplification of grid mapping and calling convention
...
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.
I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.
I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
George Necula
68972de021
[pallas] Add lowering errors for block shapes that are not supported.
...
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.
PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
Jake VanderPlas
a17c8d945b
Finalize deprecation of jax.random.shuffle
...
This has been raising a DeprecationWarning for longer than anyone can remember.
PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
jax authors
aeff5b61a9
Merge pull request #22080 from vfdev-5:add-device-kwarg-linspace-array
...
PiperOrigin-RevId: 656467191
2024-07-26 11:18:24 -07:00
jax authors
0df074c285
Merge pull request #22680 from superbobry:maint
...
PiperOrigin-RevId: 656427681
2024-07-26 09:25:47 -07:00
Yash Katariya
05677694d8
Document copy_to_host_async
method of jax.Array
...
PiperOrigin-RevId: 656408298
2024-07-26 08:21:01 -07:00
jax authors
694c14bbe6
Merge pull request #22556 from cool-RR:log-cache-key
...
PiperOrigin-RevId: 656364840
2024-07-26 05:32:11 -07:00
Sergei Lebedev
8d33a6c9a6
Bumped jaxlib version mypy uses on the CI
...
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
jax authors
2db99e03dd
Merge pull request #22283 from ayaka14732:ayx/lowering/sign
...
PiperOrigin-RevId: 656317943
2024-07-26 02:28:33 -07:00
Eugene Zhulenev
15d4389247
Use vmap for random_gamma implementation on CPU backend
...
XLA:CPU is preparing to switch from compiling whole XLA program into a single LLVM function to a mode where each fusion/kernel will have its own entry point, and a thin runtime that will dispatch compute functions concurrently. This execution mode does not work very well with while loops with tiny computations and large number of iterations. Similar to GPU backend use vmap to avoid excessive runtime overheads.
Context: https://github.com/openxla/community/pull/96
PiperOrigin-RevId: 656199716
2024-07-25 19:41:59 -07:00
Ayaka
6cc09173d5
Add lowering for lax.sign
2024-07-26 10:33:42 +08:00
Yash Katariya
2eb1888c98
Make the vmap(jit) or vmap(wsc) with a concrete layout error more informative
...
PiperOrigin-RevId: 656176702
2024-07-25 18:32:37 -07:00
jax authors
f75e52dbf1
Merge pull request #22674 from mattjj:remove-pdot
...
PiperOrigin-RevId: 656127142
2024-07-25 16:03:19 -07:00
vfdev-5
76d61f9d8f
Added device kwargs to jnp.linspace, jnp.array, jnp.asarray
2024-07-26 00:36:34 +02:00
Peter Hawkins
1ac2085417
Fix "unhashable type" error when passing a jax array as the "repeats" argument to jnp.repeat().
...
PiperOrigin-RevId: 656112851
2024-07-25 15:22:59 -07:00
Matthew Johnson
88d1cd731d
remove pdot and xeinsum (since xmap is gone)
2024-07-25 21:19:17 +00:00
Yash Katariya
7de3c06147
Delete mesh.Loop now that xmap has been deleted
...
PiperOrigin-RevId: 656084608
2024-07-25 14:08:32 -07:00
jax authors
3ed9acba3a
Merge pull request #22669 from hawkinsp:repeat
...
PiperOrigin-RevId: 656075132
2024-07-25 13:45:25 -07:00
jax authors
92806ee9f8
Merge pull request #22668 from cool-RR:nanoseconds
...
PiperOrigin-RevId: 656072892
2024-07-25 13:40:05 -07:00
Peter Hawkins
f07e963bf0
Simplify jaxpr for jnp.repeat in scalar repeat case.
...
Before:
```
In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4))
Out[2]:
{ lambda ; a:i32[3,4]. let
b:i32[3,4,1] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 1)] a
c:i32[1,3,1,4,1,1] = reshape[dimensions=None new_sizes=(1, 3, 1, 4, 1, 1)] b
d:i32[1,3,1,4,3,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1, 2, 3, 4, 5)
shape=(1, 3, 1, 4, 3, 1)
] c
e:i32[3,4,3] = reshape[dimensions=None new_sizes=(3, 4, 3)] d
f:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] e
in (f,) }
```
After:
```
In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4))
Out[2]:
{ lambda ; a:i32[3,4]. let
b:i32[3,4,3] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 3)] a
c:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] b
in (c,) }
```
2024-07-25 15:50:23 -04:00
jax authors
d4e08a9805
Merge pull request #22619 from jaro-sevcik:rename-mock-gpus
...
PiperOrigin-RevId: 656049327
2024-07-25 12:44:49 -07:00
jax authors
5d352a8b0c
Merge pull request #22665 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 656047033
2024-07-25 12:37:17 -07:00
Ram Rachum
0d92d31063
Show elapsed time in nanoseconds
2024-07-25 22:20:25 +03:00
rajasekharporeddy
d717135564
Better docs for jnp.triu_indices_from and tril_indices_from
2024-07-26 00:31:44 +05:30
Jake VanderPlas
81b9db6b80
[array api] streamline astype device implementation
...
When this was first implemented, convert_element_type did not yet
have a sharding argument. Now we can simplify things by using it.
2024-07-25 10:42:05 -07:00
Jake VanderPlas
31abce1c80
register several deprecations in jax.numpy
...
This is in preparation for finalizing these deprecations. They include:
- complex->real casting in jnp.astype
- complex inputs to jnp.clip
- complex inputs to jnp.hypot
PiperOrigin-RevId: 656005670
2024-07-25 10:40:00 -07:00
Paweł Paruzel
ae40c87919
Activate Cholesky Factorization Kernel to XLA's FFI
...
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
jax authors
9ea79c61f4
Merge pull request #22653 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 655980588
2024-07-25 09:29:07 -07:00
Sergei Lebedev
99a8b92a7f
Fixed Pallas Mosaic GPU tests
...
* Migrated to the new barrier APIs
* Fixed scratch view casting logic, it previously didn't work for >1 view
PiperOrigin-RevId: 655937541
2024-07-25 06:51:13 -07:00
Bart Chrzaszcz
b00f978f70
#sdy Support with_sharding_constraint lowering through Shardy.
...
PiperOrigin-RevId: 655905063
2024-07-25 04:20:52 -07:00
rajasekharporeddy
4ff69d3853
Fix jnp.triu_indices_from and jnp.tril_indices_from to emulate NumPy's behavior for arrays other than 2-D
2024-07-25 16:26:14 +05:30
jax authors
e14752c0ab
Merge pull request #22642 from ROCm:ci_jax_exp
...
PiperOrigin-RevId: 655894235
2024-07-25 03:36:36 -07:00
jax authors
2dadbd7eb6
Merge pull request #22605 from Cjkkkk:add_sm86_sm89_flash_attention
...
PiperOrigin-RevId: 655894058
2024-07-25 03:32:45 -07:00
George Necula
4063373b22
Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
...
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
Matthew Johnson
c8ea86c9c9
remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
...
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069
PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
jax authors
76b4c70c23
Merge pull request #22628 from hawkinsp:broadcast2
...
PiperOrigin-RevId: 655779730
2024-07-24 19:17:25 -07:00
Yash Katariya
51e27923e8
Simplify pjit's batching rule now that xmap is deleted. Also do cleanup around adding manual axes under shard_map
...
PiperOrigin-RevId: 655776234
2024-07-24 19:02:13 -07:00
jax authors
086b500da6
Merge pull request #21069 from mattjj:remove-named-shapes
...
PiperOrigin-RevId: 655766534
2024-07-24 18:20:50 -07:00