2139 Commits

Author SHA1 Message Date
Matthew Johnson
26f70c9c16 remove busted example from shmap jep 2024-11-01 16:37:46 +00:00
jax authors
5a3ed6c792 Merge pull request #24647 from emilyfertig:emilyaf-doc-pytree-dataclass
PiperOrigin-RevId: 691984161
2024-10-31 17:16:31 -07:00
Emily Fertig
467bd09f03 Add a register_dataclass example to the pytree tutorial. 2024-10-31 16:26:42 -07:00
Dan Foreman-Mackey
ce8dba98fb Move the CUDA end-to-end example to FFI examples workflow + hosted
runner.
2024-10-31 12:21:51 -04:00
Sergei Lebedev
85662f6dd8 [pallas:mosaic_gpu] plgpu.copy_smem_to_gmem no longer transparently commits SMEM
Users are expected to call `pltpu.commit_smem` manually instead.

PiperOrigin-RevId: 691724662
2024-10-31 02:21:10 -07:00
Jake VanderPlas
abf14323dc Adjust copyright notice.
Previously we had been pulling-in NumPy and SciPy docs at runtime, but
after the work in #21461 this is no longer the case.
2024-10-28 18:53:38 -07:00
minigoel
68428488c8 Add a link to Intel plugin for JAX 2024-10-28 10:47:59 -07:00
Sergei Lebedev
dfa6fcd56b [pallas:mosaic_gpu] Extracted a basic emit_pipeline API from the in kernel pipelining test
PiperOrigin-RevId: 690619853
2024-10-28 08:25:47 -07:00
jax authors
6e06110e1e Merge pull request #24538 from jakevdp:cumulative-prod
PiperOrigin-RevId: 690606656
2024-10-28 07:45:15 -07:00
Jim Lin
e4eca9ec59 #jax Adds a missing comma to Pallas Quickstart
PiperOrigin-RevId: 689907976
2024-10-25 14:14:11 -07:00
Jake VanderPlas
02daf75f97 Add new jnp.cumulative_prod function.
This follows the API of the similar function added in NumPy 2.1.0
2024-10-25 13:45:54 -07:00
jax authors
3b42a6b413 Merge pull request #24391 from keshavb96:remat_documentation
PiperOrigin-RevId: 689888674
2024-10-25 13:13:01 -07:00
Sergei Lebedev
5a2128e44b [pallas] Removed deprecated aliases to CostEstimate and run_scoped
PiperOrigin-RevId: 689871787
2024-10-25 12:16:58 -07:00
jax authors
8c9dc21e30 Update hermetic CUDA docs.
PiperOrigin-RevId: 689463215
2024-10-24 11:51:02 -07:00
jax authors
7ad73e44ce Merge pull request #24446 from gnecula:export_doc
PiperOrigin-RevId: 688886756
2024-10-23 02:50:57 -07:00
Peter Hawkins
e4f3f8f064 Use libtpu releases rather than libtpu-nightly for jax[tpu].
PiperOrigin-RevId: 688632409
2024-10-22 11:47:07 -07:00
jax authors
a2e4aff897 Merge pull request #24425 from dfm:rename-vmap-methods
PiperOrigin-RevId: 688547393
2024-10-22 07:51:29 -07:00
George Necula
7ed65c89a4 [docs] Added two new APIs to the export API docs 2024-10-22 09:11:56 +02:00
Hernan Moraldo
5d3cac6603 Fix documentation.
PiperOrigin-RevId: 688293390
2024-10-21 15:29:59 -07:00
jax authors
11eeff072f Merge pull request #22410 from garymm:patch-1
PiperOrigin-RevId: 688265373
2024-10-21 14:05:22 -07:00
Justin Fu
0b46a236c1 Update Pallas distributed tutorials with jax.make_mesh 2024-10-21 12:49:56 -07:00
Keshav Balasubramanian
5750766898 more detail 2024-10-21 12:08:58 -07:00
Dan Foreman-Mackey
61701af4a2 Rename vmap methods for callbacks. 2024-10-21 15:03:04 -04:00
Gary Miguel
dc908b4843 Update installation instructions
Apple GPUs and Mac x86_64 is a non-existent combination.
Mac x86_64 with AMD GPU is supported.

It's a bit of a confusing situation so hard to summarize, but hopefully this is more accurate and less confusing

