jax authors
3de7ecf6da
Merge pull request #27092 from pearu:pearu/gammainc-bug-fix
...
PiperOrigin-RevId: 736177398
2025-03-12 10:20:39 -07:00
jax authors
e7d10a2310
Merge pull request #27041 from carlosgmartin:fix_binomial_value_error
...
PiperOrigin-RevId: 736171463
2025-03-12 10:05:18 -07:00
Pearu Peterson
f608a8c502
Update gammainc and gammaincc against scipy 1.16: return nan whenever one of operands is nan.
2025-03-12 17:48:45 +02:00
Yash Katariya
abcc7fdf4c
[sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName]
to ShapedArray. Also add jax_varying_axes_in_types
config to hide this option under while we develop it.
...
PiperOrigin-RevId: 736141670
2025-03-12 08:29:16 -07:00
Sergei Lebedev
e33f3fc48b
[pallas:mosaic_gpu] Added support for reductions to the WG lowering
...
Note that
* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
the dialect lowering AFAICT.
PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Matthew Johnson
66a6eb299e
add autodiff rules for jax.lax.ragged_all_to_all collective
...
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.
PiperOrigin-RevId: 735957604
2025-03-11 18:22:02 -07:00
Yash Katariya
3a26804c68
Rename get_ty
to typeof
which is an alias of get_aval
...
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00
Sharad Vikram
c6b164dc09
[Pallas/Fuser] Add custom evaluate to allow/disallow transposes
...
PiperOrigin-RevId: 735931978
2025-03-11 16:35:49 -07:00
Yash Katariya
f45cbf3342
Fix a bug where full
and use_mesh
outside jit did not work because the shard
passed to make_array_from_callback
was sharded on all devices instead of just 1 device.
...
This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context.
PiperOrigin-RevId: 735909962
2025-03-11 15:25:46 -07:00
Jevin Jiang
29bfd00f9c
[Pallas TPU] Fix preferred_element_type propagation in dot_general with const
...
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
shuw
f9aef8a189
Support nvfp4
2025-03-11 19:33:25 +00:00
Pearu Peterson
82b2591b21
Fix scipy.special.gammainc/gammaincc evaluation at boundary points
2025-03-11 21:18:47 +02:00
jax authors
c2c68c018f
Merge pull request #27059 from jakevdp:fix-while-loop
...
PiperOrigin-RevId: 735828960
2025-03-11 11:32:00 -07:00
Gunhyun Park
d191927b24
Fix syntax error and typos for composite primitive docstring.
...
PiperOrigin-RevId: 735808000
2025-03-11 10:37:07 -07:00
Jake VanderPlas
4ae3211ea2
jax.disable_jit: ensure while_loop behaves similarly to non-disable_jit version
2025-03-11 09:53:34 -07:00
Yash Katariya
76dec38286
Under pjit the with mesh:
context will use use_mesh(mesh): jit
instead of tracking separately using resource_env
.
...
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.
PiperOrigin-RevId: 735602187
2025-03-10 20:21:02 -07:00
jax authors
02505fa757
[Pallas TPU] Remove next_slot
SMEM tensor from pipeline emitter
...
PiperOrigin-RevId: 735564365
2025-03-10 17:19:39 -07:00
Ayaka
988a1208a9
Better error message when raise_if_error()
is called within a traced context
...
PiperOrigin-RevId: 735557928
2025-03-10 16:55:06 -07:00
jax authors
aceae84fab
[Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
...
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Sharad Vikram
81dde225b0
[Pallas/Fuser] Add select_n push rule
...
PiperOrigin-RevId: 735510713
2025-03-10 14:23:01 -07:00
jax authors
261e6e5fdc
Merge pull request #27038 from jakevdp:vmap-sentinel
...
PiperOrigin-RevId: 735510065
2025-03-10 14:21:11 -07:00
jax authors
c942b0fef0
Merge pull request #26977 from jakevdp:fix-expn
...
PiperOrigin-RevId: 735506133
2025-03-10 14:09:32 -07:00
Sharad Vikram
87272fbe93
[Pallas/Fuser] Add debug option to fuser.fuse that prints out jaxpr
...
PiperOrigin-RevId: 735505460
2025-03-10 14:07:26 -07:00
carlosgmartin
8b6ca56417
Fix the ValueError message for random.binomial (forgot to use string formatting).
2025-03-10 16:38:03 -04:00
jax authors
affe2e734e
Rename dot_with_no_batch_dims_saveable
to dots_with_no_batch_dims_saveable
for internal consistency
...
PiperOrigin-RevId: 735484326
2025-03-10 13:04:49 -07:00
Praveen Narayanan
b6d4fe5387
Define lax.ragged_dot_general and express lax.ragged_dot in terms of it.
...
PiperOrigin-RevId: 735471245
2025-03-10 12:25:22 -07:00
jax authors
18f2f19c1a
Merge pull request #26525 from wenscarl:e2m1fn
...
PiperOrigin-RevId: 735457804
2025-03-10 11:46:18 -07:00
Jacob Burnim
73d20cd62a
[Pallas] Small fix to TPU interpret mode (input_output_aliases + scalar args).
...
PiperOrigin-RevId: 735455671
2025-03-10 11:40:10 -07:00
Jake VanderPlas
8ecadfdf9d
Internal: make it easier to detect the vmap sentinel
2025-03-10 11:37:50 -07:00
Michael Whittaker
5cb29949d4
Warn the user if transparent huge pages aren't enabled.
...
PiperOrigin-RevId: 735431881
2025-03-10 10:37:58 -07:00
jax authors
14b215fe76
Merge pull request #27032 from dfm:lax-dtype
...
PiperOrigin-RevId: 735424674
2025-03-10 10:18:58 -07:00
jax authors
ab0ce8a448
Merge pull request #26811 from dfm:direct-lin
...
PiperOrigin-RevId: 735388827
2025-03-10 08:39:49 -07:00
Dan Foreman-Mackey
21884d4a14
Move (most) jaxlib linalg custom call registration into JAX.
...
My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX.
It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future.
This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that.
PiperOrigin-RevId: 735381736
2025-03-10 08:17:44 -07:00
Dan Foreman-Mackey
4eada56027
Avoid using array operations within lax.py operations.
2025-03-10 11:04:32 -04:00
Sergei Lebedev
91340ea0a7
[pallas:mosaic_gpu] Added support for math functions to the WG lowering
...
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
Benjamin Chetioui
75d8702023
[Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
...
Enable some of the pre-existing Pallas `ops_test`s for testing.
PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Dan Foreman-Mackey
36d515ed2c
A few more fixes for debug_info tests with direct_linearize.
2025-03-08 07:47:24 -05:00
Jevin Jiang
0f0636afab
[Mosaic TPU][Pallas] Add pl.reciprocal
...
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
jax authors
4988adccf1
Merge pull request #27010 from mattjj:direct-linearize-fixes-3
...
PiperOrigin-RevId: 734747001
2025-03-07 18:15:02 -08:00
Matthew Johnson
fe26c19b92
[direct-linearize] fix name_stack bugs
...
Surprisingly, the bug was tracked down to #26111 aka cl/730939406, specifically
the new implementation of reset_name_stack in source_info_util.py.
To repro, use the before-this-commit implementation of reset_name_stack (left
commented-out in the file), and run
```
JAX_USE_DIRECT_LINEARIZE=1 python tests/name_stack_test.py NameStackTransformationTest.test_nested_jit_stack
```
2025-03-08 01:51:19 +00:00
Matthew Johnson
251b93ebd7
fixups that we meant to include in #26427
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-08 00:03:26 +00:00
jax authors
6095af050f
Merge pull request #26427 from mattjj:direct-linearize-fixes
...
PiperOrigin-RevId: 734687601
2025-03-07 14:22:16 -08:00
jax authors
d849779689
Merge pull request #27001 from mattjj:yash-scan
...
PiperOrigin-RevId: 734685031
2025-03-07 14:14:30 -08:00
jax authors
1870176eb3
Merge pull request #26979 from mattjj:26936
...
PiperOrigin-RevId: 734674945
2025-03-07 13:43:55 -08:00
Matthew Johnson
f4f31f89ae
[scan] when num_trips==0, don't generate weird size-zero reshapes
2025-03-07 21:35:40 +00:00
Matthew Johnson
7c2f842353
shard_map and other fixes to direct-linearize
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Matthew Johnson
0e30a3ace9
[mutable-arrays] read values should have the same explicit sharding as ref
...
fixes #26936
2025-03-07 20:53:29 +00:00
jax authors
ccf7278292
Add the len(arg) to the error message for static_argnums
...
Helps reduce the confusion on what is considered an argnum.
Ideally there should be static_argkwg
PiperOrigin-RevId: 734591856
2025-03-07 09:49:49 -08:00
Yash Katariya
9f37b5197f
[sharding_in_types] Fix a bug where empty_array
in scan was created with the wrong spec when unroll > 1
.
...
PiperOrigin-RevId: 734591110
2025-03-07 09:47:32 -08:00
Christos Perivolaropoulos
eeccc67c0b
[mgpu] Debug print arrays.
...
PiperOrigin-RevId: 734576543
2025-03-07 08:58:25 -08:00