jax authors
a2b70e3346
Bump shard_count for shard_map_test to fix timeouts.
...
PiperOrigin-RevId: 571109311
2023-10-05 13:18:10 -07:00
Jake VanderPlas
4fd6bf4e57
Re-add JIT to jax.numpy.bitwise_count
...
PiperOrigin-RevId: 571107971
2023-10-05 13:07:49 -07:00
jax authors
797f577fb8
Expose mlir.ShapePolyLoweringState
...
PiperOrigin-RevId: 571075542
2023-10-05 11:21:09 -07:00
Tomás Longeri
ab4a8e3417
[Mosaic] apply_vector_layout C++ rewrite: various bug fixes
...
PiperOrigin-RevId: 571075082
2023-10-05 11:10:49 -07:00
Peter Hawkins
295cecd505
Update XLA dependency to use revision
...
e73af4223d
.
PiperOrigin-RevId: 571067534
2023-10-05 10:54:50 -07:00
Tomás Longeri
68c84a6c5c
[Mosaic] apply_vector_layout: Use shape in generalizes check (in Python)
...
- The addition to the check in the relayout loop in `apply_layout_op` should result in skipping some no-op relayouts
- The assert in `disassemble` also needs to be updated because it won't hold now that relayout is skipped more (relayout guarantees the defining layout to be equal to the input layout)
PiperOrigin-RevId: 571066259
2023-10-05 10:44:22 -07:00
Adam Paszke
633f68a398
[Mosaic] Fix a buggy vector.broadcast rule in apply_vector_layout
...
The rule did not take tiling into account, assuming that it works with
32-bit data that has native tiling. Now, we should have appropriate checks
in place, as well as some support for lane broadcasts of tiled values.
PiperOrigin-RevId: 570956025
2023-10-05 02:59:30 -07:00
Adam Paszke
d8a81ba45a
[Mosaic] Handle a larger class of broadcasts with 1-sized trailing dimensions
...
PiperOrigin-RevId: 570947498
2023-10-05 02:18:49 -07:00
Adam Paszke
2e3a5d6c96
[Mosaic] Fix a bug when elementwise op can incorrectly propagate a replicated layout
...
PiperOrigin-RevId: 570942842
2023-10-05 01:54:27 -07:00
Tomás Longeri
dd2fcf5166
[Mosaic] apply_vector_layout C++ rewrite (16): vector.multi_reduction
...
PiperOrigin-RevId: 570907928
2023-10-04 22:34:26 -07:00
jax authors
201001bf9a
Merge pull request #17921 from gnecula:harness_random
...
PiperOrigin-RevId: 570900835
2023-10-04 21:50:38 -07:00
Tomás Longeri
07e407703d
[Mosaic] apply_vector_layout C++ rewrite (15): vector.shape_cast
...
PiperOrigin-RevId: 570891989
2023-10-04 20:47:12 -07:00
Jevin Jiang
98aae41f1e
[XLA:Mosaic] Fix layout attribute string parse in python.
...
PiperOrigin-RevId: 570883549
2023-10-04 20:09:50 -07:00
jax authors
1c37f5091c
sparse_test: Split into two so that each target is small enough to fit within a medium timeout.
...
PiperOrigin-RevId: 570882867
2023-10-04 19:59:03 -07:00
Chris Jones
465eb21561
[pallas] Fix allow_tf32
value in Triton dot_general
lowering.
...
`precision` is canonicalized as a tuple or `None`.
PiperOrigin-RevId: 570879987
2023-10-04 19:34:19 -07:00
Tomás Longeri
862d676c7a
[Mosaic] apply_vector_layout C++ rewrite (14): tpu.matmul and vector.contract
...
PiperOrigin-RevId: 570877941
2023-10-04 19:20:00 -07:00
Tomás Longeri
991e6ef719
[Mosaic] apply_vector_layout C++ rewrite (13): scf.if, scf.yield
...
PiperOrigin-RevId: 570845376
2023-10-04 16:30:53 -07:00
jax authors
9e3d64a16b
Merge pull request #17929 from hawkinsp:torchloader
...
PiperOrigin-RevId: 570839721
2023-10-04 16:11:24 -07:00
jax authors
60654644d6
Merge pull request #17938 from jakevdp:bitwise-count
...
PiperOrigin-RevId: 570837601
2023-10-04 16:00:49 -07:00
jax authors
7be8df9b47
Merge pull request #17933 from cgarciae:condition-numpy-version
...
PiperOrigin-RevId: 570827957
2023-10-04 15:27:44 -07:00
Jake VanderPlas
7df29577e6
jnp.bitwise_count: call into lax.population_count
2023-10-04 15:22:33 -07:00
Cristian Garcia
7498ffe843
condition numpy version based on python version
2023-10-04 21:06:01 +00:00
jax authors
be7f210a1b
Allow duration event listeners to take extra keyword arguments.
...
PiperOrigin-RevId: 570783105
2023-10-04 13:15:49 -07:00
jax authors
305efe0501
random_test: reduce num_generated_cases to avoid timeouts
...
PiperOrigin-RevId: 570781641
2023-10-04 13:04:44 -07:00
Tomás Longeri
933d3530d8
[Mosaic] apply_vector_layout C++ rewrite: generalization for relayouts that reduce sublanes
...
Corresponds to cl/563570338.
PiperOrigin-RevId: 570769785
2023-10-04 12:26:53 -07:00
jax authors
389334d358
Merge pull request #17925 from jakevdp:rand-loggamma
...
PiperOrigin-RevId: 570765482
2023-10-04 12:12:40 -07:00
Peter Hawkins
d8a0227e86
Simplify the torch data loader collate function using tree_map.
...
Fixes https://github.com/google/jax/issues/1004
2023-10-04 14:59:06 -04:00
Jake VanderPlas
f739a888f3
jax.random: fix NaN corner-case in loggamma
2023-10-04 11:40:32 -07:00
Tomás Longeri
59d4f4462a
[Mosaic] apply_vector_layout: arith.constant: erase old op after replacing uses (for non-splat values)
...
PiperOrigin-RevId: 570755366
2023-10-04 11:38:02 -07:00
Skye Wanderman-Milne
01372fedca
Clarify that mesh_utils.create_device_mesh's contiguous_submeshes arg isn't necessary with jax.Array
...
PiperOrigin-RevId: 570751299
2023-10-04 11:24:23 -07:00
Sergei Lebedev
923498fb45
_StateContextManager
now preserves the type of the value it stores.
...
This change is a follow-on to google/jax#16866 , which added an ABSL-like API
for flags defined with `DEFINE_...`. Here we add a similar typed API for flags
defined with `define_..._state`.
See 37dad4d356/absl/flags/_flagvalues.py (L1333)
.
PiperOrigin-RevId: 570721827
2023-10-04 09:49:19 -07:00
Tomás Longeri
2fe00f88a0
[Mosaic] Run verifier after infer_memref_layout
...
PiperOrigin-RevId: 570720278
2023-10-04 09:38:54 -07:00
Adam Paszke
3cf822d0de
[Mosaic] Allow arbitrary LHS tensor axis size for non-packed matmuls
...
PiperOrigin-RevId: 570717660
2023-10-04 09:28:10 -07:00
George Necula
c63880ba69
[export] Improve primitive harnesses to use jax.random.key
...
Many tests involving randomness in multi_platform_export_test were
failing because the primitive harnesses uses raw uint32 arrays as
keys. Change them to use jax.random.keys.
2023-10-04 16:26:11 +03:00
George Necula
efa987c019
[export] Add tests for multi-platform and cross-platform export
...
Test for each JAX primitive harness that we can lower it for
multiple platforms and then execute it on multiple platforms
with the same results as the JAX native execution.
This is a large test, covering between 5000-7000 harnesses,
depending on the platform.
Hundreds of harnesses fail this test. In future work we will
address each of the failing harnesses in turn.
PiperOrigin-RevId: 570661115
2023-10-04 05:08:36 -07:00
Sharad Vikram
52500d9711
[Pallas] Add pretty printing for dma_start
...
DMAs now print as `dma_start a[...] -> b[...] c`
PiperOrigin-RevId: 570587172
2023-10-03 22:10:58 -07:00
jax authors
16a2283a97
Merge pull request #17911 from hawkinsp:py312
...
PiperOrigin-RevId: 570558531
2023-10-03 19:01:14 -07:00
Peter Hawkins
58cb0d9ede
Use the released version of Python 3.12 in the Windows wheel builds.
2023-10-03 21:51:37 -04:00
jax authors
6c0924920e
Merge pull request #17910 from hawkinsp:rocm
...
PiperOrigin-RevId: 570555612
2023-10-03 18:45:12 -07:00
Peter Hawkins
efc18e4147
[JAX] Obtain NCCL via a stub, rather than linking it statically or dynamically.
...
This shrinks the CUDA jaxlib wheel size by around 80MB.
PiperOrigin-RevId: 570554454
2023-10-03 18:33:58 -07:00
Peter Hawkins
578478b24a
Remove references to the ROCm TensorFlow repository in AMD build instructions.
2023-10-03 21:33:52 -04:00
jax authors
816ebf2fae
Merge pull request #17909 from skye:version
...
PiperOrigin-RevId: 570551640
2023-10-03 18:15:17 -07:00
Skye Wanderman-Milne
82b58386b7
Update versions and CHANGELOG after jax 0.4.17 release
2023-10-03 17:54:35 -07:00
jax authors
f319a2b5d5
Merge pull request #17908 from ybaturina:update_test_instructions
...
PiperOrigin-RevId: 570546044
2023-10-03 17:53:35 -07:00
jax authors
2a46654065
Merge pull request #17906 from google:tpu_ci_remove_workaround
...
PiperOrigin-RevId: 570545829
2023-10-03 17:43:33 -07:00
jax authors
6bc184ee6c
Merge pull request #17907 from jakevdp:typo
...
PiperOrigin-RevId: 570502792
2023-10-03 14:46:55 -07:00
jax authors
6cf4ce154d
Merge pull request #17885 from jakevdp:bitwise-count
...
PiperOrigin-RevId: 570491042
2023-10-03 14:09:08 -07:00
Jake VanderPlas
3fd204ca0a
fix typo in deprecation message
2023-10-03 14:04:09 -07:00
Sharad Vikram
a142c59713
[Pallas] Enable integer indexing for memrefs in .at
...
PiperOrigin-RevId: 570488769
2023-10-03 13:58:59 -07:00
Jake VanderPlas
a09fdf6e2f
Add jax.numpy.bitwise_count()
2023-10-03 13:48:16 -07:00