Fixes: #24408
2024-10-21 10:20:09 -07:00
Dan Foreman-Mackey
0b651f0f45 Make ffi_call return a callable 2024-10-21 12:16:57 -04:00
Yash Katariya
ca2d1584f8 Remove mesh_utils.create_device_mesh from docs
PiperOrigin-RevId: 687695419
2024-10-19 15:48:42 -07:00
Keshav Balasubramanian
2789b0d4db minor change 2024-10-18 11:45:27 -07:00
keshavb96
5b8e4db855 document jax config to disable remat HLO pass 2024-10-18 11:44:51 -07:00
jax authors
919f7c8684 Merge pull request #24345 from phu0ngng:cuda_custom_call
PiperOrigin-RevId: 687034466
2024-10-17 13:57:15 -07:00
Sergei Lebedev
de7beb91a7 [pallas:mosaic_gpu] Added layout_cast
PiperOrigin-RevId: 686917796
2024-10-17 08:08:05 -07:00
jax authors
3bdc57dd29 Merge pull request #24300 from ROCm:ci_rocm_readme
PiperOrigin-RevId: 686872994
2024-10-17 05:21:13 -07:00
George Necula
9aa79bffba [export] Fix github links in the export documentation
Reflects the repo change google/jax -> jax-ml/jax.
Also changes the error message to put the link to the documentation
in a more visible place.
2024-10-17 08:30:28 +01:00
Jake VanderPlas
e1f280c843 CI: enable additional ruff formatting checks 2024-10-16 16:09:54 -07:00
Ruturaj4
3c3b08dfd6 [ROCm] Fix README.md to update AMD JAX installation instructions 2024-10-16 17:15:32 -05:00
jax authors
089e4aa904 Merge pull request #24341 from phu0ngng:cuda_graph_ex
PiperOrigin-RevId: 686577115
2024-10-16 11:23:28 -07:00
jax authors
ead1c05ada Merge pull request #23831 from nouiz:doc_policies
PiperOrigin-RevId: 686576725
2024-10-16 11:21:41 -07:00
Phuong Nguyen
d4bbb4fd84 added cmdBuffer traits
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:37:49 -07:00
Phuong Nguyen
82113cd047 rm CmdBuffer traits
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:27:09 -07:00
Phuong Nguyen
f3775aa233 added cudaGraph traits + use register_ffi_target()
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:01:20 -07:00
Sergei Lebedev
4c0d82824f [pallas:mosaic_gpu] Added a few more operations necessary to port Flash Attention
PiperOrigin-RevId: 686451398
2024-10-16 04:05:36 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Praveen Batra
3a3190fbce Fix typo in Pallas TPU matmul doc. I think the logical layout of the input array is non-transposed, rather than transposed?
PiperOrigin-RevId: 686151692
2024-10-15 10:23:39 -07:00
Yash Katariya
a2973be051 Don't add mhlo.layout_mode = "default" since that is the default even in PJRT and will help reduce cruft in the IR
PiperOrigin-RevId: 684963359
2024-10-11 14:54:32 -07:00
Justin Fu
cff9e93824 [Pallas] Add runtime assert via checkify.check. This check will halt the TPU if triggered, meaning that we would need to restart the program to recover.
PiperOrigin-RevId: 684940271
2024-10-11 13:34:04 -07:00
Peter Hawkins
46f0a3eee7 Clone RandomAlgorithm into lax.py, instead of using the version from XLA.
Change in preparation for removing HLO ops from the XLA Python bindings.

In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.

PiperOrigin-RevId: 684892102
2024-10-11 11:03:15 -07:00
Frédéric Bastien
e9011940d8
Update docs/gradient-checkpointing.md
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-10-11 12:33:10 -04:00
jax authors
bc3df0e3f5 Merge pull request #24241 from hawkinsp:autodidax
PiperOrigin-RevId: 684811631
2024-10-11 06:08:32 -07:00
Peter Hawkins
c0efa86bdc Port autodidax to use StableHLO instead of classic HLO. 2024-10-11 08:25:05 -04:00
Sergei Lebedev
acd0e497af [pallas:mosaic_gpu] GPUBlockSpec no longer accepts swizzle
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.

PiperOrigin-RevId: 684797980
2024-10-11 05:11:26 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00