15048 Commits

Author SHA1 Message Date
Jake VanderPlas
2a35b0b9a3 Remove int4 from jtu.dtypes.all_dtypes.
Why? It's not included in the supported() enumeration on any platforms, so
there is no need to mention it here. I tried fixing this by including it in
supported(), but this led to many errors. It's better to not list it here,
because it might mislead us into thinking it's being tested.
2024-10-28 06:15:46 -07:00
Adam Paszke
343cf18e09 [Pallas:MGPU] Wire up the Mosaic GPU profiler into Pallas
PiperOrigin-RevId: 690574747
2024-10-28 05:40:08 -07:00
Peter Hawkins
6f3c01238e [mosaic] Directly build IR in _device_id_to_logical, rather than using lower_fun.
This is just as simple and faster.

PiperOrigin-RevId: 690196495
2024-10-26 16:58:42 -07:00
jax authors
2b01affd12 Merge pull request #24537 from jakevdp:doc-examples
PiperOrigin-RevId: 690156005
2024-10-26 12:19:54 -07:00
Dan Foreman-Mackey
ad1d864b05 Fix lint at head 2024-10-26 07:41:44 -04:00
Parker Schuh
6b065579d4 Support None in PmapSharding as a replacement for device_put_replicated.
eg:
`jax.device_put(x, PmapSharding.default(x.shape, None, jax.local_devices()))`
PiperOrigin-RevId: 689956669
2024-10-25 17:01:03 -07:00
jax authors
47bacfab5e Merge pull request #24031 from garymm:garymm/vmap-error-msg
PiperOrigin-RevId: 689940504
2024-10-25 15:59:57 -07:00
jax authors
eaef7d7279 Merge pull request #24495 from zinccat:pallas_attn_fix
PiperOrigin-RevId: 689934045
2024-10-25 15:37:17 -07:00
Gary Miguel
9f7f08eccb Fix vmap error message when args passed by keyword
See the new test for a case that used to produce the wrong message.

Fixes: #24406
2024-10-25 15:17:03 -07:00
Jake VanderPlas
02daf75f97 Add new jnp.cumulative_prod function.
This follows the API of the similar function added in NumPy 2.1.0
2024-10-25 13:45:54 -07:00
Gunhyun Park
94440c74c8 Register acos primitive to lower to CHLO acos.
Related: https://github.com/openxla/stablehlo/pull/2496
PiperOrigin-RevId: 689890774
2024-10-25 13:20:36 -07:00
Jake VanderPlas
adf1492843 Add some missing jax.numpy documentation 2024-10-25 13:14:44 -07:00
jax authors
5a41093970 Merge pull request #24524 from dfm:jax2tf-dot-algorithm-version
PiperOrigin-RevId: 689875086
2024-10-25 12:28:20 -07:00
jax authors
eb13e68b53 Merge pull request #24433 from dfm:ffi-call-input-output-aliases
PiperOrigin-RevId: 689872997
2024-10-25 12:22:11 -07:00
Sergei Lebedev
5a2128e44b [pallas] Removed deprecated aliases to CostEstimate and run_scoped
PiperOrigin-RevId: 689871787
2024-10-25 12:16:58 -07:00
jax authors
6f371212d9 Implements an alternate version of ragged_attention, wherein, the actual attention kernel itself is dense. Meaning, this kernel does not have the compute saving (@when wrapped kernel) or prefetch/index skipping (via index rewriting) as part of the kernel. Rather, the kernel is invoked with a Jumble (A ragged type representation) and pallas takes care of applying the correct work skipping and index rewriting.
Performance wise, we should be at parity, although this has not yet been tested.

Authoring wise, the new kernel is significantly smaller and simpler to write.

A major known limitation of this approach, which we have a plan to fix, is the invariant that the `seq_len % grid_size == 0` - we plan to relax this limitation in following CLs.

PiperOrigin-RevId: 689868468
2024-10-25 12:07:34 -07:00
Brian Wieder
7db4b254e0 Clear extra_jit_context when exiting.
In for some reason, extra_jit_context was leaking when `pallas.core` no longer imported `pallas.pallas_call`, leading to leaking XLA Clients.

