24618 Commits

Author SHA1 Message Date
Yash Katariya
39e4f7f2ce [sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
jax authors
ccfef7a549 Merge pull request #25424 from jakevdp:dedupe-broadcast
PiperOrigin-RevId: 705261094
2024-12-11 15:25:02 -08:00
Jake VanderPlas
c40780b957 internal: dedupe lax broadcasting logic 2024-12-11 15:03:39 -08:00
Gleb Pobudzey
e92ca9bbae Use boolean values for partial mask blocks in the splash attention kernel.
The values are guaranteed to be 0 or 1 since we create this array ourselves when processing the masks into a MaskInfo object.

PiperOrigin-RevId: 705252534
2024-12-11 14:59:30 -08:00
jax authors
b7af1eb905 Merge pull request #25381 from jakevdp:mypy-np22
PiperOrigin-RevId: 705248189
2024-12-11 14:47:37 -08:00
jax authors
e55bbc778a Merge pull request #25422 from jakevdp:broadcast-rank
PiperOrigin-RevId: 705245013
2024-12-11 14:38:24 -08:00
Jake VanderPlas
f4f4bf6a19 Fix type annotations for NumPy 2.2 2024-12-11 14:24:58 -08:00
jax authors
fb53971802 Merge pull request #25419 from jakevdp:lax-dtypes
PiperOrigin-RevId: 705230631
2024-12-11 13:59:38 -08:00
Jake VanderPlas
76d8b9c5a4 internal: simplify broadcast_shapes logic 2024-12-11 13:50:20 -08:00
jax authors
b8d2e9383a Update XLA dependency to use revision
209cbfa31a.

PiperOrigin-RevId: 705215149
2024-12-11 13:16:26 -08:00
jax authors
5e887b446b Merge pull request #25414 from jakevdp:finalize-deps
PiperOrigin-RevId: 705197214
2024-12-11 12:24:13 -08:00
jax authors
8c4d3db99a Merge pull request #23225 from snadampal:aarch64_jax
PiperOrigin-RevId: 705196103
2024-12-11 12:20:32 -08:00
Jake VanderPlas
65d2ca632c jax.lax: raise TypeError for mismatched dtypes 2024-12-11 11:59:10 -08:00
jax authors
5fe8bcc734 Merge pull request #25407 from ROCm:remove-cuda-import-in-plugin-upstream
PiperOrigin-RevId: 705168796
2024-12-11 11:07:19 -08:00
Jake VanderPlas
0fe97f97c7 jax.core: remove private API
PiperOrigin-RevId: 705155279
2024-12-11 10:31:34 -08:00
jax authors
6c45d3131d Merge pull request #25401 from momostein:type-hint-scipy-rotation-mul
PiperOrigin-RevId: 705153025
2024-12-11 10:25:36 -08:00
Gleb Pobudzey
20236f1083 Increase shard count after adding more tests
PiperOrigin-RevId: 705146601
2024-12-11 10:08:50 -08:00
Charles Hofer
8d42fa0b0b Remove cuda include from gpu plugin extension and BUILD 2024-12-11 11:55:51 -06:00
Jake VanderPlas
f858a71461 Finalize some deprecations in jax.core, jax.lib.xla_bridge, and jax.lib.xla_client. 2024-12-11 09:50:33 -08:00
jax authors
01206f839b Merge pull request #25395 from gnecula:poly_better_eq
PiperOrigin-RevId: 705105803
2024-12-11 07:51:40 -08:00
jax authors
98c405569a Merge pull request #25403 from gnecula:readme1
PiperOrigin-RevId: 705102635
2024-12-11 07:39:20 -08:00
Benjamin Chetioui
4ef7706abb [Mosaic GPU] Split layout inference and dialect lowering files and tests.
PiperOrigin-RevId: 705100503
2024-12-11 07:31:34 -08:00
jax authors
354bd52710 Merge pull request #25387 from jakevdp:more-core-deps
PiperOrigin-RevId: 705094013
2024-12-11 07:07:20 -08:00
Benjamin Chetioui
07a3515065 [Mosaic GPU] Add an initial skeleton for a layout inference pass.
Layouts are added as annotations on MLIR ops, using the `in_layouts` and
`out_layouts` attributes.

At this point, layout inference is done in two passes: one "backwards" pass
(root-to-parameters), and one "forward" pass (parameters-to-root).

Each pass goes through all the ops in the specified order, and infers a
possible layout from the layout information that is available. We expect to
need two passes because partial layout annotations may be provided on
intermediate nodes (e.g. `wgmma`), and a single pass from the root to the
parameters is therefore insufficient to properly annotate all the operations.

We do not perform any check as to whether the inferred layouts can be further
lowered correctly---meaning that the produced IR can possibly fail to lower
later.

Layouts are only inferred for ops involving at least one operand or result of
type `VectorType`/`RankedTensorType`.

When layouts can't be inferred for an op that should have them, we default to
annotating it with strided fragmented layouts.

PiperOrigin-RevId: 705092403
2024-12-11 07:01:06 -08:00
jax authors
b79dae8288 Merge pull request #25390 from jakevdp:matvec
PiperOrigin-RevId: 705077854
2024-12-11 06:00:35 -08:00
Ayaka
13ce51785d [Pallas] Remove grid=1 in tests
Remove `grid=1` in tests because it's the same as not specifying `grid`.

PiperOrigin-RevId: 705077047
2024-12-11 05:56:32 -08:00
George Necula
9f6cf62a48 Minor change in the README, remove "expect bugs" 2024-12-11 14:50:14 +01:00
jax authors
0d7eaeb5d8 Merge pull request #24805 from andportnoy:aportnoy/mosaic-gpu-cupti-profiler
PiperOrigin-RevId: 705071782
2024-12-11 05:29:10 -08:00
Brecht Ooms
257b033e8c Add Rotation return type hint to Rotation.__mul__()
Without this type hint, some tools (including PyCharm) infer the more
generic return type from typing.NamedTuple.
To improve user experience, I've added a narrower type hint.

However, the typing of this method is still 'flawed' as the only properly supported
input is another Rotation. This is a narrower input type and therefore
violates the Liskov substitution principle. Therefore I left the input
parameter untyped.

For more info:
https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
2024-12-11 12:48:02 +01:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
Dimitar (Mitko) Asenov
3d9c720d42 [Mosaic GPU] Automatically format the Mosaic GPU dialect test python code
This allows me to keep using the formatter going forward and not have to bother manually formatting code.

PiperOrigin-RevId: 705024602
2024-12-11 02:04:08 -08:00
George Necula
60f9da5d58 [shape_poly] Improve reasoning for >= in presence of == constraints.
Previously, an equality constraint was used only as a normalization
rule. This created a problem for constraints of the form "4*b=c",
because it would not allow proving that "b <= c" (since the
normalization of "4*b" kicks in only if "b" is multiplied by a
multiple of 4.

Now we add the equality constraints also in the inequality
reasoning state.
2024-12-11 10:51:49 +01:00
Dimitar (Mitko) Asenov
66f45d039f [Mosaic GPU] Add WGMMA to the Mosaic GPU MLIR Dialect.
The op API is still in flux so I'm leaving some of the verification code untested.

PiperOrigin-RevId: 705020066
2024-12-11 01:47:29 -08:00
jax authors
cfdac00e5e Merge pull request #25392 from carlosgmartin:add_nn_relu_grad_at_zero_test_update_paper_link
PiperOrigin-RevId: 704947960
2024-12-10 20:37:55 -08:00
Jake VanderPlas
59b9eefd06 jax.core: more API deprecations 2024-12-10 20:27:28 -08:00
Yash Katariya
41f490aef4 [sharding_in_types] Default axis_types to Auto for all axis_names if user does not set any AxisType. Also resolve some TODOs now that we have a way for user to set the mesh.
PiperOrigin-RevId: 704944255
2024-12-10 20:20:23 -08:00
Yash Katariya
b5e4fd161d [sharding_in_types] Enforce AxisTypes to always exist if set_mesh is used.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.

During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.

PiperOrigin-RevId: 704911253
2024-12-10 18:03:21 -08:00
carlosgmartin
08801147f1 Add test of relu grad at zero. Update paper links. 2024-12-10 19:39:47 -05:00
Ayaka
e88b578356 [Pallas TPU] Add WeirdOp to TPU dialect and add lowering for lax.is_finite
PiperOrigin-RevId: 704888940
2024-12-10 16:38:04 -08:00
jax authors
3ca9f14107 Merge pull request #25361 from Rifur13:regression
PiperOrigin-RevId: 704885039
2024-12-10 16:25:36 -08:00
Jake VanderPlas
f6d58761d1 jax.numpy: implement matvec & vecmat 2024-12-10 16:03:19 -08:00
Jacob Burnim
1c1a17e0f0 Only run tpu_all_gather_test on tpu_v5e_4x2
PiperOrigin-RevId: 704871583
2024-12-10 15:42:42 -08:00
jax authors
2ff90382d2 Update XLA dependency to use revision
ce56ae1529.

PiperOrigin-RevId: 704835739
2024-12-10 13:55:02 -08:00
Michael Hudgins
ac92aaaf3a Enable New bazel presubmits for pull requests.
At this time these workflows will currently be non-blocking for submission.

PiperOrigin-RevId: 704825599
2024-12-10 13:26:48 -08:00
Jake VanderPlas
e6d6c4ef8a Delete non-public API jax.lib.xla_bridge._backends
This is doubly non-public: nothing under `jax.lib` is public, and also the object itself has a preceding underscore. Therefore it is safe to remove (chex had referenced this previously, but that's now addressed in adaf1b2b75).

PiperOrigin-RevId: 704825268
2024-12-10 13:25:14 -08:00
Bixia Zheng
d4899f7b9b [jax:custom_partitioning] Make SdyShardingRule a user facing class.
Move the parsing of a sharding rule string to a free function
str_to_sdy_sharding_rule. Move the building of the MLIR sharding rule to a free
function sdy_sharding_rule_to_mlir.

PiperOrigin-RevId: 704818640
2024-12-10 13:05:43 -08:00
jax authors
9c32fe8fbf Merge pull request #25357 from jakevdp:core-deps
PiperOrigin-RevId: 704808153
2024-12-10 12:35:05 -08:00
Tzu-Wei Sung
e418e88321 [Pallas] Add non-square pl.dot test cases.
PiperOrigin-RevId: 704788500
2024-12-10 11:38:28 -08:00
Dan Foreman-Mackey
593143e17e Deduplicate some GPU plugin definition code.
The `jaxlib/cuda_plugin_extension.cc` and `jaxlib/rocm_plugin_extension.cc` files were nearly identical so this change consolidates the shared implementation into a single target.

PiperOrigin-RevId: 704785926
2024-12-10 11:32:06 -08:00
jax authors
210bd3080d Merge pull request #25378 from hawkinsp:optbarrier
PiperOrigin-RevId: 704783658
2024-12-10 11:24:15 -08:00