15738 Commits

Author SHA1 Message Date
Pearu Peterson
50670bd907 Fix log10 and log2 for large inputs. 2025-01-01 12:45:39 +02:00
Yunlong Liu
97b1faacdd Fixes the random key sharding in shard_map. 2024-12-29 18:43:21 +00:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Sergei Lebedev
76ccb199fd [pallas:mosaic_gpu] Added some runtime type checking to copy_* and barrier_* primitives
PiperOrigin-RevId: 710302436
2024-12-28 09:02:43 -08:00
jax authors
b6aead6f3a [AutoPGLE] Explicitly disable command buffers when profiler is used.
PiperOrigin-RevId: 709475833
2024-12-24 21:31:05 -08:00
Sergei Lebedev
44333e1cfb [pallas:mosaic_gpu] Addressed a todo in broadcasted_iota lowering
PiperOrigin-RevId: 709310152
2024-12-24 04:32:29 -08:00
jax authors
4eff1316db Merge pull request #25672 from jakevdp:finalize-dep
PiperOrigin-RevId: 709284584
2024-12-24 02:08:51 -08:00
jax authors
c57b49c606 Merge pull request #25669 from jakevdp:undep
PiperOrigin-RevId: 709104784
2024-12-23 11:06:24 -08:00
Jake VanderPlas
40fe4b8797 Finalize deprecation of some symbols from jax.lib.xla_client 2024-12-23 10:14:16 -08:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
jax authors
6c85e54aad Merge pull request #25662 from liblaf:main
PiperOrigin-RevId: 709058676
2024-12-23 07:38:25 -08:00
jax authors
704185ea25 Merge pull request #24607 from kaixih:support_head_size_256
PiperOrigin-RevId: 709058104
2024-12-23 07:36:21 -08:00
Jake VanderPlas
cb10710c92 Remove casting from jax.nn.one_hot
This change was made after the most recent release, so is safe
to remove. Casting float to int potentially changes intentional
beavior: e.g. NaN casts to 0. Some downstream users currently
use NaN to mark rows which should have no one-hot entry.
2024-12-23 07:33:49 -08:00
Chris Jones
83e60a9697 [pallas:triton] Add support for lowering int4 load.
PiperOrigin-RevId: 709032308
2024-12-23 05:12:46 -08:00
Sergei Lebedev
a51d627941 [pallas:mosaic_gpu] Reduced duplication between _ensure_fa and _ensure_ir_value
PiperOrigin-RevId: 709030824
2024-12-23 05:04:06 -08:00
Sergei Lebedev
3e7f48114c [pallas:mosaic_gpu] Updated the lowering following the changes in in Mosaic GPU internals
PiperOrigin-RevId: 709009048
2024-12-23 03:14:26 -08:00
liblaf
75b56548e2
Fix a typo in documentation for pinv function. 2024-12-23 17:20:33 +08:00
jax authors
1719986aaa [Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.

This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.

And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.

PiperOrigin-RevId: 708752566
2024-12-22 00:50:51 -08:00
jax authors
1c0dee8012 Merge pull request #25650 from jakevdp:view-int4
PiperOrigin-RevId: 708468858
2024-12-20 17:31:31 -08:00
John QiangZhang
e560c6a45c Change the namespace name to avoid using export c++ keyword on namespace.
PiperOrigin-RevId: 708450293
2024-12-20 16:02:15 -08:00
Jake VanderPlas
75f36dc3ea Support int4/uint4 in jnp.ndarray.view 2024-12-20 13:57:40 -08:00
jax authors
44d67e1379 Merge pull request #25648 from hawkinsp:warnings3
PiperOrigin-RevId: 708415848
2024-12-20 13:41:43 -08:00
Jake VanderPlas
beee98ab4a Add int4/uint4 support to bitcast_convert_type 2024-12-20 12:45:24 -08:00
Peter Hawkins
59e5ce22d3 Avoid calls to warnings.catch_warnings in JAX core code.
warnings.catch_warnings is not thread-safe. However it is always used to avoid complex-to-real conversion warnings, which we can avoid in other ways.
2024-12-20 15:43:03 -05:00
Dimitar (Mitko) Asenov
dad23fed09 [Mosaic GPU] Add a lowering for simple async_load and async_store ops.
Only untransformed and unsliced loads/stores are supported for now. The rest will be a follow up.

PiperOrigin-RevId: 708347442
2024-12-20 09:38:13 -08:00
jax authors
01e8f889c2 Merge pull request #25616 from jakevdp:abstractify
PiperOrigin-RevId: 708340432
2024-12-20 09:09:32 -08:00
Oleg Shyshkov
db464b3f0a Clarify documentation for output_offsets operand of ragged_all_to_all.
PiperOrigin-RevId: 708321802
2024-12-20 07:52:11 -08:00
Adam Paszke
d2f937e241 Make jax.Arrays a necessary part of the cycle in the GC guard test
Otherwise, the cycle can be broken by clearing the references of the helper
objects, at which points the deallocation of arrays proceeds through regular
reference counting (and does not trigger logs!). I have not verified that
this is what happens, but the test has been mysteriously failing under a
number of configurations and this seems to fix it.

I added a note to the garbage collection guard to clarify that it's not
guaranteed to report all cycles.

PiperOrigin-RevId: 708320953
2024-12-20 07:48:04 -08:00
Jake VanderPlas
c560f8e06c Unify abstractify & shaped_abstractify rules 2024-12-20 04:28:19 -08:00
Christos Perivolaropoulos
20efbd965f [pallas:mosaic_gpu] Change the fori tests to also take the while_p path and fix the bug.
The bug was that bounds were dropped ctx.avals_in and then they were being
extracted. Extract them before dropping them.

PiperOrigin-RevId: 708266659
2024-12-20 03:50:34 -08:00
jax authors
5031b6f599 Merge pull request #25625 from mattjj:ref-errors-5
PiperOrigin-RevId: 708196762
2024-12-19 23:25:53 -08:00
Jevin Jiang
2faf540203 [Mosaic TPU] Add relayout-insertion pass and support bitwidth change for i1 vector relayout
We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.

PiperOrigin-RevId: 708143852
2024-12-19 19:56:40 -08:00
Matthew Johnson
b6482f126e add mutable array ref error checks to cond and custom_vjp 2024-12-20 01:44:50 +00:00
jax authors
64d3871e55 Merge pull request #25489 from emilyfertig:distributed-init-input-validation
PiperOrigin-RevId: 708087898
2024-12-19 17:03:26 -08:00
Jake VanderPlas
482a6e7394 Delete unused internal symbols
Followup to https://github.com/jax-ml/jax/pull/25614.

PiperOrigin-RevId: 708077981
2024-12-19 16:32:01 -08:00
Emily Fertig
a24e70320b Add more input validation to jax.distributed.initialize. 2024-12-19 16:26:34 -08:00
Justin Fu
d129438548 [Mosaic GPU] Prototype of a warp-specialized pipeline emitter for Mosaic GPU.
PiperOrigin-RevId: 708010809
2024-12-19 13:28:58 -08:00
jax authors
3e2f2aabae Merge pull request #25614 from jakevdp:dep-shaped-abstractify
PiperOrigin-RevId: 707973428
2024-12-19 11:24:51 -08:00
kaixih
307ea87a8d support head size of 256
Test large head size only on hopper+ gpus

Test large head size only on cudnn 9.5+
2024-12-19 18:38:06 +00:00
jax authors
dc8b786d12 Merge pull request #25606 from dfm:fft-nd
PiperOrigin-RevId: 707935849
2024-12-19 09:21:47 -08:00
jax authors
8402a9881e Merge pull request #25590 from jakevdp:fix-one-hot-float
PiperOrigin-RevId: 707922981
2024-12-19 08:31:49 -08:00
Adam Paszke
23000a3842 Always suppress the differing_executors Hypothesis health check
It's only relevant to notify about potential key collisions in the example
database, but we explicitly disable it, so it doesn't matter.

PiperOrigin-RevId: 707914664
2024-12-19 08:00:53 -08:00
Benjamin Chetioui
3915f4a147 [Mosaic GPU] Commit to using Vectors everywhere (and no Tensors).
PiperOrigin-RevId: 707912637
2024-12-19 07:51:58 -08:00
Dan Foreman-Mackey
c6131ee527 Add support for N-D FFTs with D>3. 2024-12-19 15:23:30 +00:00
Sergei Lebedev
af7a31f196 [pallas:triton] Fixed a typo in a type annotation
PiperOrigin-RevId: 707905188
2024-12-19 07:20:54 -08:00
Jake VanderPlas
8c3c441ee4 jax.nn.one_hot: deprecate non-integer inputs 2024-12-19 07:11:31 -08:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
jax authors
7680532512 Merge pull request #25595 from jakevdp:mv-shaped-abstractify
PiperOrigin-RevId: 707888615
2024-12-19 06:07:14 -08:00
Adam Paszke
ad00ec1dc9 [Mosaic TPU] Guard tests for new features by the libtpu version
PiperOrigin-RevId: 707875450
2024-12-19 05:04:09 -08:00
Adam Paszke
006c65d8d4 [Mosaic GPU] Add a new tiled layout, optimized for upcasting before WGMMA
PiperOrigin-RevId: 707860467
2024-12-19 04:03:24 -08:00