8695 Commits

Author SHA1 Message Date
jax authors
31cb3fd36e Merge pull request #23923 from carlosgmartin:ldexp_custom_jvp
PiperOrigin-RevId: 680757259
2024-09-30 16:21:57 -07:00
Ayaka
a24420e76b [Pallas TPU] Add lowering for lax.cos_p
Fixes https://github.com/jax-ml/jax/issues/24026

PiperOrigin-RevId: 680754948
2024-09-30 16:12:11 -07:00
Ayaka
23ce5a11cc [Pallas TPU] Consolidate OpsExtraTest into OpsTest
Historically, tests that only ran on GPUs were placed in `OpsExtraTest`, while general tests were in `OpsTest`. However, this separation may cause us to miss issues that should be addressed on TPUs as well. Going forward, all tests will be unified in `OpsTest`, and any tests that fail on TPUs will be skipped individually using `skipTest`. This will help us better track and address TPU-specific failures.

PiperOrigin-RevId: 680747902
2024-09-30 15:50:23 -07:00
carlosgmartin
65a58d622c Edit implementation of jax.numpy.ldexp to get correct gradient. 2024-09-30 18:27:39 -04:00
Jake VanderPlas
36782e8319 jnp.mask_indices: add docs & tests 2024-09-30 15:13:41 -07:00
Jevin Jiang
4a596aee1e [Mosaic TPU] Force offset to 0 when inferring input has offset out of the first tile.
We still have this temporary check in apply vector layout, but in infer vector layout, instead of throwing error, we should just reset offset to zero. Because some ops which has relaxed this restriction might be passed as input for un-relaxed ops and cause failure.

PiperOrigin-RevId: 680706301
2024-09-30 13:52:48 -07:00
jax authors
cdc72787fc Merge pull request #24025 from jakevdp:gradient-doc
PiperOrigin-RevId: 680703792
2024-09-30 13:48:09 -07:00
Jake VanderPlas
36d6bb9013 Better docs for jnp.gradient
Also remove skip_params option from util.implements, as this was its last usage.
2024-09-30 13:07:52 -07:00
jax authors
3766f887d3 Merge pull request #23505 from sergachev:fix_cudnn_fusion_test
PiperOrigin-RevId: 680685919
2024-09-30 12:58:44 -07:00
Jevin Jiang
7e2f487ada [Mosaic TPU] Canonicalize arith.select's condition to vector if other types are vector.
This fixes the failure in elementwise rule of apply vector layout pass.

If the condition scalar is static, it will be simplified to corresponding vector from true value and false value by MLIR.

If the condition scalar is dynamic, we want to use vselect over scf.if anyway. Because latter creates a inner region.

PiperOrigin-RevId: 680674560
2024-09-30 12:26:44 -07:00
Dan Foreman-Mackey
ff1c2ac152 Add a test for 64-bit precision of IFFT on GPU.
Fixes https://github.com/jax-ml/jax/issues/23827. The precision fix was in https://github.com/openxla/xla/pull/17598, which has now been integrated into JAX, but we add a test here based on the repro from https://github.com/jax-ml/jax/issues/23827.

PiperOrigin-RevId: 680633622
2024-09-30 10:38:16 -07:00
Peter Hawkins
45cd77ad8c Simplify CI configuration.
PiperOrigin-RevId: 680607105
2024-09-30 09:32:09 -07:00
Yash Katariya
203cda6f98 Move test_aot_device_implicit_transfer to pjit_test.py
This test is not specific to compute offload and is more relevant to pjit.

PiperOrigin-RevId: 680599882
2024-09-30 09:10:17 -07:00
Sergei Lebedev
a046e21a1e [pallas:mosaic_gpu] Do not do mgpu.commit_shared if all outputs are invariant wrt sequential axes
PiperOrigin-RevId: 680565753
2024-09-30 07:25:46 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
jax authors
411928b966 Rollback because of breakages
Reverts 21fea5b0db7a8d3fcd9d6918b430b0ebdd4da3e5

