23928 Commits

Author SHA1 Message Date
jax authors
5fd4ea9054 Merge pull request #24525 from gnecula:fix_changelog
PiperOrigin-RevId: 689766317
2024-10-25 06:26:26 -07:00
George Necula
c62b19883f Fix copy and paste error in CHANGELOG. 2024-10-25 16:11:35 +03:00
Dan Foreman-Mackey
33a46e8f68 Re-enable jax2tf test for dot algorithm with stricter TF version check. 2024-10-25 08:26:19 -04:00
Jake VanderPlas
8948e6de58 sharding cleanup: use inline checks for unimplemented and auto 2024-10-25 04:22:40 -07:00
Ruturaj4
bfd7075c39 [ROCm] ci build fixes 2024-10-25 05:01:44 -05:00
Peter Hawkins
bb5fbec64b [mosaic] Use .clone() to duplicate a module, rather than printing and parsing it.
PiperOrigin-RevId: 689708462
2024-10-25 02:32:49 -07:00
George Necula
9088adda68 [jax2tf] Disable jax2tf with non-native serialization.
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.

This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.

PiperOrigin-RevId: 689708392
2024-10-25 02:30:54 -07:00
George Necula
0bc70bbd73 Disable jax2tf test recently added in cl/688976685.
See failure: https://github.com/jax-ml/jax/actions/runs/11514933009/job/32054580529?pr=24183

PiperOrigin-RevId: 689703645
2024-10-25 02:12:34 -07:00
jax authors
3823612ebf Merge pull request #24505 from gnecula:jax2tf_bug
PiperOrigin-RevId: 689662727
2024-10-24 23:36:55 -07:00
Ayaka
5c614470ad [Pallas TPU] Add lowerings for scalar absf and rsqrt
This PR is similar to https://github.com/jax-ml/jax/pull/24284

PiperOrigin-RevId: 689546724
2024-10-24 15:59:34 -07:00
Kanglan Tang
af28595909 Add a jax_wheel Bazel rule to build jax pip packages
PiperOrigin-RevId: 689514531
2024-10-24 14:20:46 -07:00
Parker Schuh
9500bd451a Fix float0 behavior inside shard_map transpose under scan.
PiperOrigin-RevId: 689512880
2024-10-24 14:15:40 -07:00
jax authors
0d68a2bf3b Merge pull request #24511 from mattjj:improve-concreteness-error-in-remat
PiperOrigin-RevId: 689488766
2024-10-24 13:05:36 -07:00
jax authors
8c9dc21e30 Update hermetic CUDA docs.
PiperOrigin-RevId: 689463215
2024-10-24 11:51:02 -07:00
jax authors
bd417ba6d0 Update XLA dependency to use revision
1f6bd971dd.

PiperOrigin-RevId: 689457454
2024-10-24 11:36:55 -07:00
Matthew Johnson
4231128535 improve concreteness error message in remat 2024-10-24 18:13:42 +00:00
jax authors
afc78524e1 Remove silent data corruption runtime flags from persistent cache key.
These flags have no effect on the compiled executable, just the runtime execution.

PiperOrigin-RevId: 689442580
2024-10-24 10:59:44 -07:00
Adam Paszke
6634f5a348 [Mosaic GPU] Use absl::StrCat instead std::string::operator+
Repeated string addition is apparently a bit of an anti-pattern. Not that it matters
much in this place, but why not do it properly.

PiperOrigin-RevId: 689416587
2024-10-24 09:49:51 -07:00
Yash Katariya
6c8e56f43f Finish 0.4.35 release by removing dead code
PiperOrigin-RevId: 689396609
2024-10-24 08:45:43 -07:00
George Necula
e5bbf3dca1 [jax2tf] Fixes a bad interaction between jax2tf.convert, TF, and call_tf.
Consider the use case when we call_tf a restored saved model that
includes parameters (hence functions closing over tf.Variable), and then
we jax2tf.convert it with native serialization, under tf.function (or
for saving to saved model).

The lowering for call_tf in presence of functions with captured inputs
requires looking up the tf.Variable and reading its value. This fails
with an error that `v.numpy()` is not allowd in graph mode. The fix
is to use `tf.init_scope()` to lift out of graph building mode, so that
we can read the value of the variables.
2024-10-24 17:41:32 +03:00
George Necula
e5f4be5564 [shape_poly] Expands support for random.choice
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.

The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.

Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
2024-10-24 17:20:09 +03:00
jax authors
644f881a51 Merge pull request #24490 from hawkinsp:searchsorted
PiperOrigin-RevId: 689364122
2024-10-24 06:56:32 -07:00
jax authors
c311c7387e Merge pull request #24427 from kaixih:tolerence_jax_sdpa
PiperOrigin-RevId: 689328282
2024-10-24 04:36:48 -07:00
Sergei Lebedev
717467a82f [pallas] input_output_aliases now only include refs which have been written to
PiperOrigin-RevId: 689323778
2024-10-24 04:18:01 -07:00
Adam Paszke
bb2e2303d7 [Pallas:MGPU] Treat each warpgroup as a single logical thread.
As an extra minor change, we now disallow specifying the predicate when uniform is
unset, as that implies that we're going to use two different mechanisms to select
a single thread.

PiperOrigin-RevId: 689289365
2024-10-24 01:54:10 -07:00
Andrey Portnoy
14e0f0e7fa [Mosaic GPU] Query SM and PTX ISA dynamically using driver and LLVM
Originally proposed in #24021. Slightly rewritter to make testing with internal LLVM toolchains better.

