25441 Commits

Author SHA1 Message Date
Bill Varcho
0abd9538ce [JAX] disable flaky parameter permutations for sparse_bcoo_bcsr test.
PiperOrigin-RevId: 722832212
2025-02-03 16:02:06 -08:00
Sergei Lebedev
f58207a28d [pallas:triton] Fixed dispatch tablee for lax.pow_p
PiperOrigin-RevId: 722817510
2025-02-03 15:17:58 -08:00
Kanglan Tang
59a3552ae6 Remove portpicker for free threaded python 3.13t in test-requirements.txt
PiperOrigin-RevId: 722776783
2025-02-03 13:30:01 -08:00
Sergei Lebedev
7929cd8410 [pallas:triton] The lowering now uses PTX instead of Triton IR
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

See #25196.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 722773884
2025-02-03 13:21:40 -08:00
jax authors
2c10a65b73 Merge pull request #25655 from carlosgmartin:simplify_random_orthogonal
PiperOrigin-RevId: 722770603
2025-02-03 13:12:09 -08:00
jax authors
17d0b86c7c Merge pull request #26275 from dfm:effects-in-custom-linear-solve
PiperOrigin-RevId: 722751986
2025-02-03 12:18:55 -08:00
carlosgmartin
c478f44e9d Simplify implementation of random.orthogonal. 2025-02-03 15:02:17 -05:00
jax authors
95535df13b Merge pull request #25688 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 722741835
2025-02-03 11:52:43 -08:00
jax authors
40d35b4219 Merge pull request #26277 from justinjfu:sourcemap_windows_fix
PiperOrigin-RevId: 722736034
2025-02-03 11:37:27 -08:00
jax authors
a5d2d89b4d Merge pull request #26246 from jakevdp:lax-docs
PiperOrigin-RevId: 722735618
2025-02-03 11:35:37 -08:00
Sergei Lebedev
bf6489ff5b [pallas:triton] Fallback lowering rules for math functions now use general dtypes
Previously, it was necessary to list all dtypes explicitly, which is why
we had separate fallback rules for float16 and bfloat16 for some functions.

PiperOrigin-RevId: 722729554
2025-02-03 11:21:11 -08:00
Sergei Lebedev
2d7e4ab2dc [mosaic_gpu] LayoutTest now correctly resets the value of MOSAIC_GPU_DUMP_SASS
PiperOrigin-RevId: 722711341
2025-02-03 10:32:14 -08:00
Jake VanderPlas
49c4020f0a jax.lax: improve docs for floor, ceil, round. 2025-02-03 10:19:22 -08:00
Justin Fu
6d7b03572c Format sourcemap directory names to work on windows 2025-02-03 09:39:56 -08:00
jax authors
7e353913f2 Merge pull request #26262 from gnecula:debug_info_one
PiperOrigin-RevId: 722684417
2025-02-03 09:17:13 -08:00
jax authors
aa64372b81 Merge pull request #26261 from vfdev-5:use-tsan-numpy-in-tsan-ci-2
PiperOrigin-RevId: 722683976
2025-02-03 09:15:29 -08:00
jax authors
12c76cdeaa Merge pull request #25474 from jaro-sevcik:compilation-cache-mock-doc
PiperOrigin-RevId: 722681339
2025-02-03 09:07:31 -08:00
Dan Foreman-Mackey
d42e3650d0 Handle effects in lax.custom_linear_solve. 2025-02-03 11:14:48 -05:00
jax authors
7164c6ba3e Merge pull request #25812 from Cjkkkk:segment_ids
PiperOrigin-RevId: 722650439
2025-02-03 07:28:25 -08:00
jax authors
49821e81de Update XLA dependency to use revision
b5db81467f.

PiperOrigin-RevId: 722632959
2025-02-03 06:21:42 -08:00
Dan Foreman-Mackey
28afd25259 Add FFI example demonstrating the use of XLA's FFI state.
Support for this was added in JAX v0.5.0.

PiperOrigin-RevId: 722597704
2025-02-03 04:06:10 -08:00
vfdev-5
476b2398ff Add TSAN numpy step 2025-02-03 11:40:39 +01:00
cjkkkk
8c4d6d6903 fix lint 2025-02-03 06:09:05 +00:00
Parker Schuh
cb188a0cb1 Reject invalid None in jax.NamedSharding(spec=None).
PiperOrigin-RevId: 722500631
2025-02-02 21:29:33 -08:00
Christos Perivolaropoulos
b48d15d788 [pallas_mgpu] For loops can have **non-ref** accumulators for carries.
The user has access only to accumulator references and they can't pass them as caries to loops. However when they are discharged these accumulators become values and become part of the carry. Before this CL this would surprise the loop lowering code.

This was never a problem for pallas mgpu until we added pipelining loops instead of sequential bloc axes.