PiperOrigin-RevId: 680552566
2024-09-30 07:23:36 -07:00
Ilia Sergachev
b320dc2e5e Fix and reenable cudnn_fusion_test.
Disable XLA autotuning fallback to cuBLAS so that the tested fusion
always executes through cuDNN.
2024-09-30 14:03:55 +00:00
Adam Paszke
21fea5b0db [Pallas/MGPU] Undo transforms on refs before giving them back to the users
This changes makes it so that the refs users receive inside their kernels have shapes
matching their block specs. However, the refs are not actually plain refs, but transformed
references that begin with the fully transformed abstract ref and then stack the inverse
of the transformation stack on top of it. This means that all primitives that take in refs
can also see the sequence of transforms the user applied in the block spec, which lets us
verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible
shape remains 2D. We should be able to use the same trick in the future to propagate tiling
and better infer the layouts for loads and stores.

PiperOrigin-RevId: 680520185
2024-09-30 04:43:08 -07:00
Sergei Lebedev
38d2a573fc Exposed sequential iteration index via pl.program_id in Pallas Mosaic GPU
PiperOrigin-RevId: 680502214
2024-09-30 03:35:58 -07:00
Dan Foreman-Mackey
d80a89d86b Add support for FFI calls with side effects via ffi_call 2024-09-27 19:46:35 -04:00
Peter Hawkins
061f435b73 Bump test tolerance on FFT test that started failing in CI after an XLA change.
PiperOrigin-RevId: 679715691
2024-09-27 13:49:58 -07:00
Peter Hawkins
366c823857 Fix test failure when shardy is not enabled.
PiperOrigin-RevId: 679713372
2024-09-27 13:42:20 -07:00
Peter Hawkins
5969e79908 Fix tests that ask for an accelerator but don't use it.
* Delete custom_object_test, since it is disabled and has been ever since jax.Array was introduced in JAX 0.4.0.
* custom_linear_solve_test was over-sharded, leading to some shards not having any test cases. Even unsharded it completes in under 65s on every platform we have.
* config_test and pallas splash attention mask test only tested helpers and didn't need a TPU.

PiperOrigin-RevId: 679711664
2024-09-27 13:36:23 -07:00
jax authors
df042fded2 Merge pull request #23870 from Zantares:tenglu/flush_output
PiperOrigin-RevId: 679639244
2024-09-27 10:21:25 -07:00
jax authors
b762291183 Merge pull request #23965 from zhenying-liu:weight-offloading-test
PiperOrigin-RevId: 679631125
2024-09-27 10:01:30 -07:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Sergei Lebedev
afaf8b823d Run Pallas Mosaic GPU tests on internal CI
PiperOrigin-RevId: 679508320
2024-09-27 02:43:35 -07:00
jax authors
ea86251a60 [Pallas:TPU] Fix lowering of convert_element_type(int32) -> bool.
We need to add a condition on vector type since both operands of arith::CmpIOp must have same type.

PiperOrigin-RevId: 679500783
2024-09-27 02:15:35 -07:00
Jane Liu
57bef447c6 Enable weight offloading tests that are supported on GPUs now 2024-09-26 23:26:27 -07:00
jax authors
5a1549cccf Merge pull request #23853 from zhenying-liu:remat-scan
PiperOrigin-RevId: 679365040
2024-09-26 18:12:30 -07:00
Justin Fu
9f4e8d0039 [XLA:Mosaic][Pallas] Enable vector.ExtractOp for non-zero indices.
PiperOrigin-RevId: 679283281
2024-09-26 13:57:45 -07:00
jax authors
96cf2b81e6 Merge pull request #23921 from rajasekharporeddy:testbranch4
PiperOrigin-RevId: 679203931
2024-09-26 10:32:44 -07:00
Adam Paszke
076287fb5c [Pallas/MGPU] Implement block spec evaluation correctly
The preivous implementation made some surprising assumptions about the contents
of the block specs and wasn't correct in general. The new implementation handles
all the cases and seems to be sufficient to finally run the matmul example with
multiple k steps while producing correct results (it's also shorter!).

PiperOrigin-RevId: 679175212
2024-09-26 09:15:12 -07:00
Bart Chrzaszcz
a3284bd8a3 #sdy Add CPU targets in JAX.
PiperOrigin-RevId: 679174535
2024-09-26 09:13:34 -07:00
rajasekharporeddy
6072f97961 Raise ValueError when axis1==axis2 for jnp.trace 2024-09-26 21:38:14 +05:30
Bart Chrzaszcz
e62a50cd34 #sdy add JAX Shardy support for shard_map.
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
    jnp.arange(8 * 8).reshape((8, 8)),
    jax.sharding.NamedSharding(mesh, P('x', None)))

