7320 Commits

Author SHA1 Message Date
Jake VanderPlas
31e2358887 test: work around issue with kstest in scipy>1.12 2024-04-03 11:17:56 -07:00
jax authors
57ee6b7550 Merge pull request #20560 from pearu:pearu/log1p-fixes
PiperOrigin-RevId: 621562841
2024-04-03 10:17:11 -07:00
Pearu Peterson
9a7fb898d4 Workaround mpmath bug (mpmath/mpmath#774) in log1p at complex infinities
Temporarily disable arctanh success tests that depend on log1p fixes
2024-04-03 18:48:26 +03:00
George Necula
35b1cb799a [callback] Allow external callbacks to return 64-bit values in 32-bit mode
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  #20433 to add a canonicalization
step instead of an error.
2024-04-03 11:15:11 +01:00
Peter Hawkins
b7401872d5 Disable a random distribution test that appears to fail under scipy 1.13.0rc1.
PiperOrigin-RevId: 621352047
2024-04-02 18:11:30 -07:00
jax authors
6b582f5977 Merge pull request #20552 from jakevdp:geomspace-complex
PiperOrigin-RevId: 621348396
2024-04-02 17:54:51 -07:00
jax authors
88dd29a0b5 Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.

PiperOrigin-RevId: 621341638
2024-04-02 17:30:52 -07:00
Jake VanderPlas
fd7c85b349 jnp.geomspace: make complex behavior consistent with NumPy 2.0 2024-04-02 16:12:49 -07:00
Sergei Lebedev
f74f4ed48b Removed unnecessary BUILD dependencies from :ops_test
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
jax authors
e282bf57db Merge pull request #20536 from jakevdp:broadcast-to
PiperOrigin-RevId: 621287464
2024-04-02 13:59:12 -07:00
Jake VanderPlas
37a34bf6a3 test: remove solve test case that is invalid under NumPy 2.0 2024-04-02 13:10:26 -07:00
jax authors
87b869ddd8 Merge pull request #20543 from jakevdp:skip-geomspace
PiperOrigin-RevId: 621267031
2024-04-02 12:53:28 -07:00
Jake VanderPlas
c79f54f7ff test: skip complex geomspace test under numpy 2.0 2024-04-02 12:32:02 -07:00
Peter Hawkins
60458bb36a Fail gracefully in lobpcg_test if matplotlib isn't installed.
There isn't yet a matplotlib that supports NumPy 2.0, so we need to support running tests without it.

PiperOrigin-RevId: 621228727
2024-04-02 10:48:46 -07:00
Jake VanderPlas
6de6983d59 jnp.broadcast_to: better error for invalid shape 2024-04-02 08:38:51 -07:00
George Necula
bff24c6d6f [callback] Improve caching effectiveness in presence of callbacks.
Previously, the user-provided Python callback function was first
flattened and then the result passed as a primitive parameter to
the callback primitives. This means that two separate io_callback
invocations with the same Python callable will generate different
Jaxprs. To prevent this we defer the flattening to lowering time.
2024-04-02 15:33:24 +02:00
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
Peter Hawkins
011ced4431 Guard test that requires two devices with device_count() check.
PiperOrigin-RevId: 620921563
2024-04-01 12:32:54 -07:00
jax authors
29cb89554e Merge pull request #20454 from pearu:pearu/log1p-disable-tests
PiperOrigin-RevId: 620856592
2024-04-01 08:35:12 -07:00
Peter Hawkins
77ff8a2339 [PJRT:CPU] Fix thread-pool stack sizes to 2MiB.
The default thread pool size is too small on Mac OS.

An older version of this runtime based on StreamExecutor set a 2MiB stack size as well, but that change was most likely lost during the TFRT rewrite.

Fixes https://github.com/google/jax/issues/20428

PiperOrigin-RevId: 620853544
2024-04-01 08:20:36 -07:00
Dinghua Li
7f9ff82e8f Disable asan for paged_attention_kernel_test.
PiperOrigin-RevId: 620841292
2024-04-01 07:18:13 -07:00
Jieying Luo
68c674d106 [PJRT C API] Add a PJRT extension to register custom partitioner.
- This extension has one C API which registers a custom partitioner with callbacks from the input.
- Update xla_client.register_custom_call_partitioner to take an optional PJRT_Api* input.
- Add xla_bridge.register_plugin_initialization_callbacks to register callbacks to be called with PJRT_Api* after plugins are discovered.

PiperOrigin-RevId: 620357554
2024-03-29 15:40:26 -07:00
Son Tuan Vu
3d6af5ada6 Breakpoints can be reordered inside of vmap on GPU too.
There is no data dependence between these breakpoints (the breakpoints are lowered into custom call that returns nothing, so there is no way to enforce their relative order)

Thus we are relaxing this ordering constraint in debugger test for all backends.

PiperOrigin-RevId: 620355448
2024-03-29 15:30:46 -07:00
jax authors
750487f2cf Adjusts error tolerance for lax_control_flow_test
PiperOrigin-RevId: 620343970
2024-03-29 14:40:39 -07:00
Trevor Gale
80c305da7b Add MegaBlox grouped matrix multiplication kernels for TPU.
PiperOrigin-RevId: 620331883
2024-03-29 13:50:49 -07:00
Yunlong Liu
f625306ec4 Enable JAX memories test with the new pinned host memory space.
PiperOrigin-RevId: 620303609
2024-03-29 11:56:04 -07:00
Peter Hawkins
8815b236b6 Disable collective broadcast pmap test on older jaxlibs.
Collective broadcast was only recently added to xla.

PiperOrigin-RevId: 620287470
2024-03-29 10:59:10 -07:00
George Necula
43cf559f57 Replace internal usage of stop_outfeed_receiver with _deprecated_stop_outfeed_receiver.
The jax.experimental.host_callback module is deprecated and will be removed.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 620237081
2024-03-29 07:16:08 -07:00
Yash Katariya
84156f359f Add identity jit tests to go from pinned_host -> device and vice versa
PiperOrigin-RevId: 620114420
2024-03-28 18:20:32 -07:00
Yash Katariya
c846233089 Use jax.random.key instead of the old jax.random.PRNGKey
PiperOrigin-RevId: 620088823
2024-03-28 16:23:39 -07:00
Dinghua Li
8bf3f47f02 Open source PagedAttention TPU kernel.
PiperOrigin-RevId: 620042536
2024-03-28 13:36:02 -07:00
Peter Hawkins
67ea800361 Skip Pallas tests on GPUs earlier than Ampere (SM 8.0).
Upstream Triton has dropped support.

PiperOrigin-RevId: 619993037
2024-03-28 11:04:46 -07:00
jax authors
e03f1d4fd1 Allows for splitting the transpose of a scan into a scan and a map.
This is an experimental feature exposed as an extra parameter: `scan(..., _split_transpose:bool)`.

If the parameter is true then the transpose of scan generates not just 2 scans
(forward and transpose of the linearized forward), but rather 3 scans: (i)
forward (as before), (ii) transposed scan that only computes loop-carried state
required for back-propagation, but saves other intermediate gradients; (iii) a
scan (actually a map) that uses any saved activation gradients and original
residuals to compute any other gradients.

Warning: this feature is somewhat experimental and may evolve or be rolled back.
PiperOrigin-RevId: 619991098
2024-03-28 10:54:50 -07:00
Benjamin Chetioui
1e1906bd68 [XLA:GPU] Deprecate Triton codegen before Ampere.
Unfortunately, upstream Triton has decided to drop support for NVIDIA GPUs
below Ampere, so we bump the GPU version requirements for using Triton.

PiperOrigin-RevId: 619899728
2024-03-28 05:56:03 -07:00
Yash Katariya
9e86aa5329 Add custom call on output along with S(5) because XLA requires the custom call to show the transfer.
Enable paramater streaming and weight offloading

PiperOrigin-RevId: 619711649
2024-03-27 17:07:36 -07:00
Sergei Lebedev
ec73c4031a Do not deadlock the GPU if a pure_callback dispatches a GPU kernel
PiperOrigin-RevId: 619656442
2024-03-27 14:26:03 -07:00
Yunlong Liu
7f30ab5822 Supports PjRt CPU array converting py array when the CPU PjRt arrays have non-default layouts.
PiperOrigin-RevId: 619608909
2024-03-27 12:14:39 -07:00
jax authors
66877c9987 Allow allow_spmd_propagation_to_output to be generated for outputs annotated with pjit.AUTO
PiperOrigin-RevId: 619608022
2024-03-27 12:04:03 -07:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
Pearu Peterson
be5812daa2 Temporarily disable log1p success unit tests 2024-03-27 16:39:10 +02:00
Parker Schuh
0b09762efd Guard host transfers inside pure_callbacks from deadlocking the TPU.
Also fix python/callback.cc to not swallow errors in numpy conversions.

PiperOrigin-RevId: 619375128
2024-03-26 18:36:39 -07:00
Yash Katariya
6e0c95585a Remove the canonicalization to GSPMDSharding internally in jit. This is not required anymore since the caches are split into tracing, lowering and compilation.
The canonicalization doesn't provide any value anymore and only makes the internals more complicated.

The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that.

PiperOrigin-RevId: 619292757
2024-03-26 13:28:45 -07:00
Anselm Levskaya
f8d6669291 Add commentary and clean the pallas bidirectional collective allgather matmul test.
"test_pipeline_all_gather_matmul" is the best demo example of nested pallas pipelines, but it's hard to follow the logic in the existing test.

A few changes were made there:

 - rename things to avoid confusion between outer and inner loop prologues / epilogues.
 - give clear names for the outer iteration space: (step, phase) to help clarify sequencing of compute and DMAs.
 - simplify and lift out all async copy definitions and add commentary on their function
 - remove some incorrect comments about the rDMA schedule, and generally add a ton of commentary about when things happen in the outer pipeline.
 - lift all the outer prologue work into an integrated prologue function
 - various other small things.

PiperOrigin-RevId: 619254981
2024-03-26 11:31:38 -07:00
Pearu Peterson
c82f0619b7 Update complex function accurancy tests for expm1 2024-03-26 19:10:38 +02:00
Yash Katariya
c78054d8ae Fix the pjit test failing on v5e
PiperOrigin-RevId: 619207394
2024-03-26 09:02:13 -07:00
jax authors
6afa523ae0 Skip fused_attention_stablehlo_test.py.
https://github.com/google/jax/issues/20438

PiperOrigin-RevId: 619204356
2024-03-26 08:50:50 -07:00
George Necula
75db481299 [callback] Fix io_callback for callbacks that return Python literals.
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.

Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.

Reverts 19e6156ccec0df7a900471df7840bc421da2898b

PiperOrigin-RevId: 619156176
2024-03-26 05:32:41 -07:00
jax authors
69980a27bb Use the information in allow_spmd_sharding_propagation_to_output and allow_spmd_sharding_propagation_to_parameters to determine what input and output tuple elements we are allowed to modfy the shardings of.
PiperOrigin-RevId: 619013275
2024-03-25 17:46:52 -07:00
jax authors
c724eab240 Merge pull request #20257 from Cjkkkk:sdpa_training
PiperOrigin-RevId: 618972494
2024-03-25 15:09:52 -07:00
Cjkkkk
f51d80ed1e move checks to setup 2024-03-25 14:11:11 -07:00