Jake VanderPlas
c73f306099
Finalize deprecation of jnp.round_
...
PiperOrigin-RevId: 705998500
2024-12-13 14:13:44 -08:00
jax authors
078c7e4444
Update XLA dependency to use revision
...
cd4b4e1ba3
.
PiperOrigin-RevId: 705976057
2024-12-13 13:00:37 -08:00
jax authors
a9c140d8c9
Merge pull request #24536 from MichaelHudgins:main
...
PiperOrigin-RevId: 705964888
2024-12-13 12:23:17 -08:00
Michael Hudgins
4de58e1af7
Merge branch 'jax-ml:main' into main
2024-12-13 14:30:30 -05:00
jax authors
99b390ce96
Merge pull request #25444 from justinjfu:gpu_docs_update
...
PiperOrigin-RevId: 705938411
2024-12-13 11:05:31 -08:00
jax authors
87b66f3c35
Merge pull request #25451 from jakevdp:undep-core
...
PiperOrigin-RevId: 705910242
2024-12-13 09:36:32 -08:00
Yash Katariya
80cf141863
Set device_assignment to None when only AbstractMesh exist in the computation
...
PiperOrigin-RevId: 705899088
2024-12-13 08:55:56 -08:00
Peter Hawkins
64eae324ee
Migrate JAX MLIR Python dialect extensions to nanobind.
...
Now that https://github.com/llvm/llvm-project/pull/117922 has landed upstream, we can work towards removing our last uses of pybind11.
PiperOrigin-RevId: 705872751
2024-12-13 07:08:28 -08:00
jax authors
5a3fa500b5
Merge pull request #25459 from hawkinsp:sort
...
PiperOrigin-RevId: 705869484
2024-12-13 06:55:32 -08:00
jax authors
6926356330
Merge pull request #25454 from sfvaroglu:sevin/slop_factor_doc
...
PiperOrigin-RevId: 705867485
2024-12-13 06:47:09 -08:00
Peter Hawkins
0922feb2f5
Use a broadcasted gather in the sort JVP, rather than forming explicit iotas.
...
Use an unsigned index and promise that it is in bounds.
2024-12-13 09:23:34 -05:00
jax authors
a123d4e39e
Remove autotune sharing.
...
xla_gpu_shard_autotuning can be used now instead and it is enabled by default.
PiperOrigin-RevId: 705792463
2024-12-13 01:22:27 -08:00
Yash Katariya
d0f63da4b5
Allow tracing and lowering (with lowering_platforms specified) to work with an AbstractMesh. Such a computation cannot be compiled.
...
This is useful for `jax.export`, e.g., for cross-platform export when we do not have access to the actual devices for which this computation is lowered.
PiperOrigin-RevId: 705764178
2024-12-12 23:17:27 -08:00
Parker Schuh
0e7f218eb0
Support axis_index inside shard_map(auto=...) by using iota and
...
then calling full_to_shard.
PiperOrigin-RevId: 705704369
2024-12-12 18:39:05 -08:00
Yash Katariya
1453a222d4
Add dynamic_arg_layouts to C++ cache and add a test in JAX which checks for cache miss if layouts of inputs arguments are different to the same jitted function.
...
PiperOrigin-RevId: 705703520
2024-12-12 18:35:29 -08:00
Ivy Zheng
ef06607735
Implement flatten one level with keys in C++ and use it for the prefix/equality error printing.
...
With this, we should be able to safely delete the python with-path registry after a new jaxlib release.
Also changed all `std::string_view` to `absl::string_view` per requirements of TF repository.
PiperOrigin-RevId: 705669465
2024-12-12 16:37:09 -08:00
jax authors
eb3ea985b7
Merge pull request #25452 from jakevdp:lax-abstractify
...
PiperOrigin-RevId: 705652650
2024-12-12 15:39:33 -08:00
Sevin Varoglu
c563e47314
Add flag desc to gpu_performance_tips.md
2024-12-12 15:39:07 -08:00
Ivy Zheng
26c40fadfd
Add jax.tree shortcuts for .*_with_path calls, for convenience of users.
...
PiperOrigin-RevId: 705645570
2024-12-12 15:13:32 -08:00
Justin Fu
1021603f85
Remove deprecated XLA GPU flags.
2024-12-12 15:10:43 -08:00
Nitin Srinivasan
ecc2673e7b
Disable failing test cases when JAX_ENABLE_X64=1
in the Bazel CPU build
...
PiperOrigin-RevId: 705635799
2024-12-12 14:41:52 -08:00
Sergei Lebedev
a14e6968bf
[mosaic] Migrated the serialization pass from codegen to pass_boilerplate.h
...
This prepares teh generalization of the serialization pass to handle both
Mosaic TPU and GPU.
PiperOrigin-RevId: 705628923
2024-12-12 14:19:36 -08:00
Jake VanderPlas
67b3413b96
Cleanup: replace lax._abstractify with core.get_aval
2024-12-12 14:08:17 -08:00
Jake VanderPlas
d3406768f0
temporarily un-deprecate several jax.core APIs.
...
These were causing excessive log-spam for some users; I'll work to migrate
them to jax.extend before re-deprecating these.
2024-12-12 13:15:58 -08:00
jax authors
97459ba9aa
Update XLA dependency to use revision
...
fb8e7d579f
.
PiperOrigin-RevId: 705602301
2024-12-12 12:55:31 -08:00
Peter Hawkins
6548caf239
Relax test tolerance for complex128 pow in lax_test.py.
...
This is failing in CI in some CPU configurations.
PiperOrigin-RevId: 705558897
2024-12-12 10:50:44 -08:00
jax authors
ea63aeab01
Merge pull request #25442 from jakevdp:raise-to-shaped
...
PiperOrigin-RevId: 705556199
2024-12-12 10:43:17 -08:00
Fiona Lang
3f58337bbc
Fix jax.core deprecation warnings for jax.extend.core.primitives symbols.
...
PiperOrigin-RevId: 705546724
2024-12-12 10:16:37 -08:00
Jevin Jiang
3ff5706051
[Mosaic TPU][NFC] Create local namespace to prevent function name duplication error under global namespace mlir::tpu
...
PiperOrigin-RevId: 705538965
2024-12-12 09:53:39 -08:00
jax authors
3c649b134a
Merge pull request #25409 from gnecula:poly_mod
...
PiperOrigin-RevId: 705537824
2024-12-12 09:50:28 -08:00
Jake VanderPlas
40367a9eaf
Cleanup: remove uses of no-op raise_to_shaped
2024-12-12 09:49:06 -08:00
jax authors
99d675ac25
Merge pull request #25426 from hawkinsp:tls
...
PiperOrigin-RevId: 705510277
2024-12-12 08:15:34 -08:00
Peter Hawkins
62e66b684b
Don't monkey-patch functions in test_utils to count events for tests.
...
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Adam Paszke
3630756e87
[Mosaic GPU] Use events as the default profiling method
...
JAX still supports old CUDA versions such as 12.0, where CUPTI leaks memory.
PiperOrigin-RevId: 705459909
2024-12-12 04:42:56 -08:00
George Necula
27b024b240
[shape_poly] Improve handling of mod(e, k) == 0 constraints.
...
These constraints turn out to be quite useful, e.g., when
we want to say that certain dimensions are a multiple of
a device axis.
Previously, the constraint `mod(e, k) == 0` was being useful
only to normalize away `mod(e, k)`. In particular it was not
useful for proving `k * floordiv(e, k)`. Now we add that
features.
2024-12-12 10:31:02 +01:00
jax authors
dda6b88864
Merge pull request #25425 from jax-ml:linearize-bugs-and-stuff
...
PiperOrigin-RevId: 705313000
2024-12-11 18:27:27 -08:00
Tzu-Wei Sung
21f6b401dd
[Mosaic] Pad trailing transposes chunks with zeros.
...
PiperOrigin-RevId: 705310340
2024-12-11 18:20:05 -08:00
Dougal
8fe8d241e8
Fixes to direct linearize
...
* Fix a bug in pjit linearization rule
* Handle multiple results and zeros in fallback rule
* Handle `has_aux`
* Implement process_custom_vjp_call
2024-12-11 20:57:36 -05:00
Yash Katariya
39e4f7f2ce
[sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
...
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
jax authors
ccfef7a549
Merge pull request #25424 from jakevdp:dedupe-broadcast
...
PiperOrigin-RevId: 705261094
2024-12-11 15:25:02 -08:00
Jake VanderPlas
c40780b957
internal: dedupe lax broadcasting logic
2024-12-11 15:03:39 -08:00
Gleb Pobudzey
e92ca9bbae
Use boolean values for partial mask blocks in the splash attention kernel.
...
The values are guaranteed to be 0 or 1 since we create this array ourselves when processing the masks into a MaskInfo object.
PiperOrigin-RevId: 705252534
2024-12-11 14:59:30 -08:00
jax authors
b7af1eb905
Merge pull request #25381 from jakevdp:mypy-np22
...
PiperOrigin-RevId: 705248189
2024-12-11 14:47:37 -08:00
jax authors
e55bbc778a
Merge pull request #25422 from jakevdp:broadcast-rank
...
PiperOrigin-RevId: 705245013
2024-12-11 14:38:24 -08:00
Jake VanderPlas
f4f4bf6a19
Fix type annotations for NumPy 2.2
2024-12-11 14:24:58 -08:00
jax authors
fb53971802
Merge pull request #25419 from jakevdp:lax-dtypes
...
PiperOrigin-RevId: 705230631
2024-12-11 13:59:38 -08:00
Jake VanderPlas
76d8b9c5a4
internal: simplify broadcast_shapes logic
2024-12-11 13:50:20 -08:00
jax authors
b8d2e9383a
Update XLA dependency to use revision
...
209cbfa31a
.
PiperOrigin-RevId: 705215149
2024-12-11 13:16:26 -08:00
jax authors
5e887b446b
Merge pull request #25414 from jakevdp:finalize-deps
...
PiperOrigin-RevId: 705197214
2024-12-11 12:24:13 -08:00
jax authors
8c4d3db99a
Merge pull request #23225 from snadampal:aarch64_jax
...
PiperOrigin-RevId: 705196103
2024-12-11 12:20:32 -08:00