jax authors
9e62994bce
Merge pull request #21135 from hawkinsp:release
...
PiperOrigin-RevId: 632235600
jax-v0.4.28
jaxlib-v0.4.28
jax-v0.4.28-rc
2024-05-09 12:32:51 -07:00
Peter Hawkins
038dfeec15
Prepare 0.4.28 release.
2024-05-09 19:25:33 +00:00
jax authors
f98e707551
Update XLA dependency to use revision
...
d60579f54a
.
PiperOrigin-RevId: 632232639
2024-05-09 12:22:25 -07:00
jax authors
1a7a2aa555
Merge pull request #21106 from jakevdp:linalg-precision
...
PiperOrigin-RevId: 632217396
2024-05-09 11:33:54 -07:00
jax authors
0c4d81c8cd
Merge pull request #21138 from jakevdp:einsum-doc
...
PiperOrigin-RevId: 632198113
2024-05-09 10:38:51 -07:00
Jake VanderPlas
2ddb7ff801
jnp.linalg: add precision & preferred_element_type to dot-like functions
2024-05-09 10:06:51 -07:00
Justin Fu
eb0b1b06e9
Merge pull request #21108 from justinjfu/skip_pallas_test_64
...
Skip float64 test_nextafter on TPU.
2024-05-09 09:20:30 -07:00
Yash Katariya
671fb1265d
Update the multi-process note in pjit's docstring
...
PiperOrigin-RevId: 632160561
2024-05-09 08:38:29 -07:00
jax authors
2be3f6d145
Merge pull request #21146 from jakevdp:fix-multidot
...
PiperOrigin-RevId: 632142647
2024-05-09 07:28:49 -07:00
Peter Hawkins
89d25bb1a3
Reenable examples_test in Bazel build.
...
Fix bitrot.
This test was disabled years ago because it was slow, but it isn't any more.
PiperOrigin-RevId: 632138101
2024-05-09 07:10:07 -07:00
Jake VanderPlas
5edfaa6782
jnp.linalg.multi_dot: use optimize='auto'
2024-05-09 06:47:30 -07:00
jax authors
1e88e2f862
Update XLA dependency to use revision
...
4872030c44
.
PiperOrigin-RevId: 631997979
2024-05-08 19:17:08 -07:00
Peter Hawkins
168f40ee3d
[XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier.
...
This works around https://github.com/python/cpython/issues/89478 , which was fixed in Python 3.11.
PiperOrigin-RevId: 631984256
2024-05-08 18:05:28 -07:00
Yash Katariya
96f888bcfe
Reverts 1956ff7d7b73794012fece2d8452e097196587fc
...
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
jax authors
f991dd8bf1
Merge pull request #21139 from jakevdp:fix-lpmn-test
...
PiperOrigin-RevId: 631954696
2024-05-08 16:06:47 -07:00
Jake VanderPlas
f556a17033
TST: fix Lpmn test for new scipy
2024-05-08 15:55:20 -07:00
Jake VanderPlas
e8700523d3
jnp.einsum: improve documentation
2024-05-08 14:30:59 -07:00
jax authors
962f084543
Merge pull request #21137 from superbobry:pallas
...
PiperOrigin-RevId: 631923082
2024-05-08 14:20:10 -07:00
jax authors
65d4c688e0
Generic reduce window jvp
...
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.
However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).
For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.
In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition
PiperOrigin-RevId: 631916764
2024-05-08 14:00:39 -07:00
Sergei Lebedev
4b62425b42
Renamed is_device_gpu_at_least to is_cuda_compute_capability_at_least
...
This makes it clear that the predicate is only supposed to be used for NVidia
GPUs at the moment.
2024-05-08 21:41:50 +01:00
jax authors
bfdc87d6d9
Merge pull request #21136 from superbobry:pallas
...
PiperOrigin-RevId: 631908723
2024-05-08 13:34:40 -07:00
Sergei Lebedev
575ba942e0
Removed get_compute_capability from jax.experimental.pallas.gpu
...
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
jax authors
a145109ac2
Update XLA dependency to use revision
...
68b17a8571
.
PiperOrigin-RevId: 631874377
2024-05-08 11:47:24 -07:00
Anselm Levskaya
f768cb74b9
Refactor pipeline emitter API.
...
Building on enrique's work, this CL refactors the emit_pipeline abstraction:
1) factors out the VMEM double-buffering bookkeeping into a helper class.
2) concentrate the intricate copy/wait scheduling logic into one place inside a scheduler helper
while allowing manual overrides, callbacks don't control scheduling anymore, rather we have
explicit loop scheduling.
3) minimize callbacks and simplify the "defaults" for fusing pipelines together.
Examples of fully overlapped versions of latency- and throughput- optimized AG-matmuls and
matmul-RSs are included in new tests.
PiperOrigin-RevId: 631865641
2024-05-08 11:22:47 -07:00
Parker Schuh
e652d62b85
Cleanup second registration of custom_partitioning callbacks now that
...
the jaxlib version has been bumped.
PiperOrigin-RevId: 631852273
2024-05-08 10:45:39 -07:00
jax authors
8baa5d8180
Merge pull request #21128 from hawkinsp:loggingtest
...
PiperOrigin-RevId: 631839432
2024-05-08 10:07:42 -07:00
jax authors
2967ec9b1b
Merge pull request #21129 from superbobry:pallas
...
PiperOrigin-RevId: 631825682
2024-05-08 09:25:49 -07:00
Sergei Lebedev
0feeaa5999
Removed stale version guards and try/except blocks from Pallas GPU
...
They are unnecessary now that the minimum jaxlib version is 0.4.27.
2024-05-08 17:05:45 +01:00
jax authors
0e57c79b0b
Switch Windows jobs to use Clang.
...
Remove the experimental/trial Clang job.
PiperOrigin-RevId: 631814321
2024-05-08 08:47:34 -07:00
jax authors
11da3df238
Merge pull request #21096 from gspschmid:gschmid/sourcemaps
...
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
Peter Hawkins
919832a63c
Enable logging_test on all CI platforms.
...
Should catch issues like https://github.com/google/jax/issues/21121
2024-05-08 12:43:52 +00:00
Georg Stefan Schmid
b0b322d486
Add sourcemap module to generate TC39-compliant source maps
2024-05-08 01:54:25 -07:00
jax authors
335f27b0b6
Update XLA dependency to use revision
...
c6df436f9e
.
PiperOrigin-RevId: 631629197
2024-05-07 19:58:28 -07:00
jax authors
e9e4e5341e
[jax:mosaic-gpu] FragmentedArray can do tiled load.
...
PiperOrigin-RevId: 631611060
2024-05-07 18:13:55 -07:00
Jevin Jiang
79f11d5495
[Pallas] Fix some typos.
...
PiperOrigin-RevId: 631592201
2024-05-07 16:52:38 -07:00
Yash Katariya
395d3cb79e
Bump minimum jaxlib version to 0.4.27
...
xla_extension_version is 261 and mlir_api_version is 56
PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Jieying Luo
5ba56bb075
Recommend the plugin in the CUDA installation instructions.
...
PiperOrigin-RevId: 631555876
2024-05-07 14:47:39 -07:00
jax authors
78e10eea79
Merge pull request #21115 from jakevdp:multi-dot
...
PiperOrigin-RevId: 631545161
2024-05-07 14:14:02 -07:00
jax authors
2b3251ed46
Merge pull request #21092 from jakevdp:dot-doc
...
PiperOrigin-RevId: 631536980
2024-05-07 13:51:13 -07:00
Jake VanderPlas
09810be0cd
Implement jnp.linalg.multi_dot using opt_einsum
2024-05-07 13:40:25 -07:00
jax authors
5f702674f7
Merge pull request #21103 from superbobry:mosaic-gpu-fix
...
PiperOrigin-RevId: 631521771
2024-05-07 13:11:43 -07:00
jax authors
7153738551
Merge pull request #21104 from superbobry:triton-fixes
...
PiperOrigin-RevId: 631521767
2024-05-07 13:06:33 -07:00
jax authors
174405c953
The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0.
...
The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation.
PiperOrigin-RevId: 631519200
2024-05-07 12:58:37 -07:00
jax authors
b6fea3734a
Merge pull request #21111 from jakevdp:fix-changelog
...
PiperOrigin-RevId: 631493089
2024-05-07 11:34:06 -07:00
jax authors
9524188e45
Merge pull request #21110 from jakevdp:upstream-nightly
...
PiperOrigin-RevId: 631490005
2024-05-07 11:25:19 -07:00
Jake VanderPlas
c18851b65d
CHANGELOG: move change from 0.4.27 to 0.4.28
2024-05-07 11:16:11 -07:00
Jake VanderPlas
496795e95f
CI: fix typo in workflow
2024-05-07 11:14:11 -07:00
Yash Katariya
5031a1ddc4
Finish jax and jaxlib 0.4.27 release
...
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00
Jake VanderPlas
034d843e8c
jax.numpy: better docs for matmul-like functions
2024-05-07 11:01:54 -07:00
jax authors
3dd8af93b4
Merge pull request #21090 from jakevdp:extract
...
PiperOrigin-RevId: 631480474
2024-05-07 10:59:10 -07:00