21135 Commits

Author SHA1 Message Date
Peter Hawkins
09448384e5 Update release notes for 0.4.29 release. 2024-06-06 11:13:14 -04:00
jax authors
fe9c1606fc Merge pull request #21655 from gnecula:exp_lower
PiperOrigin-RevId: 640898362
2024-06-06 07:58:41 -07:00
Peter Hawkins
dfe6128509 Reverts da816d34eaad6a1c6536959ccb4bfee4466c037d
PiperOrigin-RevId: 640886105
2024-06-06 07:10:09 -07:00
Chris Jones
d700a0842b [mosaic:gpu] Fix matmul example for mixed precision inputs.
The wgmma instruction group was never committed.

PiperOrigin-RevId: 640863541
2024-06-06 05:29:23 -07:00
Tomás Longeri
20d9aac6be [Mosaic] Remove some restrictions for vector.shape_cast in infer-vector-layout and apply-vector-layout
- On infer-vector-layout remove some restrictions related to batch dimensions. Reshaping them doesn't matter as long as they don't combine with tiled dimensions.
- On apply-vector-layout, simplify handling of cases where the implicit tiled don't change while removing some unnecessary restrictions.
  - Don't require native tiling or natural topology for this.

PiperOrigin-RevId: 640837740
2024-06-06 03:26:43 -07:00
Christos Perivolaropoulos
24e4bf2265 [Mosaic GPU] Add f32 benchmarks for matmul.
PiperOrigin-RevId: 640826101
2024-06-06 02:31:40 -07:00
George Necula
079eea5669 [export] Add a LoweringParameters.for_export boolean context for exporting
This boolean context field is set only when we are lowering for
exporting. It can be used, e.g., to adapt the lowering rules
for the export case.
2024-06-06 06:28:58 +01:00
Yash Katariya
55d0f5ef8f Add lower to specialize making it a true Stage.
So now users can do:

```
specialized = jax.jit(f).specialize(*args)
print(specialized.jaxpr, specialized.out_info)

lowered = specialized.lower()

compiled = lowered.compile()
```
PiperOrigin-RevId: 640737396
2024-06-05 19:54:41 -07:00
jax authors
d117305dba Merge pull request #21685 from pkgoogle:better_bitwise_count_doc
PiperOrigin-RevId: 640717333
2024-06-05 18:29:20 -07:00
Yash Katariya
228adb4a4a Add specialize on jax.jit and make it a Stage.
Eventually, we should use this in jax.make_jaxpr and delete all the duplicated code.

PiperOrigin-RevId: 640707223
2024-06-05 17:46:14 -07:00
jax authors
a14a5b6e06 Update XLA dependency to use revision
3618e421da.

PiperOrigin-RevId: 640705050
2024-06-05 17:37:28 -07:00
jax authors
2a08b96306 Merge pull request #21629 from ayaka14732:typo1
PiperOrigin-RevId: 640703754
2024-06-05 17:31:57 -07:00
Mark Sandler
da816d34ea Makes global_shape optional for jax.make_array_from_process_local_data.
PiperOrigin-RevId: 640695090
2024-06-05 16:58:08 -07:00
Piseth Ky
61044cee49 better docs for bitwise_count 2024-06-05 16:16:54 -07:00
Parker Schuh
20c2a45bea PallasOpsInterpretTest.test_debug_print still flaky, add effects
barrier to block until the output has been known to have been emitted.

PiperOrigin-RevId: 640652710
2024-06-05 14:38:38 -07:00
jax authors
bf59a67bf0 Merge pull request #21657 from rajasekharporeddy:test_branch7
PiperOrigin-RevId: 640652642
2024-06-05 14:35:35 -07:00
jax authors
444a4c9110 Merge pull request #21125 from rajasekharporeddy:test_branch1
PiperOrigin-RevId: 640648819
2024-06-05 14:24:38 -07:00
jax authors
88dbeca297 Merge pull request #21679 from jakevdp:with-config
PiperOrigin-RevId: 640637421
2024-06-05 13:48:46 -07:00
Jake VanderPlas
e6e4acb7c3 tests: set configs with jtu.with_config rather than manually 2024-06-05 13:34:32 -07:00
rajasekharporeddy
be41f309ed Add code examples to jax.scipy.fft.idct and jax.scipy.fft.idctn docs 2024-06-06 01:45:07 +05:30
jax authors
913ff50000 Merge pull request #21418 from rajasekharporeddy:test_branch6
PiperOrigin-RevId: 640617948
2024-06-05 12:53:17 -07:00
Parker Schuh
efeb25bb87 Skip other memories tests that fail on older libtpus.
PiperOrigin-RevId: 640617633
2024-06-05 12:49:47 -07:00
jax authors
224954a60e Merge pull request #21676 from jakevdp:global-config-context
PiperOrigin-RevId: 640614980
2024-06-05 12:40:10 -07:00
jax authors
0cbd0a023d Merge pull request #21494 from dfm:mac-arm-x86
PiperOrigin-RevId: 640613602
2024-06-05 12:36:00 -07:00
jax authors
45182da66d Merge pull request #21571 from pkgoogle:update_abs_doc
PiperOrigin-RevId: 640599874
2024-06-05 11:55:46 -07:00
Piseth Ky
0c22c63c2e Updating jnp.abs/absolute docs
fix documentation output

fix doc code

updating doc string to reference abs as an alias

CI/CD requires args

Update alias formatting

allowing us to skip alias docstrings

Update jax/_src/numpy/ufuncs.py

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

Update jax/_src/numpy/ufuncs.py

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

Update jax/_src/numpy/ufuncs.py

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

