23755 Commits

Author SHA1 Message Date
rajasekharporeddy
7e6fa3ed28 Improve docs for jax.numpy: sinh, cosh and tanh 2024-09-26 23:27:23 +05:30
jax authors
6f7ad641d7 Merge pull request #23940 from jakevdp:jacobian-doc
PiperOrigin-RevId: 679203936
2024-09-26 10:34:25 -07:00
jax authors
96cf2b81e6 Merge pull request #23921 from rajasekharporeddy:testbranch4
PiperOrigin-RevId: 679203931
2024-09-26 10:32:44 -07:00
Peter Hawkins
0e082f978b Deprecate jax.lib.xla_client.Device.
jax.Device is a longstanding public name for this class.

PiperOrigin-RevId: 679197718
2024-09-26 10:17:04 -07:00
jax authors
140a8c70b4 Update XLA dependency to use revision
0e732d65bd.

PiperOrigin-RevId: 679196598
2024-09-26 10:13:46 -07:00
Adam Paszke
dd2ee8c7b2 [Pallas/MGPU] Skip outgoing TMA when the output is being revisited
Otherwise we end up with programs that race on writes to the same GMEM location.

PiperOrigin-RevId: 679189227
2024-09-26 09:54:34 -07:00
Adam Paszke
076287fb5c [Pallas/MGPU] Implement block spec evaluation correctly
The preivous implementation made some surprising assumptions about the contents
of the block specs and wasn't correct in general. The new implementation handles
all the cases and seems to be sufficient to finally run the matmul example with
multiple k steps while producing correct results (it's also shorter!).

PiperOrigin-RevId: 679175212
2024-09-26 09:15:12 -07:00
Bart Chrzaszcz
a3284bd8a3 #sdy Add CPU targets in JAX.
PiperOrigin-RevId: 679174535
2024-09-26 09:13:34 -07:00
rajasekharporeddy
6072f97961 Raise ValueError when axis1==axis2 for jnp.trace 2024-09-26 21:38:14 +05:30
Bart Chrzaszcz
e62a50cd34 #sdy add JAX Shardy support for shard_map.
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
    jnp.arange(8 * 8).reshape((8, 8)),
    jax.sharding.NamedSharding(mesh, P('x', None)))

