George Necula
9088adda68
[jax2tf] Disable jax2tf with non-native serialization.
...
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.
This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.
PiperOrigin-RevId: 689708392
2024-10-25 02:30:54 -07:00
George Necula
0bc70bbd73
Disable jax2tf test recently added in cl/688976685.
...
See failure: https://github.com/jax-ml/jax/actions/runs/11514933009/job/32054580529?pr=24183
PiperOrigin-RevId: 689703645
2024-10-25 02:12:34 -07:00
jax authors
3823612ebf
Merge pull request #24505 from gnecula:jax2tf_bug
...
PiperOrigin-RevId: 689662727
2024-10-24 23:36:55 -07:00
Ayaka
5c614470ad
[Pallas TPU] Add lowerings for scalar absf
and rsqrt
...
This PR is similar to https://github.com/jax-ml/jax/pull/24284
PiperOrigin-RevId: 689546724
2024-10-24 15:59:34 -07:00
Kanglan Tang
af28595909
Add a jax_wheel Bazel rule to build jax pip packages
...
PiperOrigin-RevId: 689514531
2024-10-24 14:20:46 -07:00
Parker Schuh
9500bd451a
Fix float0 behavior inside shard_map transpose under scan.
...
PiperOrigin-RevId: 689512880
2024-10-24 14:15:40 -07:00
jax authors
0d68a2bf3b
Merge pull request #24511 from mattjj:improve-concreteness-error-in-remat
...
PiperOrigin-RevId: 689488766
2024-10-24 13:05:36 -07:00
jax authors
8c9dc21e30
Update hermetic CUDA docs.
...
PiperOrigin-RevId: 689463215
2024-10-24 11:51:02 -07:00
jax authors
bd417ba6d0
Update XLA dependency to use revision
...
1f6bd971dd
.
PiperOrigin-RevId: 689457454
2024-10-24 11:36:55 -07:00
Matthew Johnson
4231128535
improve concreteness error message in remat
2024-10-24 18:13:42 +00:00
jax authors
afc78524e1
Remove silent data corruption runtime flags from persistent cache key.
...
These flags have no effect on the compiled executable, just the runtime execution.
PiperOrigin-RevId: 689442580
2024-10-24 10:59:44 -07:00
Adam Paszke
6634f5a348
[Mosaic GPU] Use absl::StrCat instead std::string::operator+
...
Repeated string addition is apparently a bit of an anti-pattern. Not that it matters
much in this place, but why not do it properly.
PiperOrigin-RevId: 689416587
2024-10-24 09:49:51 -07:00
Yash Katariya
6c8e56f43f
Finish 0.4.35 release by removing dead code
...
PiperOrigin-RevId: 689396609
2024-10-24 08:45:43 -07:00
George Necula
e5bbf3dca1
[jax2tf] Fixes a bad interaction between jax2tf.convert, TF, and call_tf.
...
Consider the use case when we call_tf a restored saved model that
includes parameters (hence functions closing over tf.Variable), and then
we jax2tf.convert it with native serialization, under tf.function (or
for saving to saved model).
The lowering for call_tf in presence of functions with captured inputs
requires looking up the tf.Variable and reading its value. This fails
with an error that `v.numpy()` is not allowd in graph mode. The fix
is to use `tf.init_scope()` to lift out of graph building mode, so that
we can read the value of the variables.
2024-10-24 17:41:32 +03:00
jax authors
644f881a51
Merge pull request #24490 from hawkinsp:searchsorted
...
PiperOrigin-RevId: 689364122
2024-10-24 06:56:32 -07:00
jax authors
c311c7387e
Merge pull request #24427 from kaixih:tolerence_jax_sdpa
...
PiperOrigin-RevId: 689328282
2024-10-24 04:36:48 -07:00
Sergei Lebedev
717467a82f
[pallas] input_output_aliases
now only include refs which have been written to
...
PiperOrigin-RevId: 689323778
2024-10-24 04:18:01 -07:00
Adam Paszke
bb2e2303d7
[Pallas:MGPU] Treat each warpgroup as a single logical thread.
...
As an extra minor change, we now disallow specifying the predicate when uniform is
unset, as that implies that we're going to use two different mechanisms to select
a single thread.
PiperOrigin-RevId: 689289365
2024-10-24 01:54:10 -07:00
Andrey Portnoy
14e0f0e7fa
[Mosaic GPU] Query SM and PTX ISA dynamically using driver and LLVM
...
Originally proposed in #24021 . Slightly rewritter to make testing with internal LLVM toolchains better.
Use CUDA driver API to query major and minor compute capabilities, thus arriving at a "base" SM string (e.g. `sm_90`).
Then use LLVM to see if we can "upgrade" the base SM string to one that enables architecture-specific capabilities (e.g. `sm_90a`).
Then use LLVM to map the SM string to a PTX ISA version that supports the SM.
Co-authored-by: Andrey Portnoy <aportnoy@nvidia.com>
PiperOrigin-RevId: 689286774
2024-10-24 01:46:29 -07:00
Jevin Jiang
b8bacda2d9
[Mosaic TPU] Use native vector tiling to load and store with untiled memref.
...
PiperOrigin-RevId: 689142734
2024-10-23 16:22:16 -07:00
jax authors
df6e5e76cc
Merge pull request #24487 from jakevdp:block-doc
...
PiperOrigin-RevId: 689071677
2024-10-23 12:55:53 -07:00
jax authors
16f8958ece
Update XLA dependency to use revision
...
ffcd64e30e
.
PiperOrigin-RevId: 689062763
2024-10-23 12:30:09 -07:00
Ayaka
ea1fc65c69
[Pallas TPU] Fix OpsTest.test_elementwise
test for bf16 inputs
...
For bf16 inputs, the shape must be (8, 128)
PiperOrigin-RevId: 689060557
2024-10-23 12:23:46 -07:00
Peter Hawkins
a7d711513c
Perform searchsorted binary search using unsigned intermediate values.
...
Midpoint computation for a binary search should be performed unsigned, see https://research.google/blog/extra-extra-read-all-about-it-nearly-all-binary-searches-and-mergesorts-are-broken/
In addition, we can avoid the somewhat verbose floor_divide HLO since we know the values in question are positive.
2024-10-23 15:11:55 -04:00
Jake VanderPlas
9bf1516abe
Improve docs for jnp.block
2024-10-23 11:37:19 -07:00
Christos Perivolaropoulos
6235158582
Dot algorithms are now supported for all types, change the test to reflect it.
...
PiperOrigin-RevId: 689036316
2024-10-23 11:17:55 -07:00
jax authors
7556f66e20
Merge pull request #24476 from jakevdp:cov-doc
...
PiperOrigin-RevId: 689034000
2024-10-23 11:13:08 -07:00
Jake VanderPlas
148f9d6559
Better docs for jnp.cov & jnp.corrcoef
2024-10-23 10:17:00 -07:00
jax authors
84cd3567b5
Avoid querying metadata query to check if it's GCE if TPU_SKIP_MDS_QUERY
is set.
...
PiperOrigin-RevId: 689009215
2024-10-23 10:09:02 -07:00
Tzu-Wei Sung
11faf68018
[Pallas:TPU] Match lax.pow(float, int) behavior in Pallas.
...
Both math::PowF and Exp2Op require a floating point exponent so casting it to x.dtype for parity of lax.pow.
PiperOrigin-RevId: 688997089
2024-10-23 09:38:03 -07:00
jax authors
3c2c60e4d0
Merge pull request #24482 from jakevdp:unwrap-doc
...
PiperOrigin-RevId: 688995901
2024-10-23 09:34:31 -07:00
Adam Paszke
88d231a3f2
[Pallas] Allow core_map's mesh to discharge backend specific effects
...
Backends often have custom effectful primitives, but their effects do not extend
beyond the scope of a single kernel, so we should remove them in core_map's abstract eval.
PiperOrigin-RevId: 688990275
2024-10-23 09:16:17 -07:00
Adam Paszke
5b3b6e84db
[Pallas:MGPU] Allow initializing accumulators with values in registers
...
This is useful to avoid unnecessary shared stores and fences in some kernels like
flash attention.
PiperOrigin-RevId: 688977199
2024-10-23 08:36:39 -07:00
Dan Foreman-Mackey
5ea6215436
Add test for jax2tf conversion of dot general with algorithm.
...
Fixes https://github.com/jax-ml/jax/issues/24236
To be fair, the fix was actually in https://github.com/openxla/xla/pull/18222 , but this adds a test to JAX to confirm.
PiperOrigin-RevId: 688976685
2024-10-23 08:34:52 -07:00
Christos Perivolaropoulos
40c92c1f8c
[pallas:mosaic_gpu] An extremely specific heuristic to allow swiglu.
...
PiperOrigin-RevId: 688973012
2024-10-23 08:24:49 -07:00
Christos Perivolaropoulos
155aa6caa4
[pallas:mosaic_gpu] Memref reshape Transform to allow the user to reshape references.
...
It is not possible for primitives to return references so in order to support reshaping we need to use TransformRef. This CL introduces both a reshape memref transform and a function for the user to create transformed refs of that type.
PiperOrigin-RevId: 688966337
2024-10-23 08:04:20 -07:00
Sharad Vikram
ce8ecbd16d
Add an extension mechanism to run_state that allows:
...
* Uninitialized values
* Custom ref aval construction
This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values.
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 688965089
2024-10-23 08:00:56 -07:00
Jake VanderPlas
d6f4ce1612
Better docs for jnp.unwrap
2024-10-23 07:58:31 -07:00
jax authors
abc6c00460
Merge pull request #24473 from hawkinsp:postrelease
...
PiperOrigin-RevId: 688951659
2024-10-23 07:13:27 -07:00
Peter Hawkins
2aeda17829
Merge branch 'release/0.4.35'
2024-10-23 08:50:31 -04:00
jax authors
7ad73e44ce
Merge pull request #24446 from gnecula:export_doc
...
PiperOrigin-RevId: 688886756
2024-10-23 02:50:57 -07:00
jax authors
81bee4219d
Merge pull request #24469 from jakevdp:indices-doc
...
PiperOrigin-RevId: 688763174
2024-10-22 18:36:00 -07:00
Yash Katariya
4688da3118
Fix jax2tf failure coming from dot_general
...
PiperOrigin-RevId: 688738110
2024-10-22 16:53:47 -07:00
Jake VanderPlas
9038bb2664
Better documentation for jnp.indices
2024-10-22 16:48:36 -07:00
Yash Katariya
f8a1f02d6b
[sharding_in_types][Take 2] Add out_type
argument to einsum
and dot_general
to allow specifying for the output type. Right now, it only accept a NamedSharding
but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout
.
...
Reverts 0b3f0e11fb0c37342b3c05ad5d53f3435b6ca44c
PiperOrigin-RevId: 688663504
2024-10-22 13:10:43 -07:00
jax authors
32be1992ee
Update XLA dependency to use revision
...
bf8dafb2a7
.
PiperOrigin-RevId: 688648145
2024-10-22 12:29:58 -07:00
Peter Hawkins
81991d87c8
JAX release 0.4.35
2024-10-22 15:00:23 -04:00
Peter Hawkins
e4f3f8f064
Use libtpu releases rather than libtpu-nightly for jax[tpu].
...
PiperOrigin-RevId: 688632409
2024-10-22 11:47:07 -07:00
jax authors
1c6b0a9193
Merge pull request #24465 from jakevdp:fix-mypy
...
PiperOrigin-RevId: 688632024
2024-10-22 11:45:27 -07:00
jax authors
9a2dd19a92
Merge pull request #21524 from andportnoy:aportnoy/unknown-platform-lowering-warning
...
PiperOrigin-RevId: 688630259
2024-10-22 11:40:39 -07:00