17727 Commits

Author SHA1 Message Date
jax authors
a2b70e3346 Bump shard_count for shard_map_test to fix timeouts.
PiperOrigin-RevId: 571109311
2023-10-05 13:18:10 -07:00
Jake VanderPlas
4fd6bf4e57 Re-add JIT to jax.numpy.bitwise_count
PiperOrigin-RevId: 571107971
2023-10-05 13:07:49 -07:00
jax authors
797f577fb8 Expose mlir.ShapePolyLoweringState
PiperOrigin-RevId: 571075542
2023-10-05 11:21:09 -07:00
Tomás Longeri
ab4a8e3417 [Mosaic] apply_vector_layout C++ rewrite: various bug fixes
PiperOrigin-RevId: 571075082
2023-10-05 11:10:49 -07:00
Peter Hawkins
295cecd505 Update XLA dependency to use revision
e73af4223d.

PiperOrigin-RevId: 571067534
2023-10-05 10:54:50 -07:00
Tomás Longeri
68c84a6c5c [Mosaic] apply_vector_layout: Use shape in generalizes check (in Python)
- The addition to the check in the relayout loop in `apply_layout_op` should result in skipping some no-op relayouts
- The assert in `disassemble` also needs to be updated because it won't hold now that relayout is skipped more (relayout guarantees the defining layout to be equal to the input layout)

PiperOrigin-RevId: 571066259
2023-10-05 10:44:22 -07:00
Adam Paszke
633f68a398 [Mosaic] Fix a buggy vector.broadcast rule in apply_vector_layout
The rule did not take tiling into account, assuming that it works with
32-bit data that has native tiling. Now, we should have appropriate checks
in place, as well as some support for lane broadcasts of tiled values.

PiperOrigin-RevId: 570956025
2023-10-05 02:59:30 -07:00
Adam Paszke
d8a81ba45a [Mosaic] Handle a larger class of broadcasts with 1-sized trailing dimensions
PiperOrigin-RevId: 570947498
2023-10-05 02:18:49 -07:00
Adam Paszke
2e3a5d6c96 [Mosaic] Fix a bug when elementwise op can incorrectly propagate a replicated layout
PiperOrigin-RevId: 570942842
2023-10-05 01:54:27 -07:00
Tomás Longeri
dd2fcf5166 [Mosaic] apply_vector_layout C++ rewrite (16): vector.multi_reduction
PiperOrigin-RevId: 570907928
2023-10-04 22:34:26 -07:00
jax authors
201001bf9a Merge pull request #17921 from gnecula:harness_random
PiperOrigin-RevId: 570900835
2023-10-04 21:50:38 -07:00
Tomás Longeri
07e407703d [Mosaic] apply_vector_layout C++ rewrite (15): vector.shape_cast
PiperOrigin-RevId: 570891989
2023-10-04 20:47:12 -07:00
Jevin Jiang
98aae41f1e [XLA:Mosaic] Fix layout attribute string parse in python.
PiperOrigin-RevId: 570883549
2023-10-04 20:09:50 -07:00
jax authors
1c37f5091c sparse_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 570882867
2023-10-04 19:59:03 -07:00
Chris Jones
465eb21561 [pallas] Fix allow_tf32 value in Triton dot_general lowering.
`precision` is canonicalized as a tuple or `None`.

PiperOrigin-RevId: 570879987
2023-10-04 19:34:19 -07:00
Tomás Longeri
862d676c7a [Mosaic] apply_vector_layout C++ rewrite (14): tpu.matmul and vector.contract
PiperOrigin-RevId: 570877941
2023-10-04 19:20:00 -07:00
Tomás Longeri
991e6ef719 [Mosaic] apply_vector_layout C++ rewrite (13): scf.if, scf.yield
PiperOrigin-RevId: 570845376
2023-10-04 16:30:53 -07:00
jax authors
9e3d64a16b Merge pull request #17929 from hawkinsp:torchloader
PiperOrigin-RevId: 570839721
2023-10-04 16:11:24 -07:00
jax authors
60654644d6 Merge pull request #17938 from jakevdp:bitwise-count
PiperOrigin-RevId: 570837601
2023-10-04 16:00:49 -07:00
jax authors
7be8df9b47 Merge pull request #17933 from cgarciae:condition-numpy-version
PiperOrigin-RevId: 570827957
2023-10-04 15:27:44 -07:00
Jake VanderPlas
7df29577e6 jnp.bitwise_count: call into lax.population_count 2023-10-04 15:22:33 -07:00
Cristian Garcia
7498ffe843 condition numpy version based on python version 2023-10-04 21:06:01 +00:00
jax authors
be7f210a1b Allow duration event listeners to take extra keyword arguments.
PiperOrigin-RevId: 570783105
2023-10-04 13:15:49 -07:00
jax authors
305efe0501 random_test: reduce num_generated_cases to avoid timeouts
PiperOrigin-RevId: 570781641
2023-10-04 13:04:44 -07:00
Tomás Longeri
933d3530d8 [Mosaic] apply_vector_layout C++ rewrite: generalization for relayouts that reduce sublanes
Corresponds to cl/563570338.