@jax.jit
@partial(
    shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
  axis_size = lax.psum(1, 'x')
  perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
  return lax.ppermute(a, 'x', perm=perm)

print(jax.jit(fwd).lower(a).as_text())
```

prints:

```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=8]>
  func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
  func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
      %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
      sdy.return %1 : tensor<1x8xi32>
    } : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
}
```

PiperOrigin-RevId: 679165100
2024-09-26 08:45:40 -07:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Sergei Lebedev
5cef547eab Added support for lax.cond_p to Pallas Mosaic GPU lowering
PiperOrigin-RevId: 679156819
2024-09-26 08:20:53 -07:00
Adam Paszke
0a66e2d0a4 [Pallas/MGPU] Fix a race in the pipelining code
We never checked if the output windows are done writing before we reused them.
Also, rename num_stages to max_concurrent_steps since we always only have 2 stages,
but might be running multiple iterations at a time.

Also fix the test for this that has been passing for reasons that I don't understand
(it didn't even write to all entries in the output??).

PiperOrigin-RevId: 679148961
2024-09-26 07:57:54 -07:00
Adam Paszke
8599dbc9b2 [Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to see what we're still missing
I annotated a number of issues in the test. To make the test run I also needed to add support
for the accumulator reference allocation and discharge in the main lowering part. Ideally,
we'd defer it all to run_scoped, but run_scoped can't allocate barriers...

PiperOrigin-RevId: 679143948
2024-09-26 07:40:15 -07:00
Adam Paszke
3c25da2c59 [Pallas/Mosaic GPU] Replace tiling/transpose fields of GPUBlockSpec with a transform list
PiperOrigin-RevId: 679079269
2024-09-26 03:41:22 -07:00
Christos Perivolaropoulos
b6d668e0d7 [pallas::mosaic_gpu] Turn the accumulator into a reference
* Changes the accumulator into a reference
* Creates a discharged flavor of the wgmma op
* run_scoped lowering discharges the input jaxpr
* dereferencing the accumulator ref is done by a new primitive that behaves as expected when discharged
* the deref primitive implies flushing the wgmma pipeline.
* run_scoped does not allow references to be leaked.

PiperOrigin-RevId: 679056765
2024-09-26 02:18:27 -07:00
jax authors
70346bda74 [Pallas] Add scalar f32 downcast test cases.
PiperOrigin-RevId: 678779025
2024-09-25 11:25:59 -07:00
jax authors
0f84c2c6be Merge pull request #23917 from dfm:gh23895
PiperOrigin-RevId: 678759331
2024-09-25 10:41:44 -07:00
Tom Natan
6cf09f8c24 Reverts eff00cc4499cfe3f3f24bafda6c1ecf908232ff3
PiperOrigin-RevId: 678756266
2024-09-25 10:33:53 -07:00
Peter Hawkins
111f13e279 Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b
PiperOrigin-RevId: 678748794
2024-09-25 10:14:45 -07:00
Dan Foreman-Mackey
96268dcae6 Fix dtype bug in jax.scipy.fft.idct 2024-09-25 12:55:43 -04:00
Sergei Lebedev
b49d8b2615 Fixed pl.debug_printing of scalar fragmented arrays under Mosaic GPU
PiperOrigin-RevId: 678726245
2024-09-25 09:10:48 -07:00
Peter Hawkins
1949413739 Increase sharding of checkify_test on TPU to fix CI flakes.
PiperOrigin-RevId: 678720498
2024-09-25 08:54:29 -07:00
Sergei Lebedev
a373e37be2 Fixed mgpu.FragmentedArray.reduce_sum for integer types
The implementation previously assumed the type is floating and used addf.

PiperOrigin-RevId: 678718871
2024-09-25 08:50:24 -07:00