20631 Commits

Author SHA1 Message Date
jax authors
9e62994bce Merge pull request #21135 from hawkinsp:release
PiperOrigin-RevId: 632235600
jax-v0.4.28 jaxlib-v0.4.28 jax-v0.4.28-rc
2024-05-09 12:32:51 -07:00
Peter Hawkins
038dfeec15 Prepare 0.4.28 release. 2024-05-09 19:25:33 +00:00
jax authors
f98e707551 Update XLA dependency to use revision
d60579f54a.

PiperOrigin-RevId: 632232639
2024-05-09 12:22:25 -07:00
jax authors
1a7a2aa555 Merge pull request #21106 from jakevdp:linalg-precision
PiperOrigin-RevId: 632217396
2024-05-09 11:33:54 -07:00
jax authors
0c4d81c8cd Merge pull request #21138 from jakevdp:einsum-doc
PiperOrigin-RevId: 632198113
2024-05-09 10:38:51 -07:00
Jake VanderPlas
2ddb7ff801 jnp.linalg: add precision & preferred_element_type to dot-like functions 2024-05-09 10:06:51 -07:00
Justin Fu
eb0b1b06e9
Merge pull request #21108 from justinjfu/skip_pallas_test_64
Skip float64 test_nextafter on TPU.
2024-05-09 09:20:30 -07:00
Yash Katariya
671fb1265d Update the multi-process note in pjit's docstring
PiperOrigin-RevId: 632160561
2024-05-09 08:38:29 -07:00
jax authors
2be3f6d145 Merge pull request #21146 from jakevdp:fix-multidot
PiperOrigin-RevId: 632142647
2024-05-09 07:28:49 -07:00
Peter Hawkins
89d25bb1a3 Reenable examples_test in Bazel build.
Fix bitrot.

This test was disabled years ago because it was slow, but it isn't any more.

PiperOrigin-RevId: 632138101
2024-05-09 07:10:07 -07:00
Jake VanderPlas
5edfaa6782 jnp.linalg.multi_dot: use optimize='auto' 2024-05-09 06:47:30 -07:00
jax authors
1e88e2f862 Update XLA dependency to use revision
4872030c44.

PiperOrigin-RevId: 631997979
2024-05-08 19:17:08 -07:00
Peter Hawkins
168f40ee3d [XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier.
This works around https://github.com/python/cpython/issues/89478, which was fixed in Python 3.11.

PiperOrigin-RevId: 631984256
2024-05-08 18:05:28 -07:00
Yash Katariya
96f888bcfe Reverts 1956ff7d7b73794012fece2d8452e097196587fc
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
jax authors
f991dd8bf1 Merge pull request #21139 from jakevdp:fix-lpmn-test
PiperOrigin-RevId: 631954696
2024-05-08 16:06:47 -07:00
Jake VanderPlas
f556a17033 TST: fix Lpmn test for new scipy 2024-05-08 15:55:20 -07:00
Jake VanderPlas
e8700523d3 jnp.einsum: improve documentation 2024-05-08 14:30:59 -07:00
jax authors
962f084543 Merge pull request #21137 from superbobry:pallas
PiperOrigin-RevId: 631923082
2024-05-08 14:20:10 -07:00
jax authors
65d4c688e0 Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.

However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).

For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.

In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition

PiperOrigin-RevId: 631916764
2024-05-08 14:00:39 -07:00
Sergei Lebedev
4b62425b42 Renamed is_device_gpu_at_least to is_cuda_compute_capability_at_least
This makes it clear that the predicate is only supposed to be used for NVidia
GPUs at the moment.
2024-05-08 21:41:50 +01:00
jax authors
bfdc87d6d9 Merge pull request #21136 from superbobry:pallas
PiperOrigin-RevId: 631908723
2024-05-08 13:34:40 -07:00
Sergei Lebedev
575ba942e0 Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
jax authors
a145109ac2 Update XLA dependency to use revision
68b17a8571.

PiperOrigin-RevId: 631874377
2024-05-08 11:47:24 -07:00
Anselm Levskaya
f768cb74b9 Refactor pipeline emitter API.
Building on enrique's work, this CL refactors the emit_pipeline abstraction:
1) factors out the VMEM double-buffering bookkeeping into a helper class.
2) concentrate the intricate copy/wait scheduling logic into one place inside a scheduler helper
while allowing manual overrides, callbacks don't control scheduling anymore, rather we have
explicit loop scheduling.
3) minimize callbacks and simplify the "defaults" for fusing pipelines together.

