Pearu Peterson
50670bd907
Fix log10 and log2 for large inputs.
2025-01-01 12:45:39 +02:00
Yunlong Liu
97b1faacdd
Fixes the random key sharding in shard_map.
2024-12-29 18:43:21 +00:00
Dan Foreman-Mackey
cb4d97aa1f
Move jex.ffi to jax.ffi.
2024-12-29 13:06:19 +00:00
Sergei Lebedev
76ccb199fd
[pallas:mosaic_gpu] Added some runtime type checking to copy_*
and barrier_*
primitives
...
PiperOrigin-RevId: 710302436
2024-12-28 09:02:43 -08:00
jax authors
b6aead6f3a
[AutoPGLE] Explicitly disable command buffers when profiler is used.
...
PiperOrigin-RevId: 709475833
2024-12-24 21:31:05 -08:00
Sergei Lebedev
44333e1cfb
[pallas:mosaic_gpu] Addressed a todo in broadcasted_iota
lowering
...
PiperOrigin-RevId: 709310152
2024-12-24 04:32:29 -08:00
jax authors
4eff1316db
Merge pull request #25672 from jakevdp:finalize-dep
...
PiperOrigin-RevId: 709284584
2024-12-24 02:08:51 -08:00
jax authors
c57b49c606
Merge pull request #25669 from jakevdp:undep
...
PiperOrigin-RevId: 709104784
2024-12-23 11:06:24 -08:00
Jake VanderPlas
40fe4b8797
Finalize deprecation of some symbols from jax.lib.xla_client
2024-12-23 10:14:16 -08:00
Jake VanderPlas
ccc3a29537
Internal: use a single registry for abstractify APIs
2024-12-23 08:44:35 -08:00
jax authors
6c85e54aad
Merge pull request #25662 from liblaf:main
...
PiperOrigin-RevId: 709058676
2024-12-23 07:38:25 -08:00
jax authors
704185ea25
Merge pull request #24607 from kaixih:support_head_size_256
...
PiperOrigin-RevId: 709058104
2024-12-23 07:36:21 -08:00
Jake VanderPlas
cb10710c92
Remove casting from jax.nn.one_hot
...
This change was made after the most recent release, so is safe
to remove. Casting float to int potentially changes intentional
beavior: e.g. NaN casts to 0. Some downstream users currently
use NaN to mark rows which should have no one-hot entry.
2024-12-23 07:33:49 -08:00
Chris Jones
83e60a9697
[pallas:triton] Add support for lowering int4
load.
...
PiperOrigin-RevId: 709032308
2024-12-23 05:12:46 -08:00
Sergei Lebedev
a51d627941
[pallas:mosaic_gpu] Reduced duplication between _ensure_fa
and _ensure_ir_value
...
PiperOrigin-RevId: 709030824
2024-12-23 05:04:06 -08:00
Sergei Lebedev
3e7f48114c
[pallas:mosaic_gpu] Updated the lowering following the changes in in Mosaic GPU internals
...
PiperOrigin-RevId: 709009048
2024-12-23 03:14:26 -08:00
liblaf
75b56548e2
Fix a typo in documentation for pinv
function.
2024-12-23 17:20:33 +08:00
jax authors
1719986aaa
[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.
...
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.
This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.
And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.
PiperOrigin-RevId: 708752566
2024-12-22 00:50:51 -08:00
jax authors
1c0dee8012
Merge pull request #25650 from jakevdp:view-int4
...
PiperOrigin-RevId: 708468858
2024-12-20 17:31:31 -08:00
John QiangZhang
e560c6a45c
Change the namespace name to avoid using export
c++ keyword on namespace.
...
PiperOrigin-RevId: 708450293
2024-12-20 16:02:15 -08:00
Jake VanderPlas
75f36dc3ea
Support int4/uint4 in jnp.ndarray.view
2024-12-20 13:57:40 -08:00
jax authors
44d67e1379
Merge pull request #25648 from hawkinsp:warnings3
...
PiperOrigin-RevId: 708415848
2024-12-20 13:41:43 -08:00
Jake VanderPlas
beee98ab4a
Add int4/uint4 support to bitcast_convert_type
2024-12-20 12:45:24 -08:00
Peter Hawkins
59e5ce22d3
Avoid calls to warnings.catch_warnings in JAX core code.
...
warnings.catch_warnings is not thread-safe. However it is always used to avoid complex-to-real conversion warnings, which we can avoid in other ways.
2024-12-20 15:43:03 -05:00
Dimitar (Mitko) Asenov
dad23fed09
[Mosaic GPU] Add a lowering for simple async_load
and async_store
ops.
...
Only untransformed and unsliced loads/stores are supported for now. The rest will be a follow up.
PiperOrigin-RevId: 708347442
2024-12-20 09:38:13 -08:00
jax authors
01e8f889c2
Merge pull request #25616 from jakevdp:abstractify
...
PiperOrigin-RevId: 708340432
2024-12-20 09:09:32 -08:00
Oleg Shyshkov
db464b3f0a
Clarify documentation for output_offsets operand of ragged_all_to_all.
...
PiperOrigin-RevId: 708321802
2024-12-20 07:52:11 -08:00
Adam Paszke
d2f937e241
Make jax.Arrays a necessary part of the cycle in the GC guard test
...
Otherwise, the cycle can be broken by clearing the references of the helper
objects, at which points the deallocation of arrays proceeds through regular
reference counting (and does not trigger logs!). I have not verified that
this is what happens, but the test has been mysteriously failing under a
number of configurations and this seems to fix it.
I added a note to the garbage collection guard to clarify that it's not
guaranteed to report all cycles.
PiperOrigin-RevId: 708320953
2024-12-20 07:48:04 -08:00
Jake VanderPlas
c560f8e06c
Unify abstractify & shaped_abstractify rules
2024-12-20 04:28:19 -08:00
Christos Perivolaropoulos
20efbd965f
[pallas:mosaic_gpu] Change the fori tests to also take the while_p path and fix the bug.
...
The bug was that bounds were dropped ctx.avals_in and then they were being
extracted. Extract them before dropping them.
PiperOrigin-RevId: 708266659
2024-12-20 03:50:34 -08:00
jax authors
5031b6f599
Merge pull request #25625 from mattjj:ref-errors-5
...
PiperOrigin-RevId: 708196762
2024-12-19 23:25:53 -08:00
Jevin Jiang
2faf540203
[Mosaic TPU] Add relayout-insertion pass and support bitwidth change for i1 vector relayout
...
We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.
PiperOrigin-RevId: 708143852
2024-12-19 19:56:40 -08:00
Matthew Johnson
b6482f126e
add mutable array ref error checks to cond and custom_vjp
2024-12-20 01:44:50 +00:00
jax authors
64d3871e55
Merge pull request #25489 from emilyfertig:distributed-init-input-validation
...
PiperOrigin-RevId: 708087898
2024-12-19 17:03:26 -08:00
Jake VanderPlas
482a6e7394
Delete unused internal symbols
...
Followup to https://github.com/jax-ml/jax/pull/25614 .
PiperOrigin-RevId: 708077981
2024-12-19 16:32:01 -08:00
Emily Fertig
a24e70320b
Add more input validation to jax.distributed.initialize.
2024-12-19 16:26:34 -08:00
Justin Fu
d129438548
[Mosaic GPU] Prototype of a warp-specialized pipeline emitter for Mosaic GPU.
...
PiperOrigin-RevId: 708010809
2024-12-19 13:28:58 -08:00
jax authors
3e2f2aabae
Merge pull request #25614 from jakevdp:dep-shaped-abstractify
...
PiperOrigin-RevId: 707973428
2024-12-19 11:24:51 -08:00
kaixih
307ea87a8d
support head size of 256
...
Test large head size only on hopper+ gpus
Test large head size only on cudnn 9.5+
2024-12-19 18:38:06 +00:00
jax authors
dc8b786d12
Merge pull request #25606 from dfm:fft-nd
...
PiperOrigin-RevId: 707935849
2024-12-19 09:21:47 -08:00
jax authors
8402a9881e
Merge pull request #25590 from jakevdp:fix-one-hot-float
...
PiperOrigin-RevId: 707922981
2024-12-19 08:31:49 -08:00
Adam Paszke
23000a3842
Always suppress the differing_executors Hypothesis health check
...
It's only relevant to notify about potential key collisions in the example
database, but we explicitly disable it, so it doesn't matter.
PiperOrigin-RevId: 707914664
2024-12-19 08:00:53 -08:00
Benjamin Chetioui
3915f4a147
[Mosaic GPU] Commit to using Vector
s everywhere (and no Tensor
s).
...
PiperOrigin-RevId: 707912637
2024-12-19 07:51:58 -08:00
Dan Foreman-Mackey
c6131ee527
Add support for N-D FFTs with D>3.
2024-12-19 15:23:30 +00:00
Sergei Lebedev
af7a31f196
[pallas:triton] Fixed a typo in a type annotation
...
PiperOrigin-RevId: 707905188
2024-12-19 07:20:54 -08:00
Jake VanderPlas
8c3c441ee4
jax.nn.one_hot: deprecate non-integer inputs
2024-12-19 07:11:31 -08:00
Jake VanderPlas
5dc37d3f70
Remove internal uses of api_util.shaped_abstractify
2024-12-19 07:06:36 -08:00
jax authors
7680532512
Merge pull request #25595 from jakevdp:mv-shaped-abstractify
...
PiperOrigin-RevId: 707888615
2024-12-19 06:07:14 -08:00
Adam Paszke
ad00ec1dc9
[Mosaic TPU] Guard tests for new features by the libtpu version
...
PiperOrigin-RevId: 707875450
2024-12-19 05:04:09 -08:00
Adam Paszke
006c65d8d4
[Mosaic GPU] Add a new tiled layout, optimized for upcasting before WGMMA
...
PiperOrigin-RevId: 707860467
2024-12-19 04:03:24 -08:00