15738 Commits

Author SHA1 Message Date
David Boetius
6e9a34f791
Move _reduce_window docstring to public func lax.reduce_window. 2025-01-09 13:31:48 +01:00
Jake VanderPlas
640cb009f1 bazel visibility change
PiperOrigin-RevId: 713488528
2025-01-08 18:34:10 -08:00
Yash Katariya
b2b38679e2 Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
2025-01-08 18:05:43 -08:00
Yash Katariya
fb832afc00 Respect the original memory kind on reshape, transpose and replicate methods of PositionalSharding. Fixes https://github.com/jax-ml/jax/issues/25769
PiperOrigin-RevId: 713446871
2025-01-08 16:03:03 -08:00
Matthew Johnson
f0392a1535 fix grad(logsumexp) to produce 0s where where is False 2025-01-08 23:38:06 +00:00
Bart Chrzaszcz
cbcc883ea3 #sdy add repr for Sdy ArraySharding and DimSharding
PiperOrigin-RevId: 713422071
2025-01-08 14:41:41 -08:00
Peter Hawkins
e20523c2e3 Make api_test.py work when test cases are run using multiple threads.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.

PiperOrigin-RevId: 713411171
2025-01-08 14:09:07 -08:00
Bixia Zheng
c4ac0dd6bd Implement the extension to the custom_partitioning API.
Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.

Extend custom_partitioner tests in  pjit_test.py for Shardy sharding rule.

PiperOrigin-RevId: 713399604
2025-01-08 13:34:47 -08:00
Yash Katariya
b3833dc705 Match the behavior on single host wrt multi-host if tiled=False. Fixes https://github.com/jax-ml/jax/issues/25783
PiperOrigin-RevId: 713398173
2025-01-08 13:31:35 -08:00
Justin Fu
d99a637d8b [Mosaic GPU] Allow multiple indexing on refs
PiperOrigin-RevId: 713355813
2025-01-08 11:21:19 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Sharad Vikram
c1a60c676a [Pallas] Add empty/empty_like helper functions
PiperOrigin-RevId: 713344151
2025-01-08 10:49:11 -08:00
Bart Chrzaszcz
5c097c8f62 #sdy Move Shardy mesh lift inlining pass after verification.
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example
```
File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module
    pipeline.run(ctx.module.operation)
MLIRError: Failure while executing pass pipeline:
error:
...
'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2
...
see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64>
```
PiperOrigin-RevId: 713314555
2025-01-08 09:17:54 -08:00
Peter Hawkins
0389d617c8 Add a unittest test extension that runs test cases in parallel using threads.
This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes.

This approach is broadly inspired by a6d205dd4c/testtools/testsuite.py (L113) and by unittest-ft.

We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool.

We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib.

PiperOrigin-RevId: 713312937
2025-01-08 09:11:47 -08:00
Peter Hawkins
3fa557289a Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
2025-01-08 08:14:50 -08:00
Sergei Lebedev
f1f98afee8 [pallas:mosaic_gpu] Fix the tests following the changes to pl.core_map
PiperOrigin-RevId: 713283207
2025-01-08 07:24:08 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Adam Paszke
f96339be1e [Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6
This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.

PiperOrigin-RevId: 713270421
2025-01-08 06:30:36 -08:00
jax authors
4718121efe Merge pull request #25754 from andportnoy:patch-4
PiperOrigin-RevId: 713222111
2025-01-08 02:57:20 -08:00
Sergei Lebedev
90201ce2b7 Removed leftover mentions of xmap from the code
PiperOrigin-RevId: 713202387
2025-01-08 01:39:38 -08:00
jax authors
1bd781d992 Add JAX events that have time spans, not only durations.
Log such events for log_elapsed_time.

The rationale for not replacing durations with it is that it appears that
record_event_duration_secs() is widely used outside of the code of JAX itself.

PiperOrigin-RevId: 713167192
2025-01-07 23:08:30 -08:00
Yash Katariya
755d6cdad8 [sharding_in_types] Aval sharding under full auto mode should contain None and not UNCONSTRAINED because axis_types + pspec give the full picture.
PiperOrigin-RevId: 713105375
2025-01-07 18:04:20 -08:00
Sharad Vikram
7be127f23c [Pallas] Improvements to core_map
PiperOrigin-RevId: 713075852
2025-01-07 16:18:30 -08:00
Zixuan Jiang
64c0f62ec4 Sort manual axes when lowering jax.shard_map to sdy.manual_computation, which ensures the determinism in the generated sdy.manual_computation.
PiperOrigin-RevId: 712973327
2025-01-07 11:02:55 -08:00
Justin Fu
8c9a539405 [Pallas] Fix pallas_call lowering mutating compiler params during Triton lowering.
Addresses: https://github.com/jax-ml/jax/issues/25714
PiperOrigin-RevId: 712930760
2025-01-07 09:01:51 -08:00
jax authors
4023810565 [AutoPGLE] FIx PGLE kokoro test failures.
PiperOrigin-RevId: 712930537
2025-01-07 08:59:59 -08:00
George Necula
fdb6af82d2 Clean up backend_or_name vs. platforms in lowering code.
It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.

However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).

