Ayaka
968bbd2bf2
Add a small atol bump to betainc
test in LaxVmapOpTest
...
PiperOrigin-RevId: 741529177
2025-03-28 08:09:51 -07:00
Jake VanderPlas
431c2c0807
cleanup now that we depend on ml_dtypes>=0.5
2025-03-28 07:44:38 -07:00
Dimitar (Mitko) Asenov
e679811c4a
[Mosaic GPU] Add warpgroup lowering for Exp2
in Pallas.
...
This change also enables tests for supported elementwise ops.
PiperOrigin-RevId: 741516852
2025-03-28 07:22:24 -07:00
Yash Katariya
563c3e2244
Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules
...
PiperOrigin-RevId: 741516445
2025-03-28 07:20:12 -07:00
Adam Paszke
39fb2a00a6
[Mosaic GPU] Add support for allocation and lowering of scratch semaphores
...
The semaphore arrays are allocated in GMEM and zeroed by XLA before the kernel begins.
PiperOrigin-RevId: 741494241
2025-03-28 05:43:53 -07:00
Adam Paszke
30451478c0
[Pallas][NFC] Move the remainder of Semaphore-related extended dtypes to Pallas core
...
This completes the move started in https://github.com/jax-ml/jax/pull/26673 .
PiperOrigin-RevId: 741487331
2025-03-28 05:10:10 -07:00
Rachel Han
a52f7b26e7
Add accuracy field to unary ops
...
* Cbrt
* Cos
* Exp, Exp2
* Expm1
* Log
* Logistic
* Log1p
* Rsqrt
* Sin
* Sqrt
* Tan
* Tanh
which allows users to select implementation that will satisfy the requested accuracy.
PiperOrigin-RevId: 741331787
2025-03-27 17:12:59 -07:00
Yash Katariya
25c106d132
Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add standard_insert_broadcast
for unary ops though)
...
* slicing.py
* windowed_reductions.py
* special.py
* convolution.py
* fft.py
* linalg.py
* ann.py
PiperOrigin-RevId: 741327361
2025-03-27 16:56:39 -07:00
Yash Katariya
71b36dca84
Sort the replicated_axes wrt mesh names in Shardy
...
PiperOrigin-RevId: 741287495
2025-03-27 14:44:02 -07:00
jax authors
22719dd445
Merge pull request #27445 from jburnim:jburnim_pallas_interpret_mode
...
PiperOrigin-RevId: 741279760
2025-03-27 14:20:21 -07:00
Bixia Zheng
b290c132dd
[jax:custom_partitioning] Raise an error when Shardy is used but the old sharding propagation callbacks instead of sharding rule are provided.
...
PiperOrigin-RevId: 741253832
2025-03-27 13:04:24 -07:00
Matthew Johnson
d8fc40f121
allow saved_input_vjp functions to be jit inputs/outputs
2025-03-27 18:53:03 +00:00
jax authors
aafbb01966
Merge pull request #27501 from jakevdp:shape-size-ndim-jax-array
...
PiperOrigin-RevId: 741222785
2025-03-27 11:30:07 -07:00
Parker Schuh
1719fa0d5b
Make sure array is copied under this situation:
...
```
x = np.arange(1000)
y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False)
z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False)
```
This condition will be true after this change `z.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()`
Also lift the restrictions that CopyToMemorySpace doesn't work sometimes for
matching src+dest memory spaces. We can always bounce through the host if there
is no more efficient copy.
PiperOrigin-RevId: 741200853
2025-03-27 10:27:26 -07:00
Peter Hawkins
083bdfc9cc
Add license headers to files that were missing them.
...
PiperOrigin-RevId: 741167870
2025-03-27 08:45:15 -07:00
Ayaka
875e4795c4
Update test_util.get_tpu_version()
...
PiperOrigin-RevId: 741139032
2025-03-27 07:03:23 -07:00
jax authors
8bd956d96a
[Pallas] Skip reads/writes from/to slices of kernel input/output buffers when the slices do not change between iterations of the grid loop that interprets kernels on CPU.
...
PiperOrigin-RevId: 741082349
2025-03-27 03:03:25 -07:00
Gunhyun Park
e1762b0af6
Assert unused variable in lax.all_to_all batching rule
...
P.S. minor improvement to code readability
PiperOrigin-RevId: 741051082
2025-03-27 00:47:13 -07:00
shuw
c7d04cc75a
Improve based on review 2
2025-03-27 05:09:25 +00:00
Parker Schuh
be1f649b51
Expose jax._src.lib.ifrt_version which tracks the version of
...
third_party/tensorflow code inside jax.
PiperOrigin-RevId: 740957982
2025-03-26 17:31:08 -07:00
kaixih
f949b8b8f6
Enable public doc for scaled dot
2025-03-27 00:05:28 +00:00
Parker Schuh
6033592a95
Rename xla_extension_version to jaxlib_extension_version to reflect its new
...
scope.
PiperOrigin-RevId: 740944270
2025-03-26 16:36:34 -07:00
Yash Katariya
e8038501d0
Fix a bug where jit was forwarding inputs to outputs even when donation was True for that inputs. This caused the output to be marked as deleted since the input was being forwarded to the output.
...
Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 740942785
2025-03-26 16:31:11 -07:00
Jake VanderPlas
667c4a0ee0
Support __jax_array__ for jnp.shape/jnp.size/jnp.ndim
2025-03-26 15:27:25 -07:00
jax authors
79ece131dc
Merge pull request #27404 from mattbahr:add-pascal-matrix
...
PiperOrigin-RevId: 740913011
2025-03-26 14:54:20 -07:00
Ayaka
ce3941c635
Add division-by-zero checks to jax.numpy functions
...
PiperOrigin-RevId: 740906595
2025-03-26 14:35:56 -07:00
jax authors
5c81d02769
Merge pull request #27494 from jakevdp:tri-indices-jax-array
...
PiperOrigin-RevId: 740904760
2025-03-26 14:31:26 -07:00
Ayaka
c450b69dd7
Add missing __len__
to MutableArray
...
Fixes https://github.com/jax-ml/jax/issues/27476
PiperOrigin-RevId: 740903637
2025-03-26 14:27:50 -07:00
Jake VanderPlas
66908372af
jnp.tri*_indices: support __jax_array__ inputs
2025-03-26 14:06:26 -07:00
Peter Hawkins
d9a6cd1a5e
Remove xla_client.make_gpu_client.
...
Cleanup; this code is not used any more because we use C API plugins instead.
PiperOrigin-RevId: 740887556
2025-03-26 13:41:32 -07:00
Yash Katariya
b92b9b0e26
Raise an informative error when the length of device_assignment doesn't match the mesh.size of out_avals. This happens when (1) we can't extract the device_assignment from the arguments and (2) there is no concrete mesh in context.
...
For example:
```
def test_random_normal_wo_mesh_context_error(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types=(AxisType.Explicit,) * 2)
s = NamedSharding(mesh, P('x', 'y'))
@jax.jit
def f(key):
out = jax.random.normal(key, shape=(8, 12), out_sharding=s)
self.assertEqual(out.aval.sharding.spec, P('x', 'y'))
self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh)
return out
key = jax.random.key(1)
with self.assertRaisesRegex(
ValueError,
'Length of device assignment.*is not equal to the size of the mesh'):
f(key)
```
PiperOrigin-RevId: 740886114
2025-03-26 13:37:15 -07:00
Jake VanderPlas
096810a721
[array API] make capabilities more accurate
2025-03-26 12:11:47 -07:00
Yash Katariya
ec2f0f5913
[sharding_in_types] Enable auto_axes to work without any mesh context manager. We extract the mesh from out_shardings
given. This allows APIs like random.uniform
to accept NamedSharding in out_sharding
argument and continue to work without a mesh context.
...
PiperOrigin-RevId: 740852542
2025-03-26 11:56:56 -07:00
Daniel Suo
e364abe961
Prune passthrough outputs in lax.switch.
2025-03-26 18:53:14 +00:00
Ayaka
feed69c561
Add nan checking to jax.numpy functions
...
PiperOrigin-RevId: 740838221
2025-03-26 11:19:22 -07:00
Gleb Pobudzey
2518e187f3
[Mosaic GPU] Support more layouts in the swap
lowering.
...
PiperOrigin-RevId: 740835345
2025-03-26 11:11:33 -07:00
Ayaka
b1b281a427
Prototype of adding error checking to jax.numpy functions
...
PiperOrigin-RevId: 740822504
2025-03-26 10:37:34 -07:00
jax authors
41fe8d9c6d
Merge pull request #27421 from jakevdp:finalize-deps
...
PiperOrigin-RevId: 740821740
2025-03-26 10:35:26 -07:00
jax authors
a04d14f589
Merge pull request #27448 from vfdev-5:fix-py314-do-not-return-from-finally
...
PiperOrigin-RevId: 740813434
2025-03-26 10:14:48 -07:00
Jake VanderPlas
91a07ea2e8
Clean up a number of finalized deprecations
2025-03-26 09:57:19 -07:00
jax authors
2b86f38585
[AutoPGLE] Prevent an AutoPGLE to run if user launched an external profiler.
...
Reverts d4745b9bd81b49e2a7a8938ea98516296d54635f
PiperOrigin-RevId: 740804528
2025-03-26 09:52:29 -07:00
Benjamin Chetioui
2057df13ba
[Pallas/Mosaic GPU] Fix copy_smem_to_gmem
lowering to not use a single_thread_predicate
when using warpgroup semantics.
...
Also avoid generating the predicate at all when using warpgroup semantics.
PiperOrigin-RevId: 740803927
2025-03-26 09:50:25 -07:00
Sergei Lebedev
6386efe369
[pallas:mosaic_gpu] plgpu.kernel
now accepts scratch shapes
...
This frees the caller from another level of indirection via `pl.run_scoped`.
PiperOrigin-RevId: 740802977
2025-03-26 09:47:09 -07:00
Christos Perivolaropoulos
9d768c4754
[pallas:mgpu] Use the ExitStack context to manage smem allocations.
...
PiperOrigin-RevId: 740790684
2025-03-26 09:10:01 -07:00
Benjamin Chetioui
dfa2f46968
[Pallas/Mosaic GPU] Delete mesh_cast_p
lowering rules. They don't seem to be used.
...
PiperOrigin-RevId: 740785108
2025-03-26 08:52:28 -07:00
vfdev-5
c159212439
Some codebase fixes required for python 3.14
...
- Fix for "SyntaxWarning: 'return' in a 'finally' block"
- Fix for "AttributeError: 'typing.Union' object attribute '__doc__' is read-only"
2025-03-26 14:16:56 +00:00
Sergei Lebedev
7a42e3d39d
[pallas:mosaic_gpu] thread_semantics=
should still default to lane-level
...
PiperOrigin-RevId: 740753009
2025-03-26 07:07:18 -07:00
Benjamin Chetioui
3f3081d46e
[Pallas/Mosaic GPU] Add a lowering rule for pjit.mesh_cast_p
for warpgroup semantics.
...
PiperOrigin-RevId: 740719326
2025-03-26 04:46:23 -07:00
Benjamin Chetioui
660f536300
[Pallas/Mosaic GPU] Add a lowering rule for lax.optimization_barrier_p
with warpgroup semantics.
...
PiperOrigin-RevId: 740684030
2025-03-26 02:22:41 -07:00
jax authors
89faa209e2
Merge pull request #27017 from mattjj:input-saved-vjp
...
PiperOrigin-RevId: 740617998
2025-03-25 22:03:56 -07:00