Use CUDA driver API to query major and minor compute capabilities, thus arriving at a "base" SM string (e.g. `sm_90`).
Then use LLVM to see if we can "upgrade" the base SM string to one that enables architecture-specific capabilities (e.g. `sm_90a`).
Then use LLVM to map the SM string to a PTX ISA version that supports the SM.

Co-authored-by: Andrey Portnoy <aportnoy@nvidia.com>
PiperOrigin-RevId: 689286774
2024-10-24 01:46:29 -07:00
ZincCat
bd9a10e4eb fix the wrong output of pallas attention kernel when q_len!=kv_len 2024-10-24 02:20:54 -04:00
Jevin Jiang
b8bacda2d9 [Mosaic TPU] Use native vector tiling to load and store with untiled memref.
PiperOrigin-RevId: 689142734
2024-10-23 16:22:16 -07:00
jax authors
df6e5e76cc Merge pull request #24487 from jakevdp:block-doc
PiperOrigin-RevId: 689071677
2024-10-23 12:55:53 -07:00
jax authors
16f8958ece Update XLA dependency to use revision
ffcd64e30e.

PiperOrigin-RevId: 689062763
2024-10-23 12:30:09 -07:00
Ayaka
ea1fc65c69 [Pallas TPU] Fix OpsTest.test_elementwise test for bf16 inputs
For bf16 inputs, the shape must be (8, 128)

PiperOrigin-RevId: 689060557
2024-10-23 12:23:46 -07:00
Peter Hawkins
a7d711513c Perform searchsorted binary search using unsigned intermediate values.
Midpoint computation for a binary search should be performed unsigned, see https://research.google/blog/extra-extra-read-all-about-it-nearly-all-binary-searches-and-mergesorts-are-broken/

In addition, we can avoid the somewhat verbose floor_divide HLO since we know the values in question are positive.
2024-10-23 15:11:55 -04:00
Jake VanderPlas
9bf1516abe Improve docs for jnp.block 2024-10-23 11:37:19 -07:00
Christos Perivolaropoulos
6235158582 Dot algorithms are now supported for all types, change the test to reflect it.
PiperOrigin-RevId: 689036316
2024-10-23 11:17:55 -07:00
jax authors
7556f66e20 Merge pull request #24476 from jakevdp:cov-doc
PiperOrigin-RevId: 689034000
2024-10-23 11:13:08 -07:00
Jake VanderPlas
148f9d6559 Better docs for jnp.cov & jnp.corrcoef 2024-10-23 10:17:00 -07:00
jax authors
84cd3567b5 Avoid querying metadata query to check if it's GCE if TPU_SKIP_MDS_QUERY is set.
PiperOrigin-RevId: 689009215
2024-10-23 10:09:02 -07:00
Tzu-Wei Sung
11faf68018 [Pallas:TPU] Match lax.pow(float, int) behavior in Pallas.
Both math::PowF and Exp2Op require a floating point exponent so casting it to x.dtype for parity of lax.pow.

PiperOrigin-RevId: 688997089
2024-10-23 09:38:03 -07:00
jax authors
3c2c60e4d0 Merge pull request #24482 from jakevdp:unwrap-doc
PiperOrigin-RevId: 688995901
2024-10-23 09:34:31 -07:00
Adam Paszke
88d231a3f2 [Pallas] Allow core_map's mesh to discharge backend specific effects
Backends often have custom effectful primitives, but their effects do not extend
beyond the scope of a single kernel, so we should remove them in core_map's abstract eval.

PiperOrigin-RevId: 688990275
2024-10-23 09:16:17 -07:00
Adam Paszke
5b3b6e84db [Pallas:MGPU] Allow initializing accumulators with values in registers
This is useful to avoid unnecessary shared stores and fences in some kernels like
flash attention.

PiperOrigin-RevId: 688977199
2024-10-23 08:36:39 -07:00
Dan Foreman-Mackey
5ea6215436 Add test for jax2tf conversion of dot general with algorithm.
Fixes https://github.com/jax-ml/jax/issues/24236

To be fair, the fix was actually in https://github.com/openxla/xla/pull/18222, but this adds a test to JAX to confirm.

PiperOrigin-RevId: 688976685
2024-10-23 08:34:52 -07:00
Christos Perivolaropoulos
40c92c1f8c [pallas:mosaic_gpu] An extremely specific heuristic to allow swiglu.
PiperOrigin-RevId: 688973012
2024-10-23 08:24:49 -07:00
Christos Perivolaropoulos
155aa6caa4 [pallas:mosaic_gpu] Memref reshape Transform to allow the user to reshape references.
It is not possible for primitives to return references so in order to support reshaping we need to use TransformRef. This CL introduces both a reshape memref transform and a function for the user to create transformed refs of that type.

PiperOrigin-RevId: 688966337
2024-10-23 08:04:20 -07:00
Sharad Vikram
ce8ecbd16d Add an extension mechanism to run_state that allows:
* Uninitialized values
* Custom ref aval construction

This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 688965089
2024-10-23 08:00:56 -07:00
Jake VanderPlas
d6f4ce1612 Better docs for jnp.unwrap 2024-10-23 07:58:31 -07:00
jax authors
abc6c00460 Merge pull request #24473 from hawkinsp:postrelease
PiperOrigin-RevId: 688951659
2024-10-23 07:13:27 -07:00
Peter Hawkins
2aeda17829 Merge branch 'release/0.4.35' 2024-10-23 08:50:31 -04:00
jax authors
7ad73e44ce Merge pull request #24446 from gnecula:export_doc
PiperOrigin-RevId: 688886756
2024-10-23 02:50:57 -07:00
jax authors
81bee4219d Merge pull request #24469 from jakevdp:indices-doc
PiperOrigin-RevId: 688763174
2024-10-22 18:36:00 -07:00