PiperOrigin-RevId: 722495749
2025-02-02 21:03:26 -08:00
Parker Schuh
da97ee2591 Stop passing None into jax.NamedSharding in preparation for followup which bans passing None in (in favor of PartitionSpec()
PiperOrigin-RevId: 722477002
2025-02-02 19:35:33 -08:00
Peter Hawkins
58d8a97f5e Reverts 3f737588fd49fcac36d6bd00cfdabae07df1fbb2
PiperOrigin-RevId: 722457330
2025-02-02 17:52:39 -08:00
jax authors
57fa37214c Merge pull request #26243 from jakevdp:einsum-asarray
PiperOrigin-RevId: 722455518
2025-02-02 17:42:47 -08:00
jax authors
af84143e61 Update XLA dependency to use revision
492a921843.

PiperOrigin-RevId: 722343063
2025-02-02 06:24:33 -08:00
George Necula
c70de6deed [better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
2025-02-02 06:23:03 +02:00
jax authors
de48ce2a4c Merge pull request #26174 from skye:cpu_configs
PiperOrigin-RevId: 722199229
2025-02-01 16:20:16 -08:00
Christos Perivolaropoulos
b23f8f414b [pallas/pallas_mgpu] Discharging run_scoped should not be discharging the intermediates
When we do run_scoped[jaxpr, R1,R2], it can't be assumed that references
corresponding to R1 and R2 can be safely discharged. Sometimes they can (eg
Accumulator) but sometimes they can't (eg SMEM scratch). It should be up to the
lowering rule to do such discharging.

This further means that during lowering there is no guarantee that the
references will not be used/returned by nested scoped blocks so we also remove
that check.

PiperOrigin-RevId: 722137352
2025-02-01 09:37:03 -08:00
Christos Perivolaropoulos
8649132d86 [pallas] Support DMA start partial discharge and run_scoped() does its own partial discharge.
This CL lays the ground for a future CL that makes run_scoped discharge to not request the discharge of the temporary buffers it creates. This causes issues becausa

a) dma_start can't discharge some but not all its references
b) run_scoped() lowering depends on run_scoped discharge to remove the run_scoped operation (or it goes in an infinite loop).

PiperOrigin-RevId: 722126566
2025-02-01 08:23:23 -08:00
jax authors
eb04fcbe5a Update XLA dependency to use revision
05e4f40276.

PiperOrigin-RevId: 722112325
2025-02-01 06:52:02 -08:00
Jevin Jiang
ed952c8e65 [Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector.
PiperOrigin-RevId: 722015152
2025-01-31 21:27:08 -08:00
Jevin Jiang
d8b9211359 [Mosaic TPU] Support dynamic gather along axis 0 or 1 for 32-bit vreg-sized vector.
PiperOrigin-RevId: 721980453
2025-01-31 18:47:25 -08:00
Gunhyun Park
c4e176328f Move ragged_all_to_all test under appropriate test file
PiperOrigin-RevId: 721947980
2025-01-31 16:44:04 -08:00
carlosgmartin
32411a430f Add jax.random.multinomial. 2025-01-31 18:45:55 -05:00
jax authors
872e6c0ec4 Merge pull request #25766 from carlosgmartin:nn_initializers_variance_scaling_mode_fan_geo_avg
PiperOrigin-RevId: 721928532
2025-01-31 15:41:50 -08:00
jax authors
a9f4dd7182 Merge pull request #26249 from jakevdp:fix-sterling
PiperOrigin-RevId: 721922732
2025-01-31 15:26:37 -08:00
carlosgmartin
96d3447e89 Add mode='fan_geo_avg' to nn.initializers.variance_scaling. 2025-01-31 17:52:22 -05:00
Emily Fertig
3b2410f77c Reverts bb951136e9b91a584bb422119ada76cc69c86024
PiperOrigin-RevId: 721908669
2025-01-31 14:42:22 -08:00
jax authors
3f737588fd Merge pull request #26195 from vfdev-5:use-tsan-numpy-in-tsan-ci
PiperOrigin-RevId: 721899249
2025-01-31 14:13:14 -08:00
Jake VanderPlas
216bd9a6cc Fix dtype issue in stirling approximation 2025-01-31 14:13:02 -08:00
jax authors
0bf8462c88 Merge pull request #26247 from hawkinsp:tsan2
PiperOrigin-RevId: 721895954
2025-01-31 14:04:42 -08:00
Peter Hawkins
d4c6657888 Add additional tsan suppressions for races.
These were found in https://github.com/jax-ml/jax/actions/runs/13072427817/job/36476819907 and reported in https://github.com/python/cpython/issues/129533 and https://github.com/python/cpython/issues/128714.
2025-01-31 13:43:08 -08:00
jax authors
0ef2ccfdb4 Merge pull request #26238 from hawkinsp:coretest
PiperOrigin-RevId: 721884935
2025-01-31 13:33:04 -08:00
jax authors
b039f976d5 Merge pull request #26239 from jakevdp:rotation-test
PiperOrigin-RevId: 721884090
2025-01-31 13:30:43 -08:00
jax authors
70aed64a15 Merge pull request #26245 from jakevdp:fix-line
PiperOrigin-RevId: 721879951
2025-01-31 13:16:51 -08:00
jax authors
cb79ff4d85 Merge pull request #26194 from jax-ml:fix-dist-init-runtime-error
PiperOrigin-RevId: 721878712
2025-01-31 13:12:46 -08:00