edit example to use integer output
2024-06-05 11:27:24 -07:00
Yash Katariya
ebc9de3dbc Expose shape and dtype on ArgInfo and mark aval as private. Aval is an internal property of JAX and shouldn't have been exposed to users. Users can create their own SDS with shape and dtype until we expose ArrayDuck.
PiperOrigin-RevId: 640577261
2024-06-05 10:50:35 -07:00
Jake VanderPlas
f04a2279a5 shape_poly_test: adjust configs via jtu.global_config_context 2024-06-05 10:45:56 -07:00
jax authors
da87e4470a Merge pull request #21651 from jakevdp:fix-jax2tf-tests
PiperOrigin-RevId: 640571776
2024-06-05 10:33:59 -07:00
jax authors
7771cd25b1 Merge pull request #21646 from dfm:gh21643
PiperOrigin-RevId: 640559336
2024-06-05 09:57:35 -07:00
George Necula
dbad518d2b [shape_poly] Add limited support for lax.approx_top_k.
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.

We also add a backwards compatibility test.

PiperOrigin-RevId: 640557581
2024-06-05 09:51:47 -07:00
jax authors
3042031860 Merge pull request #21668 from hawkinsp:win
PiperOrigin-RevId: 640551082
2024-06-05 09:29:40 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Sergei Lebedev
fc4d343c83 Added missing jax.block_until_ready to PallasOpsTest.test_debug_print*
PiperOrigin-RevId: 640541103
2024-06-05 08:53:35 -07:00
Sergei Lebedev
69dd14e58a Updated remaining `mgpu.once usages to use mgpu.single_thread`
PiperOrigin-RevId: 640534664
2024-06-05 08:31:36 -07:00
jax authors
621814bd7d Add loop-based vmap lowering for pallas calls
Loop-based vmap is used for cases in which a pipeline-based vmap is currently not feasible:
* Dynamic grid dimensions
* Batched scalar prefetch arguments

PiperOrigin-RevId: 640530524
2024-06-05 08:15:27 -07:00
Yash Katariya
9e3f290de3 Delete XLACompatibleSharding and replace with jax.sharding.Sharding.
As of this change, `XLACompatibleSharding` is an alias of `jax.sharding.Sharding` but it will be deprecated in a follow up change.

Why do this?

* All shardings JAX has are XLA Compatible. The reason why `Sharding` was created was to allow non-xla shardings but that's not happened in the past 2 years. So let's simplify!

* Having these 2 types makes things very confusing. One example is:
  * `jax.jit` only accepts XLACompatibleShardings.
  * `jax.device_put` accepts `jax.sharding.Sharding` but if you use `device_put` inside `jax.jit` with a memory_kind then you can only pass `XLACompatibleSharding`. This is contradicting and confusing and we can simplify.

PiperOrigin-RevId: 640527070
2024-06-05 08:03:23 -07:00
Peter Hawkins
ab6d361620 [windows] Add --no-progress to LLVM install command to avoid excessive logs in Github actions. 2024-06-05 10:52:43 -04:00
Malcolm Reynolds
1669b99505 Reverts c2a3c0bb80434d89053b43648f88eba22b9bf1fa
PiperOrigin-RevId: 640524004
2024-06-05 07:50:58 -07:00
Sergei Lebedev
e09cda8fa9 Removed unnecessary jaxlib version guards from xla_bridge
The minimum jaxlib version is 0.4.27.

PiperOrigin-RevId: 640513280
2024-06-05 07:05:08 -07:00
Chris Jones
557cae65d1 [mosaic:gpu] Rename once -> single_thread.
PiperOrigin-RevId: 640475714
2024-06-05 04:13:09 -07:00
Chris Jones
485fad5679 [mosaic:gpu] Document behaviour of warp_idx and warpgroup_idx.
Extracted common warp broadcasting code.

PiperOrigin-RevId: 640475128
2024-06-05 04:09:38 -07:00
Sharad Vikram
c2a3c0bb80 Add support to pipeline emitter for shapes that don't perfectly divide the block shapes
PiperOrigin-RevId: 640471328
2024-06-05 03:54:19 -07:00
Sergei Lebedev
d5e43dd1e9 Test pl.debug_print() on GPU/Triton via jtu.capture_stdout()
This approach does not currently work on TPU, because (I think) the printing
is done asynchronosly in C++, and stdout is empty by the time CPython leaves
the with block.

PiperOrigin-RevId: 640456288
2024-06-05 02:45:20 -07:00
Sergei Lebedev
40f107e5a5 Moved Pallas GPU ops into pallas/ops/gpu
PiperOrigin-RevId: 640439838
2024-06-05 01:34:46 -07:00
jax authors
f3d2c4fd63 Merge pull request #21500 from froystig:slab-heap
PiperOrigin-RevId: 640415259
2024-06-04 23:53:02 -07:00
George Necula
39ac584729 [shape_poly] Move to jax._src in preparation for adding to AOT APIs.
The shape polymorphism APIs are still private and are only exposed through `jax.experimental.export` as before.

PiperOrigin-RevId: 640393089
2024-06-04 22:03:24 -07:00
Roy Frostig
6b275a1875 configure dynamic shape mode only locally
... under a dynamic scope when staging out djaxprs. This avoids
enabling dynamic shape mode globally (in turn also making it a side
effect of `djax` import).
2024-06-04 21:59:25 -07:00
Roy Frostig
1aba5c2c82 remove usage of eval_shape when defining allocating ops
This was unnecessarily involved and introduced an unenforced
requirement that shape-dynamism be enabled during the evaluation.
2024-06-04 21:56:49 -07:00
Roy Frostig
971afab587 add helpers for internal static shape/dimension assertions
Also some typing fixes.
2024-06-04 20:25:33 -07:00