Examples of fully overlapped versions of latency- and throughput- optimized AG-matmuls and
matmul-RSs are included in new tests.

PiperOrigin-RevId: 631865641
2024-05-08 11:22:47 -07:00
Parker Schuh
e652d62b85 Cleanup second registration of custom_partitioning callbacks now that
the jaxlib version has been bumped.

PiperOrigin-RevId: 631852273
2024-05-08 10:45:39 -07:00
jax authors
8baa5d8180 Merge pull request #21128 from hawkinsp:loggingtest
PiperOrigin-RevId: 631839432
2024-05-08 10:07:42 -07:00
jax authors
2967ec9b1b Merge pull request #21129 from superbobry:pallas
PiperOrigin-RevId: 631825682
2024-05-08 09:25:49 -07:00
Sergei Lebedev
0feeaa5999 Removed stale version guards and try/except blocks from Pallas GPU
They are unnecessary now that the minimum jaxlib version is 0.4.27.
2024-05-08 17:05:45 +01:00
jax authors
0e57c79b0b Switch Windows jobs to use Clang.
Remove the experimental/trial Clang job.

PiperOrigin-RevId: 631814321
2024-05-08 08:47:34 -07:00
jax authors
11da3df238 Merge pull request #21096 from gspschmid:gschmid/sourcemaps
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
Peter Hawkins
919832a63c Enable logging_test on all CI platforms.
Should catch issues like https://github.com/google/jax/issues/21121
2024-05-08 12:43:52 +00:00
Georg Stefan Schmid
b0b322d486 Add sourcemap module to generate TC39-compliant source maps 2024-05-08 01:54:25 -07:00
jax authors
335f27b0b6 Update XLA dependency to use revision
c6df436f9e.

PiperOrigin-RevId: 631629197
2024-05-07 19:58:28 -07:00
jax authors
e9e4e5341e [jax:mosaic-gpu] FragmentedArray can do tiled load.
PiperOrigin-RevId: 631611060
2024-05-07 18:13:55 -07:00
Jevin Jiang
79f11d5495 [Pallas] Fix some typos.
PiperOrigin-RevId: 631592201
2024-05-07 16:52:38 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Jieying Luo
5ba56bb075 Recommend the plugin in the CUDA installation instructions.
PiperOrigin-RevId: 631555876
2024-05-07 14:47:39 -07:00
jax authors
78e10eea79 Merge pull request #21115 from jakevdp:multi-dot
PiperOrigin-RevId: 631545161
2024-05-07 14:14:02 -07:00
jax authors
2b3251ed46 Merge pull request #21092 from jakevdp:dot-doc
PiperOrigin-RevId: 631536980
2024-05-07 13:51:13 -07:00
Jake VanderPlas
09810be0cd Implement jnp.linalg.multi_dot using opt_einsum 2024-05-07 13:40:25 -07:00
jax authors
5f702674f7 Merge pull request #21103 from superbobry:mosaic-gpu-fix
PiperOrigin-RevId: 631521771
2024-05-07 13:11:43 -07:00
jax authors
7153738551 Merge pull request #21104 from superbobry:triton-fixes
PiperOrigin-RevId: 631521767
2024-05-07 13:06:33 -07:00
jax authors
174405c953 The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0.
The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation.

PiperOrigin-RevId: 631519200
2024-05-07 12:58:37 -07:00
jax authors
b6fea3734a Merge pull request #21111 from jakevdp:fix-changelog
PiperOrigin-RevId: 631493089
2024-05-07 11:34:06 -07:00
jax authors
9524188e45 Merge pull request #21110 from jakevdp:upstream-nightly
PiperOrigin-RevId: 631490005
2024-05-07 11:25:19 -07:00
Jake VanderPlas
c18851b65d CHANGELOG: move change from 0.4.27 to 0.4.28 2024-05-07 11:16:11 -07:00
Jake VanderPlas
496795e95f CI: fix typo in workflow 2024-05-07 11:14:11 -07:00
Yash Katariya
5031a1ddc4 Finish jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00
Jake VanderPlas
034d843e8c jax.numpy: better docs for matmul-like functions 2024-05-07 11:01:54 -07:00
jax authors
3dd8af93b4 Merge pull request #21090 from jakevdp:extract
PiperOrigin-RevId: 631480474
2024-05-07 10:59:10 -07:00