9119 Commits

Author SHA1 Message Date
Ayaka
968bbd2bf2 Add a small atol bump to betainc test in LaxVmapOpTest
PiperOrigin-RevId: 741529177
2025-03-28 08:09:51 -07:00
Jake VanderPlas
431c2c0807 cleanup now that we depend on ml_dtypes>=0.5 2025-03-28 07:44:38 -07:00
Dimitar (Mitko) Asenov
e679811c4a [Mosaic GPU] Add warpgroup lowering for Exp2 in Pallas.
This change also enables tests for supported elementwise ops.

PiperOrigin-RevId: 741516852
2025-03-28 07:22:24 -07:00
Yash Katariya
563c3e2244 Add standard pbroadcast rules to more primitives. This should cover all primitives from which shard_map registered standard_rewrite rules
PiperOrigin-RevId: 741516445
2025-03-28 07:20:12 -07:00
Adam Paszke
39fb2a00a6 [Mosaic GPU] Add support for allocation and lowering of scratch semaphores
The semaphore arrays are allocated in GMEM and zeroed by XLA before the kernel begins.

PiperOrigin-RevId: 741494241
2025-03-28 05:43:53 -07:00
Adam Paszke
30451478c0 [Pallas][NFC] Move the remainder of Semaphore-related extended dtypes to Pallas core
This completes the move started in https://github.com/jax-ml/jax/pull/26673.

PiperOrigin-RevId: 741487331
2025-03-28 05:10:10 -07:00
Rachel Han
a52f7b26e7 Add accuracy field to unary ops
* Cbrt
  * Cos
  * Exp, Exp2
  * Expm1
  * Log
  * Logistic
  * Log1p
  * Rsqrt
  * Sin
  * Sqrt
  * Tan
  * Tanh
which allows users to select implementation that will satisfy the requested accuracy.

