Peter Hawkins
09448384e5
Update release notes for 0.4.29 release.
2024-06-06 11:13:14 -04:00
jax authors
fe9c1606fc
Merge pull request #21655 from gnecula:exp_lower
...
PiperOrigin-RevId: 640898362
2024-06-06 07:58:41 -07:00
Peter Hawkins
dfe6128509
Reverts da816d34eaad6a1c6536959ccb4bfee4466c037d
...
PiperOrigin-RevId: 640886105
2024-06-06 07:10:09 -07:00
Chris Jones
d700a0842b
[mosaic:gpu] Fix matmul example for mixed precision inputs.
...
The wgmma instruction group was never committed.
PiperOrigin-RevId: 640863541
2024-06-06 05:29:23 -07:00
Tomás Longeri
20d9aac6be
[Mosaic] Remove some restrictions for vector.shape_cast in infer-vector-layout and apply-vector-layout
...
- On infer-vector-layout remove some restrictions related to batch dimensions. Reshaping them doesn't matter as long as they don't combine with tiled dimensions.
- On apply-vector-layout, simplify handling of cases where the implicit tiled don't change while removing some unnecessary restrictions.
- Don't require native tiling or natural topology for this.
PiperOrigin-RevId: 640837740
2024-06-06 03:26:43 -07:00
Christos Perivolaropoulos
24e4bf2265
[Mosaic GPU] Add f32 benchmarks for matmul.
...
PiperOrigin-RevId: 640826101
2024-06-06 02:31:40 -07:00
George Necula
079eea5669
[export] Add a LoweringParameters.for_export boolean context for exporting
...
This boolean context field is set only when we are lowering for
exporting. It can be used, e.g., to adapt the lowering rules
for the export case.
2024-06-06 06:28:58 +01:00
Yash Katariya
55d0f5ef8f
Add lower
to specialize
making it a true Stage
.
...
So now users can do:
```
specialized = jax.jit(f).specialize(*args)
print(specialized.jaxpr, specialized.out_info)
lowered = specialized.lower()
compiled = lowered.compile()
```
PiperOrigin-RevId: 640737396
2024-06-05 19:54:41 -07:00
jax authors
d117305dba
Merge pull request #21685 from pkgoogle:better_bitwise_count_doc
...
PiperOrigin-RevId: 640717333
2024-06-05 18:29:20 -07:00
Yash Katariya
228adb4a4a
Add specialize
on jax.jit and make it a Stage
.
...
Eventually, we should use this in jax.make_jaxpr and delete all the duplicated code.
PiperOrigin-RevId: 640707223
2024-06-05 17:46:14 -07:00
jax authors
a14a5b6e06
Update XLA dependency to use revision
...
3618e421da
.
PiperOrigin-RevId: 640705050
2024-06-05 17:37:28 -07:00
jax authors
2a08b96306
Merge pull request #21629 from ayaka14732:typo1
...
PiperOrigin-RevId: 640703754
2024-06-05 17:31:57 -07:00
Mark Sandler
da816d34ea
Makes global_shape optional for jax.make_array_from_process_local_data.
...
PiperOrigin-RevId: 640695090
2024-06-05 16:58:08 -07:00
Piseth Ky
61044cee49
better docs for bitwise_count
2024-06-05 16:16:54 -07:00
Parker Schuh
20c2a45bea
PallasOpsInterpretTest.test_debug_print still flaky, add effects
...
barrier to block until the output has been known to have been emitted.
PiperOrigin-RevId: 640652710
2024-06-05 14:38:38 -07:00
jax authors
bf59a67bf0
Merge pull request #21657 from rajasekharporeddy:test_branch7
...
PiperOrigin-RevId: 640652642
2024-06-05 14:35:35 -07:00
jax authors
444a4c9110
Merge pull request #21125 from rajasekharporeddy:test_branch1
...
PiperOrigin-RevId: 640648819
2024-06-05 14:24:38 -07:00
jax authors
88dbeca297
Merge pull request #21679 from jakevdp:with-config
...
PiperOrigin-RevId: 640637421
2024-06-05 13:48:46 -07:00
Jake VanderPlas
e6e4acb7c3
tests: set configs with jtu.with_config rather than manually
2024-06-05 13:34:32 -07:00
rajasekharporeddy
be41f309ed
Add code examples to jax.scipy.fft.idct and jax.scipy.fft.idctn docs
2024-06-06 01:45:07 +05:30
jax authors
913ff50000
Merge pull request #21418 from rajasekharporeddy:test_branch6
...
PiperOrigin-RevId: 640617948
2024-06-05 12:53:17 -07:00
Parker Schuh
efeb25bb87
Skip other memories tests that fail on older libtpus.
...
PiperOrigin-RevId: 640617633
2024-06-05 12:49:47 -07:00
jax authors
224954a60e
Merge pull request #21676 from jakevdp:global-config-context
...
PiperOrigin-RevId: 640614980
2024-06-05 12:40:10 -07:00
jax authors
0cbd0a023d
Merge pull request #21494 from dfm:mac-arm-x86
...
PiperOrigin-RevId: 640613602
2024-06-05 12:36:00 -07:00
jax authors
45182da66d
Merge pull request #21571 from pkgoogle:update_abs_doc
...
PiperOrigin-RevId: 640599874
2024-06-05 11:55:46 -07:00
Piseth Ky
0c22c63c2e
Updating jnp.abs/absolute docs
...
fix documentation output
fix doc code
updating doc string to reference abs as an alias
CI/CD requires args
Update alias formatting
allowing us to skip alias docstrings
Update jax/_src/numpy/ufuncs.py
Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>
Update jax/_src/numpy/ufuncs.py
Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>
Update jax/_src/numpy/ufuncs.py
Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>
edit example to use integer output
2024-06-05 11:27:24 -07:00
Yash Katariya
ebc9de3dbc
Expose shape
and dtype
on ArgInfo and mark aval
as private. Aval is an internal property of JAX and shouldn't have been exposed to users. Users can create their own SDS with shape and dtype until we expose ArrayDuck
.
...
PiperOrigin-RevId: 640577261
2024-06-05 10:50:35 -07:00
Jake VanderPlas
f04a2279a5
shape_poly_test: adjust configs via jtu.global_config_context
2024-06-05 10:45:56 -07:00
jax authors
da87e4470a
Merge pull request #21651 from jakevdp:fix-jax2tf-tests
...
PiperOrigin-RevId: 640571776
2024-06-05 10:33:59 -07:00
jax authors
7771cd25b1
Merge pull request #21646 from dfm:gh21643
...
PiperOrigin-RevId: 640559336
2024-06-05 09:57:35 -07:00
George Necula
dbad518d2b
[shape_poly] Add limited support for lax.approx_top_k.
...
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.
We also add a backwards compatibility test.
PiperOrigin-RevId: 640557581
2024-06-05 09:51:47 -07:00
jax authors
3042031860
Merge pull request #21668 from hawkinsp:win
...
PiperOrigin-RevId: 640551082
2024-06-05 09:29:40 -07:00
Yash Katariya
1edd649de4
Deprecate XLACompatibleSharding
in favor of jax.sharding.Sharding
.
...
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Sergei Lebedev
fc4d343c83
Added missing jax.block_until_ready to PallasOpsTest.test_debug_print*
...
PiperOrigin-RevId: 640541103
2024-06-05 08:53:35 -07:00
Sergei Lebedev
69dd14e58a
Updated remaining `mgpu.once
usages to use
mgpu.single_thread
`
...
PiperOrigin-RevId: 640534664
2024-06-05 08:31:36 -07:00
jax authors
621814bd7d
Add loop-based vmap lowering for pallas calls
...
Loop-based vmap is used for cases in which a pipeline-based vmap is currently not feasible:
* Dynamic grid dimensions
* Batched scalar prefetch arguments
PiperOrigin-RevId: 640530524
2024-06-05 08:15:27 -07:00
Yash Katariya
9e3f290de3
Delete XLACompatibleSharding and replace with jax.sharding.Sharding
.
...
As of this change, `XLACompatibleSharding` is an alias of `jax.sharding.Sharding` but it will be deprecated in a follow up change.
Why do this?
* All shardings JAX has are XLA Compatible. The reason why `Sharding` was created was to allow non-xla shardings but that's not happened in the past 2 years. So let's simplify!
* Having these 2 types makes things very confusing. One example is:
* `jax.jit` only accepts XLACompatibleShardings.
* `jax.device_put` accepts `jax.sharding.Sharding` but if you use `device_put` inside `jax.jit` with a memory_kind then you can only pass `XLACompatibleSharding`. This is contradicting and confusing and we can simplify.
PiperOrigin-RevId: 640527070
2024-06-05 08:03:23 -07:00
Peter Hawkins
ab6d361620
[windows] Add --no-progress to LLVM install command to avoid excessive logs in Github actions.
2024-06-05 10:52:43 -04:00
Malcolm Reynolds
1669b99505
Reverts c2a3c0bb80434d89053b43648f88eba22b9bf1fa
...
PiperOrigin-RevId: 640524004
2024-06-05 07:50:58 -07:00
Sergei Lebedev
e09cda8fa9
Removed unnecessary jaxlib version guards from xla_bridge
...
The minimum jaxlib version is 0.4.27.
PiperOrigin-RevId: 640513280
2024-06-05 07:05:08 -07:00
Chris Jones
557cae65d1
[mosaic:gpu] Rename once
-> single_thread
.
...
PiperOrigin-RevId: 640475714
2024-06-05 04:13:09 -07:00
Chris Jones
485fad5679
[mosaic:gpu] Document behaviour of warp_idx
and warpgroup_idx
.
...
Extracted common warp broadcasting code.
PiperOrigin-RevId: 640475128
2024-06-05 04:09:38 -07:00
Sharad Vikram
c2a3c0bb80
Add support to pipeline emitter for shapes that don't perfectly divide the block shapes
...
PiperOrigin-RevId: 640471328
2024-06-05 03:54:19 -07:00
Sergei Lebedev
d5e43dd1e9
Test pl.debug_print() on GPU/Triton via jtu.capture_stdout()
...
This approach does not currently work on TPU, because (I think) the printing
is done asynchronosly in C++, and stdout is empty by the time CPython leaves
the with block.
PiperOrigin-RevId: 640456288
2024-06-05 02:45:20 -07:00
Sergei Lebedev
40f107e5a5
Moved Pallas GPU ops into pallas/ops/gpu
...
PiperOrigin-RevId: 640439838
2024-06-05 01:34:46 -07:00
jax authors
f3d2c4fd63
Merge pull request #21500 from froystig:slab-heap
...
PiperOrigin-RevId: 640415259
2024-06-04 23:53:02 -07:00
George Necula
39ac584729
[shape_poly] Move to jax._src in preparation for adding to AOT APIs.
...
The shape polymorphism APIs are still private and are only exposed through `jax.experimental.export` as before.
PiperOrigin-RevId: 640393089
2024-06-04 22:03:24 -07:00
Roy Frostig
6b275a1875
configure dynamic shape mode only locally
...
... under a dynamic scope when staging out djaxprs. This avoids
enabling dynamic shape mode globally (in turn also making it a side
effect of `djax` import).
2024-06-04 21:59:25 -07:00
Roy Frostig
1aba5c2c82
remove usage of eval_shape
when defining allocating ops
...
This was unnecessarily involved and introduced an unenforced
requirement that shape-dynamism be enabled during the evaluation.
2024-06-04 21:56:49 -07:00
Roy Frostig
971afab587
add helpers for internal static shape/dimension assertions
...
Also some typing fixes.
2024-06-04 20:25:33 -07:00