PiperOrigin-RevId: 689857071
2024-10-25 11:35:47 -07:00
Yash Katariya
34611be53d Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here:
* Handled transpose of `dot_general` correctly with shardings
* Handled transpose of `reduce_sum` correctly with shardings
* `ShapedArray.to_tangent_aval` now sets the sharding of the tangent (not handling unreduced yet).
* `ConcreteArray.aval` correctly sets the sharding which is extracted from the `val` attribute.
* (Paired with Dougal!) Added sharding rule for `reshape_p` only when singleton dims are added/removed.
* Added sharding rule for `select_n_p` because it gets called during `jax.grad` of minformer.
* Added `sharding` attribute to `broadcast_in_dim` because we need to provide the correct sharding to it during `full` and transpose of `reduce_sum`.

PiperOrigin-RevId: 689837320
2024-10-25 10:35:25 -07:00
jax authors
ee0292c6d2 Merge pull request #24470 from jakevdp:is-unspecified
PiperOrigin-RevId: 689820689
2024-10-25 09:49:30 -07:00
Dan Foreman-Mackey
21f3353544 Add support for layouts and other advanced features in ffi_call. 2024-10-25 12:31:07 -04:00
Jake VanderPlas
d4c46825d6 Finalize deprecation of xb, xc, & xe symbols in jax.interpreters.xla
PiperOrigin-RevId: 689792265
2024-10-25 08:12:44 -07:00
jax authors
8c6164a492 Merge pull request #24500 from gnecula:poly_choice
PiperOrigin-RevId: 689792194
2024-10-25 08:10:52 -07:00
Dan Foreman-Mackey
33a46e8f68 Re-enable jax2tf test for dot algorithm with stricter TF version check. 2024-10-25 08:26:19 -04:00
Jake VanderPlas
8948e6de58 sharding cleanup: use inline checks for unimplemented and auto 2024-10-25 04:22:40 -07:00
Peter Hawkins
bb5fbec64b [mosaic] Use .clone() to duplicate a module, rather than printing and parsing it.
PiperOrigin-RevId: 689708462
2024-10-25 02:32:49 -07:00
George Necula
9088adda68 [jax2tf] Disable jax2tf with non-native serialization.
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.

This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.

PiperOrigin-RevId: 689708392
2024-10-25 02:30:54 -07:00
George Necula
0bc70bbd73 Disable jax2tf test recently added in cl/688976685.
See failure: https://github.com/jax-ml/jax/actions/runs/11514933009/job/32054580529?pr=24183

PiperOrigin-RevId: 689703645
2024-10-25 02:12:34 -07:00
jax authors
3823612ebf Merge pull request #24505 from gnecula:jax2tf_bug
PiperOrigin-RevId: 689662727
2024-10-24 23:36:55 -07:00
Kanglan Tang
af28595909 Add a jax_wheel Bazel rule to build jax pip packages
PiperOrigin-RevId: 689514531
2024-10-24 14:20:46 -07:00
Parker Schuh
9500bd451a Fix float0 behavior inside shard_map transpose under scan.
PiperOrigin-RevId: 689512880
2024-10-24 14:15:40 -07:00
jax authors
0d68a2bf3b Merge pull request #24511 from mattjj:improve-concreteness-error-in-remat
PiperOrigin-RevId: 689488766
2024-10-24 13:05:36 -07:00
Matthew Johnson
4231128535 improve concreteness error message in remat 2024-10-24 18:13:42 +00:00
jax authors
afc78524e1 Remove silent data corruption runtime flags from persistent cache key.
These flags have no effect on the compiled executable, just the runtime execution.

PiperOrigin-RevId: 689442580
2024-10-24 10:59:44 -07:00
Yash Katariya
6c8e56f43f Finish 0.4.35 release by removing dead code
PiperOrigin-RevId: 689396609
2024-10-24 08:45:43 -07:00
George Necula
e5bbf3dca1 [jax2tf] Fixes a bad interaction between jax2tf.convert, TF, and call_tf.
Consider the use case when we call_tf a restored saved model that
includes parameters (hence functions closing over tf.Variable), and then
we jax2tf.convert it with native serialization, under tf.function (or
for saving to saved model).