PiperOrigin-RevId: 741331787
2025-03-27 17:12:59 -07:00
Yash Katariya
25c106d132 Add standard_insert_pbroadcasts and standard_vma_rule to all primitives in following files: (Don't add standard_insert_broadcast for unary ops though)
* slicing.py
* windowed_reductions.py
* special.py
* convolution.py
* fft.py
* linalg.py
* ann.py

PiperOrigin-RevId: 741327361
2025-03-27 16:56:39 -07:00
Yash Katariya
71b36dca84 Sort the replicated_axes wrt mesh names in Shardy
PiperOrigin-RevId: 741287495
2025-03-27 14:44:02 -07:00
jax authors
22719dd445 Merge pull request #27445 from jburnim:jburnim_pallas_interpret_mode
PiperOrigin-RevId: 741279760
2025-03-27 14:20:21 -07:00
Bixia Zheng
b290c132dd [jax:custom_partitioning] Raise an error when Shardy is used but the old sharding propagation callbacks instead of sharding rule are provided.
PiperOrigin-RevId: 741253832
2025-03-27 13:04:24 -07:00
Matthew Johnson
d8fc40f121 allow saved_input_vjp functions to be jit inputs/outputs 2025-03-27 18:53:03 +00:00
jax authors
aafbb01966 Merge pull request #27501 from jakevdp:shape-size-ndim-jax-array
PiperOrigin-RevId: 741222785
2025-03-27 11:30:07 -07:00
Parker Schuh
1719fa0d5b Make sure array is copied under this situation:
```
x = np.arange(1000)
y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False)
z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False)
```

This condition will be true after this change `z.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()`

Also lift the restrictions that CopyToMemorySpace doesn't work sometimes for
matching src+dest memory spaces. We can always bounce through the host if there
is no more efficient copy.

PiperOrigin-RevId: 741200853
2025-03-27 10:27:26 -07:00
Peter Hawkins
083bdfc9cc Add license headers to files that were missing them.
PiperOrigin-RevId: 741167870
2025-03-27 08:45:15 -07:00
Ayaka
875e4795c4 Update test_util.get_tpu_version()
PiperOrigin-RevId: 741139032
2025-03-27 07:03:23 -07:00
jax authors
8bd956d96a [Pallas] Skip reads/writes from/to slices of kernel input/output buffers when the slices do not change between iterations of the grid loop that interprets kernels on CPU.
PiperOrigin-RevId: 741082349
2025-03-27 03:03:25 -07:00
Gunhyun Park
e1762b0af6 Assert unused variable in lax.all_to_all batching rule
P.S. minor improvement to code readability

PiperOrigin-RevId: 741051082
2025-03-27 00:47:13 -07:00
shuw
c7d04cc75a Improve based on review 2 2025-03-27 05:09:25 +00:00
Parker Schuh
be1f649b51 Expose jax._src.lib.ifrt_version which tracks the version of
third_party/tensorflow code inside jax.

PiperOrigin-RevId: 740957982
2025-03-26 17:31:08 -07:00
kaixih
f949b8b8f6 Enable public doc for scaled dot 2025-03-27 00:05:28 +00:00
Parker Schuh
6033592a95 Rename xla_extension_version to jaxlib_extension_version to reflect its new
scope.

PiperOrigin-RevId: 740944270
2025-03-26 16:36:34 -07:00
Yash Katariya
e8038501d0 Fix a bug where jit was forwarding inputs to outputs even when donation was True for that inputs. This caused the output to be marked as deleted since the input was being forwarded to the output.
Since this functionality was added for a dynamic shapes experiment, only enable it when dynamic_shapes config is True.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 740942785
2025-03-26 16:31:11 -07:00
Jake VanderPlas
667c4a0ee0 Support __jax_array__ for jnp.shape/jnp.size/jnp.ndim 2025-03-26 15:27:25 -07:00
jax authors
79ece131dc Merge pull request #27404 from mattbahr:add-pascal-matrix
PiperOrigin-RevId: 740913011
2025-03-26 14:54:20 -07:00
Ayaka
ce3941c635 Add division-by-zero checks to jax.numpy functions
PiperOrigin-RevId: 740906595
2025-03-26 14:35:56 -07:00
jax authors
5c81d02769 Merge pull request #27494 from jakevdp:tri-indices-jax-array
PiperOrigin-RevId: 740904760
2025-03-26 14:31:26 -07:00
Ayaka
c450b69dd7 Add missing __len__ to MutableArray
Fixes https://github.com/jax-ml/jax/issues/27476

PiperOrigin-RevId: 740903637
2025-03-26 14:27:50 -07:00
Jake VanderPlas
66908372af jnp.tri*_indices: support __jax_array__ inputs 2025-03-26 14:06:26 -07:00
Peter Hawkins
d9a6cd1a5e Remove xla_client.make_gpu_client.
Cleanup; this code is not used any more because we use C API plugins instead.

PiperOrigin-RevId: 740887556
2025-03-26 13:41:32 -07:00
Yash Katariya
b92b9b0e26 Raise an informative error when the length of device_assignment doesn't match the mesh.size of out_avals. This happens when (1) we can't extract the device_assignment from the arguments and (2) there is no concrete mesh in context.
For example:

```
def test_random_normal_wo_mesh_context_error(self):
    mesh = jtu.create_mesh((2, 2), ('x', 'y'),
                           axis_types=(AxisType.Explicit,) * 2)
    s = NamedSharding(mesh, P('x', 'y'))

    @jax.jit
    def f(key):
      out = jax.random.normal(key, shape=(8, 12), out_sharding=s)
      self.assertEqual(out.aval.sharding.spec, P('x', 'y'))
      self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh)
      return out

    key = jax.random.key(1)
    with self.assertRaisesRegex(
        ValueError,
        'Length of device assignment.*is not equal to the size of the mesh'):
      f(key)
```

PiperOrigin-RevId: 740886114
2025-03-26 13:37:15 -07:00
Jake VanderPlas
096810a721 [array API] make capabilities more accurate 2025-03-26 12:11:47 -07:00
Yash Katariya
ec2f0f5913 [sharding_in_types] Enable auto_axes to work without any mesh context manager. We extract the mesh from out_shardings given. This allows APIs like random.uniform to accept NamedSharding in out_sharding argument and continue to work without a mesh context.
PiperOrigin-RevId: 740852542
2025-03-26 11:56:56 -07:00
Daniel Suo
e364abe961 Prune passthrough outputs in lax.switch. 2025-03-26 18:53:14 +00:00
Ayaka
feed69c561 Add nan checking to jax.numpy functions
PiperOrigin-RevId: 740838221
2025-03-26 11:19:22 -07:00
Gleb Pobudzey
2518e187f3 [Mosaic GPU] Support more layouts in the swap lowering.
PiperOrigin-RevId: 740835345
2025-03-26 11:11:33 -07:00
Ayaka
b1b281a427 Prototype of adding error checking to jax.numpy functions
PiperOrigin-RevId: 740822504
2025-03-26 10:37:34 -07:00
jax authors
41fe8d9c6d Merge pull request #27421 from jakevdp:finalize-deps
PiperOrigin-RevId: 740821740
2025-03-26 10:35:26 -07:00
jax authors
a04d14f589 Merge pull request #27448 from vfdev-5:fix-py314-do-not-return-from-finally
PiperOrigin-RevId: 740813434
2025-03-26 10:14:48 -07:00
Jake VanderPlas
91a07ea2e8 Clean up a number of finalized deprecations 2025-03-26 09:57:19 -07:00
jax authors
2b86f38585 [AutoPGLE] Prevent an AutoPGLE to run if user launched an external profiler.
Reverts d4745b9bd81b49e2a7a8938ea98516296d54635f

PiperOrigin-RevId: 740804528
2025-03-26 09:52:29 -07:00
Benjamin Chetioui
2057df13ba [Pallas/Mosaic GPU] Fix copy_smem_to_gmem lowering to not use a single_thread_predicate when using warpgroup semantics.
Also avoid generating the predicate at all when using warpgroup semantics.

PiperOrigin-RevId: 740803927
2025-03-26 09:50:25 -07:00
Sergei Lebedev
6386efe369 [pallas:mosaic_gpu] plgpu.kernel now accepts scratch shapes
This frees the caller from another level of indirection via `pl.run_scoped`.

PiperOrigin-RevId: 740802977
2025-03-26 09:47:09 -07:00
Christos Perivolaropoulos
9d768c4754 [pallas:mgpu] Use the ExitStack context to manage smem allocations.
PiperOrigin-RevId: 740790684
2025-03-26 09:10:01 -07:00
Benjamin Chetioui
dfa2f46968 [Pallas/Mosaic GPU] Delete mesh_cast_p lowering rules. They don't seem to be used.
PiperOrigin-RevId: 740785108
2025-03-26 08:52:28 -07:00
vfdev-5
c159212439 Some codebase fixes required for python 3.14
- Fix for "SyntaxWarning: 'return' in a 'finally' block"
- Fix for "AttributeError: 'typing.Union' object attribute '__doc__' is read-only"
2025-03-26 14:16:56 +00:00
Sergei Lebedev
7a42e3d39d [pallas:mosaic_gpu] thread_semantics= should still default to lane-level
PiperOrigin-RevId: 740753009
2025-03-26 07:07:18 -07:00
Benjamin Chetioui
3f3081d46e [Pallas/Mosaic GPU] Add a lowering rule for pjit.mesh_cast_p for warpgroup semantics.
PiperOrigin-RevId: 740719326
2025-03-26 04:46:23 -07:00
Benjamin Chetioui
660f536300 [Pallas/Mosaic GPU] Add a lowering rule for lax.optimization_barrier_p with warpgroup semantics.
PiperOrigin-RevId: 740684030
2025-03-26 02:22:41 -07:00
jax authors
89faa209e2 Merge pull request #27017 from mattjj:input-saved-vjp
PiperOrigin-RevId: 740617998
2025-03-25 22:03:56 -07:00