15404 Commits

Author SHA1 Message Date
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Adam Paszke
3630756e87 [Mosaic GPU] Use events as the default profiling method
JAX still supports old CUDA versions such as 12.0, where CUPTI leaks memory.

PiperOrigin-RevId: 705459909
2024-12-12 04:42:56 -08:00
jax authors
dda6b88864 Merge pull request #25425 from jax-ml:linearize-bugs-and-stuff
PiperOrigin-RevId: 705313000
2024-12-11 18:27:27 -08:00
Dougal
8fe8d241e8 Fixes to direct linearize
* Fix a bug in pjit linearization rule
  * Handle multiple results and zeros in fallback rule
  * Handle `has_aux`
  * Implement process_custom_vjp_call
2024-12-11 20:57:36 -05:00
Yash Katariya
39e4f7f2ce [sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
jax authors
ccfef7a549 Merge pull request #25424 from jakevdp:dedupe-broadcast
PiperOrigin-RevId: 705261094
2024-12-11 15:25:02 -08:00
Jake VanderPlas
c40780b957 internal: dedupe lax broadcasting logic 2024-12-11 15:03:39 -08:00
Gleb Pobudzey
e92ca9bbae Use boolean values for partial mask blocks in the splash attention kernel.
The values are guaranteed to be 0 or 1 since we create this array ourselves when processing the masks into a MaskInfo object.

PiperOrigin-RevId: 705252534
2024-12-11 14:59:30 -08:00
jax authors
b7af1eb905 Merge pull request #25381 from jakevdp:mypy-np22
PiperOrigin-RevId: 705248189
2024-12-11 14:47:37 -08:00
jax authors
e55bbc778a Merge pull request #25422 from jakevdp:broadcast-rank
PiperOrigin-RevId: 705245013
2024-12-11 14:38:24 -08:00
Jake VanderPlas
f4f4bf6a19 Fix type annotations for NumPy 2.2 2024-12-11 14:24:58 -08:00
jax authors
fb53971802 Merge pull request #25419 from jakevdp:lax-dtypes
PiperOrigin-RevId: 705230631
2024-12-11 13:59:38 -08:00
Jake VanderPlas
76d8b9c5a4 internal: simplify broadcast_shapes logic 2024-12-11 13:50:20 -08:00
jax authors
5e887b446b Merge pull request #25414 from jakevdp:finalize-deps
PiperOrigin-RevId: 705197214
2024-12-11 12:24:13 -08:00
Jake VanderPlas
65d2ca632c jax.lax: raise TypeError for mismatched dtypes 2024-12-11 11:59:10 -08:00
Jake VanderPlas
0fe97f97c7 jax.core: remove private API
PiperOrigin-RevId: 705155279
2024-12-11 10:31:34 -08:00
jax authors
6c45d3131d Merge pull request #25401 from momostein:type-hint-scipy-rotation-mul
PiperOrigin-RevId: 705153025
2024-12-11 10:25:36 -08:00
Jake VanderPlas
f858a71461 Finalize some deprecations in jax.core, jax.lib.xla_bridge, and jax.lib.xla_client. 2024-12-11 09:50:33 -08:00
jax authors
01206f839b Merge pull request #25395 from gnecula:poly_better_eq
PiperOrigin-RevId: 705105803
2024-12-11 07:51:40 -08:00
Benjamin Chetioui
4ef7706abb [Mosaic GPU] Split layout inference and dialect lowering files and tests.
PiperOrigin-RevId: 705100503
2024-12-11 07:31:34 -08:00
jax authors
354bd52710 Merge pull request #25387 from jakevdp:more-core-deps
PiperOrigin-RevId: 705094013
2024-12-11 07:07:20 -08:00
Benjamin Chetioui
07a3515065 [Mosaic GPU] Add an initial skeleton for a layout inference pass.
Layouts are added as annotations on MLIR ops, using the `in_layouts` and
`out_layouts` attributes.

At this point, layout inference is done in two passes: one "backwards" pass
(root-to-parameters), and one "forward" pass (parameters-to-root).

Each pass goes through all the ops in the specified order, and infers a
possible layout from the layout information that is available. We expect to
need two passes because partial layout annotations may be provided on
intermediate nodes (e.g. `wgmma`), and a single pass from the root to the
parameters is therefore insufficient to properly annotate all the operations.

We do not perform any check as to whether the inferred layouts can be further
lowered correctly---meaning that the produced IR can possibly fail to lower
later.

Layouts are only inferred for ops involving at least one operand or result of
type `VectorType`/`RankedTensorType`.

When layouts can't be inferred for an op that should have them, we default to
annotating it with strided fragmented layouts.

PiperOrigin-RevId: 705092403
2024-12-11 07:01:06 -08:00
jax authors
b79dae8288 Merge pull request #25390 from jakevdp:matvec
PiperOrigin-RevId: 705077854
2024-12-11 06:00:35 -08:00
jax authors
0d7eaeb5d8 Merge pull request #24805 from andportnoy:aportnoy/mosaic-gpu-cupti-profiler
PiperOrigin-RevId: 705071782
2024-12-11 05:29:10 -08:00
Brecht Ooms
257b033e8c Add Rotation return type hint to Rotation.__mul__()
Without this type hint, some tools (including PyCharm) infer the more
generic return type from typing.NamedTuple.
To improve user experience, I've added a narrower type hint.

However, the typing of this method is still 'flawed' as the only properly supported
input is another Rotation. This is a narrower input type and therefore
violates the Liskov substitution principle. Therefore I left the input
parameter untyped.

For more info:
https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
2024-12-11 12:48:02 +01:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
George Necula
60f9da5d58 [shape_poly] Improve reasoning for >= in presence of == constraints.
Previously, an equality constraint was used only as a normalization
rule. This created a problem for constraints of the form "4*b=c",
because it would not allow proving that "b <= c" (since the
normalization of "4*b" kicks in only if "b" is multiplied by a
multiple of 4.

Now we add the equality constraints also in the inequality
reasoning state.
2024-12-11 10:51:49 +01:00
jax authors
cfdac00e5e Merge pull request #25392 from carlosgmartin:add_nn_relu_grad_at_zero_test_update_paper_link
PiperOrigin-RevId: 704947960
2024-12-10 20:37:55 -08:00
Jake VanderPlas
59b9eefd06 jax.core: more API deprecations 2024-12-10 20:27:28 -08:00
Yash Katariya
41f490aef4 [sharding_in_types] Default axis_types to Auto for all axis_names if user does not set any AxisType. Also resolve some TODOs now that we have a way for user to set the mesh.
PiperOrigin-RevId: 704944255
2024-12-10 20:20:23 -08:00
Yash Katariya
b5e4fd161d [sharding_in_types] Enforce AxisTypes to always exist if set_mesh is used.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.

During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.

PiperOrigin-RevId: 704911253
2024-12-10 18:03:21 -08:00
carlosgmartin
08801147f1 Add test of relu grad at zero. Update paper links. 2024-12-10 19:39:47 -05:00
Ayaka
e88b578356 [Pallas TPU] Add WeirdOp to TPU dialect and add lowering for lax.is_finite
PiperOrigin-RevId: 704888940
2024-12-10 16:38:04 -08:00
Jake VanderPlas
f6d58761d1 jax.numpy: implement matvec & vecmat 2024-12-10 16:03:19 -08:00
Jake VanderPlas
e6d6c4ef8a Delete non-public API jax.lib.xla_bridge._backends
This is doubly non-public: nothing under `jax.lib` is public, and also the object itself has a preceding underscore. Therefore it is safe to remove (chex had referenced this previously, but that's now addressed in adaf1b2b75).

PiperOrigin-RevId: 704825268
2024-12-10 13:25:14 -08:00
Bixia Zheng
d4899f7b9b [jax:custom_partitioning] Make SdyShardingRule a user facing class.
Move the parsing of a sharding rule string to a free function
str_to_sdy_sharding_rule. Move the building of the MLIR sharding rule to a free
function sdy_sharding_rule_to_mlir.

PiperOrigin-RevId: 704818640
2024-12-10 13:05:43 -08:00
jax authors
9c32fe8fbf Merge pull request #25357 from jakevdp:core-deps
PiperOrigin-RevId: 704808153
2024-12-10 12:35:05 -08:00
jax authors
210bd3080d Merge pull request #25378 from hawkinsp:optbarrier
PiperOrigin-RevId: 704783658
2024-12-10 11:24:15 -08:00
Jake VanderPlas
6541a62099 jax.core: deprecate a number of APIs 2024-12-10 11:11:32 -08:00
jax authors
e6dfe8f380 [AutoPGLE] Share FDO profile even when compilation cache disabled.
PiperOrigin-RevId: 704757991
2024-12-10 10:23:42 -08:00
jax authors
6dbafed7bc Fix mypy failure
PiperOrigin-RevId: 704748889
2024-12-10 10:01:43 -08:00
jax authors
8813973d96 [AutoPGLE] Cleanup compiler code.
PiperOrigin-RevId: 704741308
2024-12-10 09:37:35 -08:00
Peter Hawkins
acae2f0546 Remove code in jax2tf for compatibility with TF 2.10 or earlier. 2024-12-10 15:18:59 +00:00
jax authors
263d4d1462 Merge pull request #25369 from jax-ml:mutable-arrays-ad
PiperOrigin-RevId: 704685653
2024-12-10 06:36:02 -08:00
jax authors
90de28cd63 Merge pull request #25335 from gnecula:export_doc_call
PiperOrigin-RevId: 704589764
2024-12-10 00:45:20 -08:00
Gunhyun Park
12c30578b2 Introduce lax.ragged_all_to_all primitive
This version emits a StableHLO custom call. The test outputs the following MLIR module:
```
module @jit_ragged_all_to_all {
  func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) {
    %0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32>
    return %0 : tensor<6xf32>
  }
}
```

For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above).

The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all.

PiperOrigin-RevId: 704550890
2024-12-09 22:19:40 -08:00
Yash Katariya
944d822ce6 Add a no-op batching rule for optimization_barrier_p
PiperOrigin-RevId: 704507586
2024-12-09 19:21:07 -08:00
Dougal
fc2edbfac8 Add a freeze primitive to delimit ref lifetimes for AD.
Also some basic AD through mutable_array/freeze.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-12-09 20:57:07 -05:00
Peter Hawkins
820f51dc53 Merge branch 'release/0.4.37' into main. 2024-12-09 20:21:43 -05:00
Peter Hawkins
ffb07cdadb Update versions for v0.4.37 release. 2024-12-09 15:39:59 -05:00