@jax.jit
@partial(
    shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
  axis_size = lax.psum(1, 'x')
  perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
  return lax.ppermute(a, 'x', perm=perm)

print(jax.jit(fwd).lower(a).as_text())
```

prints:

```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=8]>
  func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
  func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
      %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
      sdy.return %1 : tensor<1x8xi32>
    } : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
}
```

PiperOrigin-RevId: 679165100
2024-09-26 08:45:40 -07:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Sergei Lebedev
5cef547eab Added support for lax.cond_p to Pallas Mosaic GPU lowering
PiperOrigin-RevId: 679156819
2024-09-26 08:20:53 -07:00
Adam Paszke
0a66e2d0a4 [Pallas/MGPU] Fix a race in the pipelining code
We never checked if the output windows are done writing before we reused them.
Also, rename num_stages to max_concurrent_steps since we always only have 2 stages,
but might be running multiple iterations at a time.

Also fix the test for this that has been passing for reasons that I don't understand
(it didn't even write to all entries in the output??).

PiperOrigin-RevId: 679148961
2024-09-26 07:57:54 -07:00
Adam Paszke
8599dbc9b2 [Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to see what we're still missing
I annotated a number of issues in the test. To make the test run I also needed to add support
for the accumulator reference allocation and discharge in the main lowering part. Ideally,
we'd defer it all to run_scoped, but run_scoped can't allocate barriers...

PiperOrigin-RevId: 679143948
2024-09-26 07:40:15 -07:00
jax authors
c07652fd46 Merge pull request #23927 from jakevdp:pad-doc
PiperOrigin-RevId: 679142773
2024-09-26 07:37:03 -07:00
Adam Paszke
57887732be [Pallas/Mosaic GPU] Disable inference of sequential axis shapes
They should just be specified in the grid, so we don't need to do this. It's
also incorrect, because it's not guaranteed that each input is sliced in the
same dimension by the sequential axis.

PiperOrigin-RevId: 679114626
2024-09-26 05:53:15 -07:00
jax authors
a6b4648e4d Merge pull request #23928 from nvcastet:fix_mosaic_barrier
PiperOrigin-RevId: 679114112
2024-09-26 05:51:25 -07:00
Jake VanderPlas
cf51ee7ef0 Improve documentation for jax.jacobian 2024-09-26 05:09:47 -07:00
Adam Paszke
3c25da2c59 [Pallas/Mosaic GPU] Replace tiling/transpose fields of GPUBlockSpec with a transform list
PiperOrigin-RevId: 679079269
2024-09-26 03:41:22 -07:00
Christos Perivolaropoulos
b6d668e0d7 [pallas::mosaic_gpu] Turn the accumulator into a reference
* Changes the accumulator into a reference
* Creates a discharged flavor of the wgmma op
* run_scoped lowering discharges the input jaxpr
* dereferencing the accumulator ref is done by a new primitive that behaves as expected when discharged
* the deref primitive implies flushing the wgmma pipeline.
* run_scoped does not allow references to be leaked.

PiperOrigin-RevId: 679056765
2024-09-26 02:18:27 -07:00
jax authors
f6fdfb4518 Merge pull request #23899 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 679039498
2024-09-26 01:22:24 -07:00
rajasekharporeddy
8ffeb2388a Better doc for jnp.trace 2024-09-26 09:39:18 +05:30
jax authors
911acf1bbf Merge pull request #23924 from jburnim:jburnim_doc_typo_fix
PiperOrigin-RevId: 678896733
2024-09-25 16:50:17 -07:00
Jevin Jiang
e4ca4f5a57 Roll back cl/678765762 [Mosaic TPU] Support bitcast without forcing retiling.
Reverts 37641dd4fade625563321b7e1e87165df23cf4a8

PiperOrigin-RevId: 678881199
2024-09-25 16:02:58 -07:00
Abhinav Goel
ec7a7791aa
Merge branch 'main' into patch-4 2024-09-25 15:46:50 -07:00
Abhinav Goel
b5bb30329d
changed default 2024-09-25 15:43:59 -07:00
jax authors
f1b3251bf9 Change CLANG_CUDA_COMPILER_PATH set order. Add --config=cuda_clang to build.py
Set `--action_env=CLANG_CUDA_COMPILER_PATH` after cuda_nvcc configuration
Add `--config=cuda_clang` when `--nouse_cuda_nvcc` flag set

PiperOrigin-RevId: 678873849
2024-09-25 15:39:44 -07:00
Jake VanderPlas
ad6c3a7f64 Improve docs for jnp.pad 2024-09-25 14:41:13 -07:00
Nicolas Castet
08629a4233 [Mosaic GPU] Fix mbarrier inline ptx for newer CTK 2024-09-25 16:39:02 -05:00
Jacob Burnim
a1f2edc968 Fix make_remote_async_copy -> make_async_remote_copy in async doc. 2024-09-25 13:39:39 -07:00
jax authors
ce99c18a74 Remove CC="/usr/lib/llvm-18/bin/clang" from clang config in .bazelrc
Restore `cuda_clang` config in .bazelrc

PiperOrigin-RevId: 678828039
2024-09-25 13:35:01 -07:00
jax authors
5d4cae07c8 Merge pull request #23916 from jakevdp:interp-doc
PiperOrigin-RevId: 678784484
2024-09-25 11:37:15 -07:00
jax authors
70346bda74 [Pallas] Add scalar f32 downcast test cases.
PiperOrigin-RevId: 678779025
2024-09-25 11:25:59 -07:00
Jake VanderPlas
e05c37c667 Finalize deprecation of pretty-printing utils in jax.core.pp_*
PiperOrigin-RevId: 678775782
2024-09-25 11:20:35 -07:00
Jevin Jiang
37641dd4fa [Mosaic TPU] Support bitcast without forcing retiling.
PiperOrigin-RevId: 678765762
2024-09-25 10:57:09 -07:00
jax authors
c93b272b78 Update XLA dependency to use revision
a473d30392.

PiperOrigin-RevId: 678764121
2024-09-25 10:53:22 -07:00
Jake VanderPlas
ee6fd5aeb2 Improve documentation for jnp.interp 2024-09-25 10:47:46 -07:00
jax authors
0f84c2c6be Merge pull request #23917 from dfm:gh23895
PiperOrigin-RevId: 678759331
2024-09-25 10:41:44 -07:00
Tom Natan
6cf09f8c24 Reverts eff00cc4499cfe3f3f24bafda6c1ecf908232ff3
PiperOrigin-RevId: 678756266
2024-09-25 10:33:53 -07:00
jax authors
a20f79d88b Merge pull request #23874 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 678754449
2024-09-25 10:28:48 -07:00
jax authors
f126705dd0 Merge pull request #23914 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 678752363
2024-09-25 10:26:32 -07:00
jax authors
8806e0b697 Merge pull request #23382 from jax-ml:dependabot/github_actions/actions/setup-python-5.2.0
PiperOrigin-RevId: 678752022
2024-09-25 10:24:34 -07:00
jax authors
8edc0fca5b Merge pull request #23910 from 8bitmp3:patch-3
PiperOrigin-RevId: 678751882
2024-09-25 10:22:46 -07:00
Peter Hawkins
111f13e279 Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b
PiperOrigin-RevId: 678748794
2024-09-25 10:14:45 -07:00
Dan Foreman-Mackey
96268dcae6 Fix dtype bug in jax.scipy.fft.idct 2024-09-25 12:55:43 -04:00
Sergei Lebedev
b49d8b2615 Fixed pl.debug_printing of scalar fragmented arrays under Mosaic GPU
PiperOrigin-RevId: 678726245
2024-09-25 09:10:48 -07:00
rajasekharporeddy
13774d1382 Fix Typos 2024-09-25 21:26:05 +05:30
Peter Hawkins
1949413739 Increase sharding of checkify_test on TPU to fix CI flakes.
PiperOrigin-RevId: 678720498
2024-09-25 08:54:29 -07:00
Sergei Lebedev
a373e37be2 Fixed mgpu.FragmentedArray.reduce_sum for integer types
The implementation previously assumed the type is floating and used addf.

PiperOrigin-RevId: 678718871
2024-09-25 08:50:24 -07:00
8bitmp3
60a06fd4c9
Update pillow version in JAX build test-requirements.txt 2024-09-25 14:55:46 +00:00