Peter Hawkins
fd23b8733d
Bump minimum SciPy version to 1.10.
...
SciPy 1.9.0 was released July 29, 2022, which is 24 months ago
PiperOrigin-RevId: 657215038
2024-07-29 08:50:18 -07:00
jax authors
e78e643b5f
Merge pull request #22593 from gnecula:pallas_more_simplification
...
PiperOrigin-RevId: 657198330
2024-07-29 07:48:52 -07:00
Dan Foreman-Mackey
ff4e0b1214
Rearrange the LAPACK handler definitions in jaxlib to avoid duplicate handler errors.
...
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.
PiperOrigin-RevId: 657186057
2024-07-29 06:59:44 -07:00
Vladimir Belitskiy
fef91fb201
Skip tests/mock_gpu_test.py on pytest.
...
PiperOrigin-RevId: 657185249
2024-07-29 06:55:43 -07:00
George Necula
70a11acbb1
[pallas] More simplification of grid mapping and calling convention
...
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.
I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.
I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
George Necula
68972de021
[pallas] Add lowering errors for block shapes that are not supported.
...
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.
PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
Sergei Lebedev
ccc4c42ec9
Reduced the input size in PallasCallInputOutputAliasingTest
...
This ensures the test doesn't OOM when running on A100 on the CI.
PiperOrigin-RevId: 657165032
2024-07-29 05:29:45 -07:00
Adam Paszke
a00b659b03
[Mosaic GPU] Fix two subtle issues with kernel lowering
...
1. The MLIR context is created by the user and its lifetime is not
in our control. To avoid depending on it, we serialize the module.
2. The operand and result layout requirements were missing from the custom call.
PiperOrigin-RevId: 657164985
2024-07-29 05:25:50 -07:00
jax authors
6a7822a73b
Update XLA dependency to use revision
...
95e3eea8d2
.
PiperOrigin-RevId: 657003194
2024-07-28 15:32:56 -07:00
jax authors
74649be7ed
Update XLA dependency to use revision
...
89089aa569
.
PiperOrigin-RevId: 656797625
2024-07-27 15:23:41 -07:00
Jake VanderPlas
a17c8d945b
Finalize deprecation of jax.random.shuffle
...
This has been raising a DeprecationWarning for longer than anyone can remember.
PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
jax authors
dab15d6fdd
Merge pull request #22684 from froystig:rngdoc
...
PiperOrigin-RevId: 656600958
2024-07-26 19:12:36 -07:00
jax authors
40d569b22e
Update XLA dependency to use revision
...
cf139009c9
.
PiperOrigin-RevId: 656531286
2024-07-26 14:34:09 -07:00
Roy Frostig
f30ebd8586
document vmap peculiarity of experimental RNG implementations
2024-07-26 13:40:16 -07:00
Roy Frostig
6ddd488df0
improve RNG doc around implementation configuration
2024-07-26 13:40:16 -07:00
jax authors
aeff5b61a9
Merge pull request #22080 from vfdev-5:add-device-kwarg-linspace-array
...
PiperOrigin-RevId: 656467191
2024-07-26 11:18:24 -07:00
Vladimir Belitskiy
7f96b263d4
Un-skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on ASAN.
...
It should have been fixed via
https://github.com/pytorch/pytorch/issues/117058#issuecomment-1973020150
PiperOrigin-RevId: 656464550
2024-07-26 11:10:41 -07:00
Vladimir Belitskiy
282ebf4882
Skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on MSAN.
...
MSAN has issues with using `-c opt` in some cases, which prevents this
test from running properly.
PiperOrigin-RevId: 656454585
2024-07-26 10:44:19 -07:00
jax authors
0df074c285
Merge pull request #22680 from superbobry:maint
...
PiperOrigin-RevId: 656427681
2024-07-26 09:25:47 -07:00
Adam Paszke
d862f78dcc
[Mosaic GPU] Skip matmul tests with large clusters
...
I'm still investigating but they sometimes hang for an unclear reason.
PiperOrigin-RevId: 656426326
2024-07-26 09:21:13 -07:00
Yash Katariya
05677694d8
Document copy_to_host_async
method of jax.Array
...
PiperOrigin-RevId: 656408298
2024-07-26 08:21:01 -07:00
jax authors
694c14bbe6
Merge pull request #22556 from cool-RR:log-cache-key
...
PiperOrigin-RevId: 656364840
2024-07-26 05:32:11 -07:00
Ayaka
bb160cf54e
Move TPU ops test to ops_test.py
...
Move the TPU ops test from `tpu_ops_test.py` to `ops_test.py`. The functions tested in this file are not TPU-specific operations, so we don't need a separate test file.
PiperOrigin-RevId: 656347969
2024-07-26 04:24:13 -07:00
Sergei Lebedev
8d33a6c9a6
Bumped jaxlib version mypy uses on the CI
...
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
jax authors
2db99e03dd
Merge pull request #22283 from ayaka14732:ayx/lowering/sign
...
PiperOrigin-RevId: 656317943
2024-07-26 02:28:33 -07:00
jax authors
8ed94bcfb6
[shard_map docs]: Fix doc typos
...
PiperOrigin-RevId: 656265100
2024-07-25 23:29:55 -07:00
Tomás Longeri
0f834cdf24
[Mosaic TPU] Enable lane broadcast for packed types and offsets outside of first tile, and fix some broadcast infer logic
...
PiperOrigin-RevId: 656201666
2024-07-25 19:48:20 -07:00
Eugene Zhulenev
15d4389247
Use vmap for random_gamma implementation on CPU backend
...
XLA:CPU is preparing to switch from compiling whole XLA program into a single LLVM function to a mode where each fusion/kernel will have its own entry point, and a thin runtime that will dispatch compute functions concurrently. This execution mode does not work very well with while loops with tiny computations and large number of iterations. Similar to GPU backend use vmap to avoid excessive runtime overheads.
Context: https://github.com/openxla/community/pull/96
PiperOrigin-RevId: 656199716
2024-07-25 19:41:59 -07:00
Ayaka
6cc09173d5
Add lowering for lax.sign
2024-07-26 10:33:42 +08:00
Yash Katariya
2eb1888c98
Make the vmap(jit) or vmap(wsc) with a concrete layout error more informative
...
PiperOrigin-RevId: 656176702
2024-07-25 18:32:37 -07:00
jax authors
6f68887e0d
Merge pull request #22673 from mattjj:cond-flakey-test
...
PiperOrigin-RevId: 656133068
2024-07-25 16:19:19 -07:00
Matthew Johnson
f72a3f8ef4
deflake cond memory leak regression test
2024-07-25 23:12:21 +00:00
jax authors
f75e52dbf1
Merge pull request #22674 from mattjj:remove-pdot
...
PiperOrigin-RevId: 656127142
2024-07-25 16:03:19 -07:00
Jevin Jiang
b1b7d0465e
[XLA:Mosaic] Support any int type upcast.
...
Also fixed the int4 unpacking.
PiperOrigin-RevId: 656119043
2024-07-25 15:39:38 -07:00
vfdev-5
76d61f9d8f
Added device kwargs to jnp.linspace, jnp.array, jnp.asarray
2024-07-26 00:36:34 +02:00
Peter Hawkins
1ac2085417
Fix "unhashable type" error when passing a jax array as the "repeats" argument to jnp.repeat().
...
PiperOrigin-RevId: 656112851
2024-07-25 15:22:59 -07:00
Matthew Johnson
88d1cd731d
remove pdot and xeinsum (since xmap is gone)
2024-07-25 21:19:17 +00:00
Yash Katariya
7de3c06147
Delete mesh.Loop now that xmap has been deleted
...
PiperOrigin-RevId: 656084608
2024-07-25 14:08:32 -07:00
jax authors
c4d3c8ddc7
Update XLA dependency to use revision
...
6b0495fc43
.
PiperOrigin-RevId: 656078864
2024-07-25 13:55:11 -07:00
jax authors
3ed9acba3a
Merge pull request #22669 from hawkinsp:repeat
...
PiperOrigin-RevId: 656075132
2024-07-25 13:45:25 -07:00
jax authors
92806ee9f8
Merge pull request #22668 from cool-RR:nanoseconds
...
PiperOrigin-RevId: 656072892
2024-07-25 13:40:05 -07:00
Sergei Lebedev
5e418f5ab2
Added argument validation to mosaic_gpu_init_tma_desc
...
This should help with understanding cuTensorMapEncodeTiled failures, since
CUDA doesn't provide any details beyond the error return code.
Note that this change also ensures that TMA descriptors are 64-byte aligned.
PiperOrigin-RevId: 656062820
2024-07-25 13:16:34 -07:00
Peter Hawkins
f07e963bf0
Simplify jaxpr for jnp.repeat in scalar repeat case.
...
Before:
```
In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4))
Out[2]:
{ lambda ; a:i32[3,4]. let
b:i32[3,4,1] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 1)] a
c:i32[1,3,1,4,1,1] = reshape[dimensions=None new_sizes=(1, 3, 1, 4, 1, 1)] b
d:i32[1,3,1,4,3,1] = broadcast_in_dim[
broadcast_dimensions=(0, 1, 2, 3, 4, 5)
shape=(1, 3, 1, 4, 3, 1)
] c
e:i32[3,4,3] = reshape[dimensions=None new_sizes=(3, 4, 3)] d
f:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] e
in (f,) }
```
After:
```
In [2]: jax.make_jaxpr(lambda x: jnp.repeat(x, 3, axis=-1))(jnp.arange(12).reshape(3, 4))
Out[2]:
{ lambda ; a:i32[3,4]. let
b:i32[3,4,3] = broadcast_in_dim[broadcast_dimensions=(0, 1) shape=(3, 4, 3)] a
c:i32[3,12] = reshape[dimensions=None new_sizes=(3, 12)] b
in (c,) }
```
2024-07-25 15:50:23 -04:00
jax authors
d4e08a9805
Merge pull request #22619 from jaro-sevcik:rename-mock-gpus
...
PiperOrigin-RevId: 656049327
2024-07-25 12:44:49 -07:00
jax authors
5d352a8b0c
Merge pull request #22665 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 656047033
2024-07-25 12:37:17 -07:00
Ram Rachum
0d92d31063
Show elapsed time in nanoseconds
2024-07-25 22:20:25 +03:00
rajasekharporeddy
d717135564
Better docs for jnp.triu_indices_from and tril_indices_from
2024-07-26 00:31:44 +05:30
jax authors
f17d0f382a
Merge pull request #22664 from jakevdp:astype-device
...
PiperOrigin-RevId: 656016734
2024-07-25 11:11:49 -07:00
jax authors
593afa6757
Merge pull request #22663 from jakevdp:array-api-cleanup
...
PiperOrigin-RevId: 656016083
2024-07-25 11:08:01 -07:00
Jake VanderPlas
81b9db6b80
[array api] streamline astype device implementation
...
When this was first implemented, convert_element_type did not yet
have a sharding argument. Now we can simplify things by using it.
2024-07-25 10:42:05 -07:00