Matthew Johnson
26f70c9c16
remove busted example from shmap jep
2024-11-01 16:37:46 +00:00
jax authors
5a3ed6c792
Merge pull request #24647 from emilyfertig:emilyaf-doc-pytree-dataclass
...
PiperOrigin-RevId: 691984161
2024-10-31 17:16:31 -07:00
Emily Fertig
467bd09f03
Add a register_dataclass example to the pytree tutorial.
2024-10-31 16:26:42 -07:00
Dan Foreman-Mackey
ce8dba98fb
Move the CUDA end-to-end example to FFI examples workflow + hosted
...
runner.
2024-10-31 12:21:51 -04:00
Sergei Lebedev
85662f6dd8
[pallas:mosaic_gpu] plgpu.copy_smem_to_gmem
no longer transparently commits SMEM
...
Users are expected to call `pltpu.commit_smem` manually instead.
PiperOrigin-RevId: 691724662
2024-10-31 02:21:10 -07:00
Jake VanderPlas
abf14323dc
Adjust copyright notice.
...
Previously we had been pulling-in NumPy and SciPy docs at runtime, but
after the work in #21461 this is no longer the case.
2024-10-28 18:53:38 -07:00
minigoel
68428488c8
Add a link to Intel plugin for JAX
2024-10-28 10:47:59 -07:00
Sergei Lebedev
dfa6fcd56b
[pallas:mosaic_gpu] Extracted a basic emit_pipeline
API from the in kernel pipelining test
...
PiperOrigin-RevId: 690619853
2024-10-28 08:25:47 -07:00
jax authors
6e06110e1e
Merge pull request #24538 from jakevdp:cumulative-prod
...
PiperOrigin-RevId: 690606656
2024-10-28 07:45:15 -07:00
Jim Lin
e4eca9ec59
#jax Adds a missing comma to Pallas Quickstart
...
PiperOrigin-RevId: 689907976
2024-10-25 14:14:11 -07:00
Jake VanderPlas
02daf75f97
Add new jnp.cumulative_prod function.
...
This follows the API of the similar function added in NumPy 2.1.0
2024-10-25 13:45:54 -07:00
jax authors
3b42a6b413
Merge pull request #24391 from keshavb96:remat_documentation
...
PiperOrigin-RevId: 689888674
2024-10-25 13:13:01 -07:00
Sergei Lebedev
5a2128e44b
[pallas] Removed deprecated aliases to CostEstimate
and run_scoped
...
PiperOrigin-RevId: 689871787
2024-10-25 12:16:58 -07:00
jax authors
8c9dc21e30
Update hermetic CUDA docs.
...
PiperOrigin-RevId: 689463215
2024-10-24 11:51:02 -07:00
jax authors
7ad73e44ce
Merge pull request #24446 from gnecula:export_doc
...
PiperOrigin-RevId: 688886756
2024-10-23 02:50:57 -07:00
Peter Hawkins
e4f3f8f064
Use libtpu releases rather than libtpu-nightly for jax[tpu].
...
PiperOrigin-RevId: 688632409
2024-10-22 11:47:07 -07:00
jax authors
a2e4aff897
Merge pull request #24425 from dfm:rename-vmap-methods
...
PiperOrigin-RevId: 688547393
2024-10-22 07:51:29 -07:00
George Necula
7ed65c89a4
[docs] Added two new APIs to the export API docs
2024-10-22 09:11:56 +02:00
Hernan Moraldo
5d3cac6603
Fix documentation.
...
PiperOrigin-RevId: 688293390
2024-10-21 15:29:59 -07:00
jax authors
11eeff072f
Merge pull request #22410 from garymm:patch-1
...
PiperOrigin-RevId: 688265373
2024-10-21 14:05:22 -07:00
Justin Fu
0b46a236c1
Update Pallas distributed tutorials with jax.make_mesh
2024-10-21 12:49:56 -07:00
Keshav Balasubramanian
5750766898
more detail
2024-10-21 12:08:58 -07:00
Dan Foreman-Mackey
61701af4a2
Rename vmap methods for callbacks.
2024-10-21 15:03:04 -04:00
Gary Miguel
dc908b4843
Update installation instructions
...
Apple GPUs and Mac x86_64 is a non-existent combination.
Mac x86_64 with AMD GPU is supported.
It's a bit of a confusing situation so hard to summarize, but hopefully this is more accurate and less confusing
Fixes : #24408
2024-10-21 10:20:09 -07:00
Dan Foreman-Mackey
0b651f0f45
Make ffi_call return a callable
2024-10-21 12:16:57 -04:00
Yash Katariya
ca2d1584f8
Remove mesh_utils.create_device_mesh
from docs
...
PiperOrigin-RevId: 687695419
2024-10-19 15:48:42 -07:00
Keshav Balasubramanian
2789b0d4db
minor change
2024-10-18 11:45:27 -07:00
keshavb96
5b8e4db855
document jax config to disable remat HLO pass
2024-10-18 11:44:51 -07:00
jax authors
919f7c8684
Merge pull request #24345 from phu0ngng:cuda_custom_call
...
PiperOrigin-RevId: 687034466
2024-10-17 13:57:15 -07:00
Sergei Lebedev
de7beb91a7
[pallas:mosaic_gpu] Added layout_cast
...
PiperOrigin-RevId: 686917796
2024-10-17 08:08:05 -07:00
jax authors
3bdc57dd29
Merge pull request #24300 from ROCm:ci_rocm_readme
...
PiperOrigin-RevId: 686872994
2024-10-17 05:21:13 -07:00
George Necula
9aa79bffba
[export] Fix github links in the export documentation
...
Reflects the repo change google/jax -> jax-ml/jax.
Also changes the error message to put the link to the documentation
in a more visible place.
2024-10-17 08:30:28 +01:00
Jake VanderPlas
e1f280c843
CI: enable additional ruff formatting checks
2024-10-16 16:09:54 -07:00
Ruturaj4
3c3b08dfd6
[ROCm] Fix README.md to update AMD JAX installation instructions
2024-10-16 17:15:32 -05:00
jax authors
089e4aa904
Merge pull request #24341 from phu0ngng:cuda_graph_ex
...
PiperOrigin-RevId: 686577115
2024-10-16 11:23:28 -07:00
jax authors
ead1c05ada
Merge pull request #23831 from nouiz:doc_policies
...
PiperOrigin-RevId: 686576725
2024-10-16 11:21:41 -07:00
Phuong Nguyen
d4bbb4fd84
added cmdBuffer traits
...
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:37:49 -07:00
Phuong Nguyen
82113cd047
rm CmdBuffer traits
...
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:27:09 -07:00
Phuong Nguyen
f3775aa233
added cudaGraph traits + use register_ffi_target()
...
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:01:20 -07:00
Sergei Lebedev
4c0d82824f
[pallas:mosaic_gpu] Added a few more operations necessary to port Flash Attention
...
PiperOrigin-RevId: 686451398
2024-10-16 04:05:36 -07:00
Yash Katariya
66c6292e6a
Make committed a public property of jax.Array.
...
Why?
Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.
PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Praveen Batra
3a3190fbce
Fix typo in Pallas TPU matmul doc. I think the logical layout of the input array is non-transposed, rather than transposed?
...
PiperOrigin-RevId: 686151692
2024-10-15 10:23:39 -07:00
Yash Katariya
a2973be051
Don't add mhlo.layout_mode = "default"
since that is the default even in PJRT and will help reduce cruft in the IR
...
PiperOrigin-RevId: 684963359
2024-10-11 14:54:32 -07:00
Justin Fu
cff9e93824
[Pallas] Add runtime assert via checkify.check. This check will halt the TPU if triggered, meaning that we would need to restart the program to recover.
...
PiperOrigin-RevId: 684940271
2024-10-11 13:34:04 -07:00
Peter Hawkins
46f0a3eee7
Clone RandomAlgorithm into lax.py, instead of using the version from XLA.
...
Change in preparation for removing HLO ops from the XLA Python bindings.
In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.
PiperOrigin-RevId: 684892102
2024-10-11 11:03:15 -07:00
Frédéric Bastien
e9011940d8
Update docs/gradient-checkpointing.md
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-10-11 12:33:10 -04:00
jax authors
bc3df0e3f5
Merge pull request #24241 from hawkinsp:autodidax
...
PiperOrigin-RevId: 684811631
2024-10-11 06:08:32 -07:00
Peter Hawkins
c0efa86bdc
Port autodidax to use StableHLO instead of classic HLO.
2024-10-11 08:25:05 -04:00
Sergei Lebedev
acd0e497af
[pallas:mosaic_gpu] GPUBlockSpec
no longer accepts swizzle
...
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.
PiperOrigin-RevId: 684797980
2024-10-11 05:11:26 -07:00
Peter Hawkins
94abaf430e
Add lax.FftType.
...
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.
We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.
PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00