PiperOrigin-RevId: 570769785
2023-10-04 12:26:53 -07:00
jax authors
389334d358 Merge pull request #17925 from jakevdp:rand-loggamma
PiperOrigin-RevId: 570765482
2023-10-04 12:12:40 -07:00
Peter Hawkins
d8a0227e86 Simplify the torch data loader collate function using tree_map.
Fixes https://github.com/google/jax/issues/1004
2023-10-04 14:59:06 -04:00
Jake VanderPlas
f739a888f3 jax.random: fix NaN corner-case in loggamma 2023-10-04 11:40:32 -07:00
Tomás Longeri
59d4f4462a [Mosaic] apply_vector_layout: arith.constant: erase old op after replacing uses (for non-splat values)
PiperOrigin-RevId: 570755366
2023-10-04 11:38:02 -07:00
Skye Wanderman-Milne
01372fedca Clarify that mesh_utils.create_device_mesh's contiguous_submeshes arg isn't necessary with jax.Array
PiperOrigin-RevId: 570751299
2023-10-04 11:24:23 -07:00
Sergei Lebedev
923498fb45 _StateContextManager now preserves the type of the value it stores.
This change is a follow-on to google/jax#16866, which added an ABSL-like API
for flags defined with `DEFINE_...`. Here we add a similar typed API for flags
defined with `define_..._state`.

See 37dad4d356/absl/flags/_flagvalues.py (L1333).

PiperOrigin-RevId: 570721827
2023-10-04 09:49:19 -07:00
Tomás Longeri
2fe00f88a0 [Mosaic] Run verifier after infer_memref_layout
PiperOrigin-RevId: 570720278
2023-10-04 09:38:54 -07:00
Adam Paszke
3cf822d0de [Mosaic] Allow arbitrary LHS tensor axis size for non-packed matmuls
PiperOrigin-RevId: 570717660
2023-10-04 09:28:10 -07:00
George Necula
c63880ba69 [export] Improve primitive harnesses to use jax.random.key
Many tests involving randomness in multi_platform_export_test were
failing because the primitive harnesses uses raw uint32 arrays as
keys. Change them to use jax.random.keys.
2023-10-04 16:26:11 +03:00
George Necula
efa987c019 [export] Add tests for multi-platform and cross-platform export
Test for each JAX primitive harness that we can lower it for
multiple platforms and then execute it on multiple platforms
with the same results as the JAX native execution.

This is a large test, covering between 5000-7000 harnesses,
depending on the platform.

Hundreds of harnesses fail this test. In future work we will
address each of the failing harnesses in turn.

PiperOrigin-RevId: 570661115
2023-10-04 05:08:36 -07:00
Sharad Vikram
52500d9711 [Pallas] Add pretty printing for dma_start
DMAs now print as `dma_start a[...] -> b[...] c`

PiperOrigin-RevId: 570587172
2023-10-03 22:10:58 -07:00
jax authors
16a2283a97 Merge pull request #17911 from hawkinsp:py312
PiperOrigin-RevId: 570558531
2023-10-03 19:01:14 -07:00
Peter Hawkins
58cb0d9ede Use the released version of Python 3.12 in the Windows wheel builds. 2023-10-03 21:51:37 -04:00
jax authors
6c0924920e Merge pull request #17910 from hawkinsp:rocm
PiperOrigin-RevId: 570555612
2023-10-03 18:45:12 -07:00
Peter Hawkins
efc18e4147 [JAX] Obtain NCCL via a stub, rather than linking it statically or dynamically.
This shrinks the CUDA jaxlib wheel size by around 80MB.

PiperOrigin-RevId: 570554454
2023-10-03 18:33:58 -07:00
Peter Hawkins
578478b24a Remove references to the ROCm TensorFlow repository in AMD build instructions. 2023-10-03 21:33:52 -04:00
jax authors
816ebf2fae Merge pull request #17909 from skye:version
PiperOrigin-RevId: 570551640
2023-10-03 18:15:17 -07:00
Skye Wanderman-Milne
82b58386b7 Update versions and CHANGELOG after jax 0.4.17 release 2023-10-03 17:54:35 -07:00
jax authors
f319a2b5d5 Merge pull request #17908 from ybaturina:update_test_instructions
PiperOrigin-RevId: 570546044
2023-10-03 17:53:35 -07:00
jax authors
2a46654065 Merge pull request #17906 from google:tpu_ci_remove_workaround
PiperOrigin-RevId: 570545829
2023-10-03 17:43:33 -07:00
jax authors
6bc184ee6c Merge pull request #17907 from jakevdp:typo
PiperOrigin-RevId: 570502792
2023-10-03 14:46:55 -07:00
jax authors
6cf4ce154d Merge pull request #17885 from jakevdp:bitwise-count
PiperOrigin-RevId: 570491042
2023-10-03 14:09:08 -07:00
Jake VanderPlas
3fd204ca0a fix typo in deprecation message 2023-10-03 14:04:09 -07:00
Sharad Vikram
a142c59713 [Pallas] Enable integer indexing for memrefs in .at
PiperOrigin-RevId: 570488769
2023-10-03 13:58:59 -07:00
Jake VanderPlas
a09fdf6e2f Add jax.numpy.bitwise_count() 2023-10-03 13:48:16 -07:00