The lowering for call_tf in presence of functions with captured inputs
requires looking up the tf.Variable and reading its value. This fails
with an error that `v.numpy()` is not allowd in graph mode. The fix
is to use `tf.init_scope()` to lift out of graph building mode, so that
we can read the value of the variables.
2024-10-24 17:41:32 +03:00
George Necula
e5f4be5564 [shape_poly] Expands support for random.choice
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.

The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.

Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
2024-10-24 17:20:09 +03:00
jax authors
644f881a51 Merge pull request #24490 from hawkinsp:searchsorted
PiperOrigin-RevId: 689364122
2024-10-24 06:56:32 -07:00
Sergei Lebedev
717467a82f [pallas] input_output_aliases now only include refs which have been written to
PiperOrigin-RevId: 689323778
2024-10-24 04:18:01 -07:00
Adam Paszke
bb2e2303d7 [Pallas:MGPU] Treat each warpgroup as a single logical thread.
As an extra minor change, we now disallow specifying the predicate when uniform is
unset, as that implies that we're going to use two different mechanisms to select
a single thread.

PiperOrigin-RevId: 689289365
2024-10-24 01:54:10 -07:00
ZincCat
bd9a10e4eb fix the wrong output of pallas attention kernel when q_len!=kv_len 2024-10-24 02:20:54 -04:00
Peter Hawkins
a7d711513c Perform searchsorted binary search using unsigned intermediate values.
Midpoint computation for a binary search should be performed unsigned, see https://research.google/blog/extra-extra-read-all-about-it-nearly-all-binary-searches-and-mergesorts-are-broken/

In addition, we can avoid the somewhat verbose floor_divide HLO since we know the values in question are positive.
2024-10-23 15:11:55 -04:00
Jake VanderPlas
9bf1516abe Improve docs for jnp.block 2024-10-23 11:37:19 -07:00
Jake VanderPlas
148f9d6559 Better docs for jnp.cov & jnp.corrcoef 2024-10-23 10:17:00 -07:00
jax authors
84cd3567b5 Avoid querying metadata query to check if it's GCE if TPU_SKIP_MDS_QUERY is set.
PiperOrigin-RevId: 689009215
2024-10-23 10:09:02 -07:00
Tzu-Wei Sung
11faf68018 [Pallas:TPU] Match lax.pow(float, int) behavior in Pallas.
Both math::PowF and Exp2Op require a floating point exponent so casting it to x.dtype for parity of lax.pow.

PiperOrigin-RevId: 688997089
2024-10-23 09:38:03 -07:00
jax authors
3c2c60e4d0 Merge pull request #24482 from jakevdp:unwrap-doc
PiperOrigin-RevId: 688995901
2024-10-23 09:34:31 -07:00
Adam Paszke
88d231a3f2 [Pallas] Allow core_map's mesh to discharge backend specific effects
Backends often have custom effectful primitives, but their effects do not extend
beyond the scope of a single kernel, so we should remove them in core_map's abstract eval.

PiperOrigin-RevId: 688990275
2024-10-23 09:16:17 -07:00
Adam Paszke
5b3b6e84db [Pallas:MGPU] Allow initializing accumulators with values in registers
This is useful to avoid unnecessary shared stores and fences in some kernels like
flash attention.

PiperOrigin-RevId: 688977199
2024-10-23 08:36:39 -07:00
Dan Foreman-Mackey
5ea6215436 Add test for jax2tf conversion of dot general with algorithm.
Fixes https://github.com/jax-ml/jax/issues/24236

To be fair, the fix was actually in https://github.com/openxla/xla/pull/18222, but this adds a test to JAX to confirm.

PiperOrigin-RevId: 688976685
2024-10-23 08:34:52 -07:00
Christos Perivolaropoulos
40c92c1f8c [pallas:mosaic_gpu] An extremely specific heuristic to allow swiglu.
PiperOrigin-RevId: 688973012
2024-10-23 08:24:49 -07:00