PiperOrigin-RevId: 712926140
2025-01-07 08:42:57 -08:00
Andrey Portnoy
5b80892169
[Mosaic GPU] Use num_q_heads=2 in flash_attention.py
Previously with 4 heads the reference function `ref` would allocate 32 GiB since it materializes large intermediate tensors. That causes CI on an 80GB H100 to run out of memory when 4 tests run in parallel. `num_q_heads=2` allows us to test multiple heads while cutting memory in half.
2025-01-07 10:31:56 -05:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
jax authors
853af56007 Merge pull request #25748 from shoyer:divmod
PiperOrigin-RevId: 712864349
2025-01-07 04:44:23 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Stephan Hoyer
7fb68cac20 Fix type signature for __divmod__ 2025-01-07 00:24:24 -08:00
George Necula
bc3306c8bc [shape_poly] Improve threefry with symbolic shapes
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.

Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.

We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
2025-01-07 09:10:04 +02:00
Yash Katariya
23eaf2160a Make inspect_array_sharding work without mesh context manager too.
PiperOrigin-RevId: 712702329
2025-01-06 17:32:15 -08:00
jax authors
b304b9efd5 Merge pull request #25740 from jakevdp:remove-array-api
PiperOrigin-RevId: 712689888
2025-01-06 16:32:54 -08:00
Jake VanderPlas
c7b0d681bd Remove deprecated jax.experimental.array_api 2025-01-06 15:19:02 -08:00
Parker Schuh
b49ba6553c Remove the need for check_rep for with_sharding_constraint.
PiperOrigin-RevId: 712630197
2025-01-06 12:59:22 -08:00
jax authors
52cc5c7f05 Merge pull request #25214 from jakevdp:einsum-optimize
PiperOrigin-RevId: 712603103
2025-01-06 11:37:54 -08:00
jax authors
634b45bf00 Merge pull request #25699 from yliu120:fix_iota
PiperOrigin-RevId: 712594991
2025-01-06 11:13:28 -08:00
Jake VanderPlas
2f7204fff6 jnp.einsum: default to optimize='auto' 2025-01-06 11:02:31 -08:00
John QiangZhang
c39e38fe5a bazel: export serialization.fbs for downstream usage
PiperOrigin-RevId: 712587802
2025-01-06 10:57:35 -08:00
Jake VanderPlas
245a13a329 Deprecate scipy.special.lpmn & lpmn_values 2025-01-06 09:31:15 -08:00
Mark Sandler
6c87bf389f Fixes tril/triu comments (they were flipped)
PiperOrigin-RevId: 712544847
2025-01-06 08:55:11 -08:00
Yunlong Liu
3ff000ee3e fix the degenerated case 2025-01-06 16:08:07 +00:00
George Necula
e87a2a5929 [shape_poly] Remove old non_negative support.
This was deprecated in January 2024, replaced by
`core_max_dim(..., 0)`.

PiperOrigin-RevId: 712523579
2025-01-06 07:36:11 -08:00
jax authors
54fd738ecb Add SMEM as a supported Pallas output memory space.
PiperOrigin-RevId: 712144883
2025-01-04 19:33:18 -08:00
Jake VanderPlas
330606320a jax.debug.print: respect local np.printoptions 2025-01-02 16:10:54 -08:00
Zac Mustin
df36c29803 Compute cost-analysis on only one HLO module.
There was historically a goal to support multiple HLOs in an executable, but this work was never finished and is no longer planned so we don't need this support.

This will soon enable us to return only a dict, instead of a list of dicts with only one item.

PiperOrigin-RevId: 711477481
2025-01-02 11:24:52 -08:00
jax authors
82001ed5b3 Merge pull request #25706 from pearu:pearu/log10-large
PiperOrigin-RevId: 711411578
2025-01-02 06:50:54 -08:00
Adam Paszke
dbe9ccd6dc Reverts 83e60a9697ec20023f4e11169edf64e910b93031
PiperOrigin-RevId: 711403091
2025-01-02 06:04:14 -08:00