rajasekharporeddy
7e6fa3ed28
Improve docs for jax.numpy: sinh, cosh and tanh
2024-09-26 23:27:23 +05:30
jax authors
6f7ad641d7
Merge pull request #23940 from jakevdp:jacobian-doc
...
PiperOrigin-RevId: 679203936
2024-09-26 10:34:25 -07:00
jax authors
96cf2b81e6
Merge pull request #23921 from rajasekharporeddy:testbranch4
...
PiperOrigin-RevId: 679203931
2024-09-26 10:32:44 -07:00
Peter Hawkins
0e082f978b
Deprecate jax.lib.xla_client.Device.
...
jax.Device is a longstanding public name for this class.
PiperOrigin-RevId: 679197718
2024-09-26 10:17:04 -07:00
jax authors
140a8c70b4
Update XLA dependency to use revision
...
0e732d65bd
.
PiperOrigin-RevId: 679196598
2024-09-26 10:13:46 -07:00
Adam Paszke
dd2ee8c7b2
[Pallas/MGPU] Skip outgoing TMA when the output is being revisited
...
Otherwise we end up with programs that race on writes to the same GMEM location.
PiperOrigin-RevId: 679189227
2024-09-26 09:54:34 -07:00
Adam Paszke
076287fb5c
[Pallas/MGPU] Implement block spec evaluation correctly
...
The preivous implementation made some surprising assumptions about the contents
of the block specs and wasn't correct in general. The new implementation handles
all the cases and seems to be sufficient to finally run the matmul example with
multiple k steps while producing correct results (it's also shorter!).
PiperOrigin-RevId: 679175212
2024-09-26 09:15:12 -07:00
Bart Chrzaszcz
a3284bd8a3
#sdy Add CPU targets in JAX.
...
PiperOrigin-RevId: 679174535
2024-09-26 09:13:34 -07:00
rajasekharporeddy
6072f97961
Raise ValueError when axis1==axis2 for jnp.trace
2024-09-26 21:38:14 +05:30
Bart Chrzaszcz
e62a50cd34
#sdy add JAX Shardy support for shard_map.
...
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)))
@jax.jit
@partial(
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
axis_size = lax.psum(1, 'x')
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
return lax.ppermute(a, 'x', perm=perm)
print(jax.jit(fwd).lower(a).as_text())
```
prints:
```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <["x"=8]>
func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
return %0 : tensor<8x8xi32>
}
func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
%1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
sdy.return %1 : tensor<1x8xi32>
} : (tensor<8x8xi32>) -> tensor<8x8xi32>
return %0 : tensor<8x8xi32>
}
}
```
PiperOrigin-RevId: 679165100
2024-09-26 08:45:40 -07:00
Peter Hawkins
7b53c2f39d
Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
...
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.
PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Sergei Lebedev
5cef547eab
Added support for lax.cond_p
to Pallas Mosaic GPU lowering
...
PiperOrigin-RevId: 679156819
2024-09-26 08:20:53 -07:00
Adam Paszke
0a66e2d0a4
[Pallas/MGPU] Fix a race in the pipelining code
...
We never checked if the output windows are done writing before we reused them.
Also, rename num_stages to max_concurrent_steps since we always only have 2 stages,
but might be running multiple iterations at a time.
Also fix the test for this that has been passing for reasons that I don't understand
(it didn't even write to all entries in the output??).
PiperOrigin-RevId: 679148961
2024-09-26 07:57:54 -07:00
Adam Paszke
8599dbc9b2
[Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to see what we're still missing
...
I annotated a number of issues in the test. To make the test run I also needed to add support
for the accumulator reference allocation and discharge in the main lowering part. Ideally,
we'd defer it all to run_scoped, but run_scoped can't allocate barriers...
PiperOrigin-RevId: 679143948
2024-09-26 07:40:15 -07:00
jax authors
c07652fd46
Merge pull request #23927 from jakevdp:pad-doc
...
PiperOrigin-RevId: 679142773
2024-09-26 07:37:03 -07:00
Adam Paszke
57887732be
[Pallas/Mosaic GPU] Disable inference of sequential axis shapes
...
They should just be specified in the grid, so we don't need to do this. It's
also incorrect, because it's not guaranteed that each input is sliced in the
same dimension by the sequential axis.
PiperOrigin-RevId: 679114626
2024-09-26 05:53:15 -07:00
jax authors
a6b4648e4d
Merge pull request #23928 from nvcastet:fix_mosaic_barrier
...
PiperOrigin-RevId: 679114112
2024-09-26 05:51:25 -07:00
Jake VanderPlas
cf51ee7ef0
Improve documentation for jax.jacobian
2024-09-26 05:09:47 -07:00
Adam Paszke
3c25da2c59
[Pallas/Mosaic GPU] Replace tiling/transpose fields of GPUBlockSpec with a transform list
...
PiperOrigin-RevId: 679079269
2024-09-26 03:41:22 -07:00
Christos Perivolaropoulos
b6d668e0d7
[pallas::mosaic_gpu] Turn the accumulator into a reference
...
* Changes the accumulator into a reference
* Creates a discharged flavor of the wgmma op
* run_scoped lowering discharges the input jaxpr
* dereferencing the accumulator ref is done by a new primitive that behaves as expected when discharged
* the deref primitive implies flushing the wgmma pipeline.
* run_scoped does not allow references to be leaked.
PiperOrigin-RevId: 679056765
2024-09-26 02:18:27 -07:00
jax authors
f6fdfb4518
Merge pull request #23899 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 679039498
2024-09-26 01:22:24 -07:00
rajasekharporeddy
8ffeb2388a
Better doc for jnp.trace
2024-09-26 09:39:18 +05:30
jax authors
911acf1bbf
Merge pull request #23924 from jburnim:jburnim_doc_typo_fix
...
PiperOrigin-RevId: 678896733
2024-09-25 16:50:17 -07:00
Jevin Jiang
e4ca4f5a57
Roll back cl/678765762 [Mosaic TPU] Support bitcast without forcing retiling.
...
Reverts 37641dd4fade625563321b7e1e87165df23cf4a8
PiperOrigin-RevId: 678881199
2024-09-25 16:02:58 -07:00
Abhinav Goel
ec7a7791aa
Merge branch 'main' into patch-4
2024-09-25 15:46:50 -07:00
Abhinav Goel
b5bb30329d
changed default
2024-09-25 15:43:59 -07:00
jax authors
f1b3251bf9
Change CLANG_CUDA_COMPILER_PATH
set order. Add --config=cuda_clang
to build.py
...
Set `--action_env=CLANG_CUDA_COMPILER_PATH` after cuda_nvcc configuration
Add `--config=cuda_clang` when `--nouse_cuda_nvcc` flag set
PiperOrigin-RevId: 678873849
2024-09-25 15:39:44 -07:00
Jake VanderPlas
ad6c3a7f64
Improve docs for jnp.pad
2024-09-25 14:41:13 -07:00
Nicolas Castet
08629a4233
[Mosaic GPU] Fix mbarrier inline ptx for newer CTK
2024-09-25 16:39:02 -05:00
Jacob Burnim
a1f2edc968
Fix make_remote_async_copy -> make_async_remote_copy in async doc.
2024-09-25 13:39:39 -07:00
jax authors
ce99c18a74
Remove CC="/usr/lib/llvm-18/bin/clang"
from clang config in .bazelrc
...
Restore `cuda_clang` config in .bazelrc
PiperOrigin-RevId: 678828039
2024-09-25 13:35:01 -07:00
jax authors
5d4cae07c8
Merge pull request #23916 from jakevdp:interp-doc
...
PiperOrigin-RevId: 678784484
2024-09-25 11:37:15 -07:00
jax authors
70346bda74
[Pallas] Add scalar f32 downcast test cases.
...
PiperOrigin-RevId: 678779025
2024-09-25 11:25:59 -07:00
Jake VanderPlas
e05c37c667
Finalize deprecation of pretty-printing utils in jax.core.pp_*
...
PiperOrigin-RevId: 678775782
2024-09-25 11:20:35 -07:00
Jevin Jiang
37641dd4fa
[Mosaic TPU] Support bitcast without forcing retiling.
...
PiperOrigin-RevId: 678765762
2024-09-25 10:57:09 -07:00
jax authors
c93b272b78
Update XLA dependency to use revision
...
a473d30392
.
PiperOrigin-RevId: 678764121
2024-09-25 10:53:22 -07:00
Jake VanderPlas
ee6fd5aeb2
Improve documentation for jnp.interp
2024-09-25 10:47:46 -07:00
jax authors
0f84c2c6be
Merge pull request #23917 from dfm:gh23895
...
PiperOrigin-RevId: 678759331
2024-09-25 10:41:44 -07:00
Tom Natan
6cf09f8c24
Reverts eff00cc4499cfe3f3f24bafda6c1ecf908232ff3
...
PiperOrigin-RevId: 678756266
2024-09-25 10:33:53 -07:00
jax authors
a20f79d88b
Merge pull request #23874 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 678754449
2024-09-25 10:28:48 -07:00
jax authors
f126705dd0
Merge pull request #23914 from rajasekharporeddy:testbranch3
...
PiperOrigin-RevId: 678752363
2024-09-25 10:26:32 -07:00
jax authors
8806e0b697
Merge pull request #23382 from jax-ml:dependabot/github_actions/actions/setup-python-5.2.0
...
PiperOrigin-RevId: 678752022
2024-09-25 10:24:34 -07:00
jax authors
8edc0fca5b
Merge pull request #23910 from 8bitmp3:patch-3
...
PiperOrigin-RevId: 678751882
2024-09-25 10:22:46 -07:00
Peter Hawkins
111f13e279
Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b
...
PiperOrigin-RevId: 678748794
2024-09-25 10:14:45 -07:00
Dan Foreman-Mackey
96268dcae6
Fix dtype bug in jax.scipy.fft.idct
2024-09-25 12:55:43 -04:00
Sergei Lebedev
b49d8b2615
Fixed pl.debug_print
ing of scalar fragmented arrays under Mosaic GPU
...
PiperOrigin-RevId: 678726245
2024-09-25 09:10:48 -07:00
rajasekharporeddy
13774d1382
Fix Typos
2024-09-25 21:26:05 +05:30
Peter Hawkins
1949413739
Increase sharding of checkify_test on TPU to fix CI flakes.
...
PiperOrigin-RevId: 678720498
2024-09-25 08:54:29 -07:00
Sergei Lebedev
a373e37be2
Fixed mgpu.FragmentedArray.reduce_sum
for integer types
...
The implementation previously assumed the type is floating and used addf.
PiperOrigin-RevId: 678718871
2024-09-25 08:50:24 -07:00
8bitmp3
60a06fd4c9
Update pillow version in JAX build test-requirements.txt
2024-09-25 14:55:46 +00:00