816 Commits

Author SHA1 Message Date
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
jax authors
aed9c6f149 Merge pull request #25969 from jakevdp:fix-util
PiperOrigin-RevId: 717104490
2025-01-18 18:02:43 -08:00
Yash Katariya
36daf36913 Add a sharding rule for reduce_precision_p and properly thread eqn.ctx in loops.py where we create pe.new_jaxpr_eqn's
PiperOrigin-RevId: 716849111
2025-01-17 17:31:24 -08:00
Jake VanderPlas
45a352041c internal: check integer overflow in lax.asarray 2025-01-17 14:38:13 -08:00
Yash Katariya
ce85b89884 [sharding_in_types] Error out for reshape for splits like this: (4, 6, 8) -> (4, 4, 2, 6)
PiperOrigin-RevId: 716653203
2025-01-17 06:58:29 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
Yash Katariya
c6b5ac5c7b [sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.

  `operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`

* Merging into 1 dimension only and all the merging dimensions should be unsharded.

  `operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`

* Split into singleton dimensions i.e. adding extra dims of size 1

  `operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`

* Merge singleton dimensions i.e. removing extra dims of size 1

  `operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`

* Identity reshape

  `operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`

These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.

PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
George Necula
f9dfe7f646 [better_errors] More cleanup 2025-01-15 10:22:29 +00:00
Yash Katariya
c72ed260fe [sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
2025-01-14 08:03:50 -08:00
Gunhyun Park
93ef0f13fe Clarify documentation of composites.
There were some confusion regarding how to properly add attributes to the op in https://github.com/jax-ml/jax/issues/25767.

PiperOrigin-RevId: 713726697
2025-01-09 10:54:54 -08:00
Yash Katariya
b2b38679e2 Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
2025-01-08 18:05:43 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
jax authors
44d67e1379 Merge pull request #25648 from hawkinsp:warnings3
PiperOrigin-RevId: 708415848
2024-12-20 13:41:43 -08:00
Jake VanderPlas
beee98ab4a Add int4/uint4 support to bitcast_convert_type 2024-12-20 12:45:24 -08:00
Peter Hawkins
59e5ce22d3 Avoid calls to warnings.catch_warnings in JAX core code.
warnings.catch_warnings is not thread-safe. However it is always used to avoid complex-to-real conversion warnings, which we can avoid in other ways.
2024-12-20 15:43:03 -05:00
Gunhyun Park
d206cc3b50 Add lax.composite primitive
A composite function can encapsulate an operation made up of other JAX functions. The semantics of the op is implemented by the `decomposition` function. For example, a `tangent` operation can be implemented as `sin(x) / cos(x)`.

This is what the HLO looks like for a tangent composite:
```
module @jit_my_tangent_composite {
  func.func public @main(%arg0: tensor<4xf64>) -> (tensor<4xf64>) {
    %0 = stablehlo.composite "my.tangent" %arg0 {decomposition = @my.tangent} : (tensor<4xf64>) -> tensor<4xf64>
    return %0 : tensor<4xf64>
  }
  func.func private @my.tangent(%arg0: tensor<4xf64>) -> tensor<4xf64> {
    %0 = stablehlo.sine %arg0 : tensor<4xf64>
    %1 = stablehlo.cosine %arg0 : tensor<4xf64>
    %2 = stablehlo.divide %0, %1 : tensor<4xf64>
    return %2 : tensor<4xf64>
  }
}
```

Similarly, this can scale to something like Attention. By preserving such an abstraction, it greatly simplifies pattern matching. Instead of matching the set of ops that represent Attention, the matcher can simply look for a uniquely identifying composite op like "MyAttention".

This is useful for preserving high level abstraction that would otherwise be lost during lowering. The hardware-aware compiler can recognize the single composite op and emit efficient code rather than pattern-matching a generic lowering which is then replaced with your own efficient lowering. And then the decomposition function can be DCE'd away. If the hardware does not have an efficient lowering, it can inline the `decomposition` which implements the semantics of the abstraction.

For more details on the API, refer to the documentation.

PiperOrigin-RevId: 707750633
2024-12-18 19:38:37 -08:00
Pearu Peterson
f592173c6c Use StableHLO acos and update complex acos accuracy tests. 2024-12-18 15:19:38 +02:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Peter Hawkins
0922feb2f5 Use a broadcasted gather in the sort JVP, rather than forming explicit iotas.
Use an unsigned index and promise that it is in bounds.
2024-12-13 09:23:34 -05:00
Jake VanderPlas
67b3413b96 Cleanup: replace lax._abstractify with core.get_aval 2024-12-12 14:08:17 -08:00
Yash Katariya
39e4f7f2ce [sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
Jake VanderPlas
c40780b957 internal: dedupe lax broadcasting logic 2024-12-11 15:03:39 -08:00
jax authors
e55bbc778a Merge pull request #25422 from jakevdp:broadcast-rank
PiperOrigin-RevId: 705245013
2024-12-11 14:38:24 -08:00
Jake VanderPlas
76d8b9c5a4 internal: simplify broadcast_shapes logic 2024-12-11 13:50:20 -08:00
Jake VanderPlas
65d2ca632c jax.lax: raise TypeError for mismatched dtypes 2024-12-11 11:59:10 -08:00
jax authors
263d4d1462 Merge pull request #25369 from jax-ml:mutable-arrays-ad
PiperOrigin-RevId: 704685653
2024-12-10 06:36:02 -08:00
Yash Katariya
944d822ce6 Add a no-op batching rule for optimization_barrier_p
PiperOrigin-RevId: 704507586
2024-12-09 19:21:07 -08:00
Dougal
fc2edbfac8 Add a freeze primitive to delimit ref lifetimes for AD.
Also some basic AD through mutable_array/freeze.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-12-09 20:57:07 -05:00
Dan Foreman-Mackey
1f4d184ac8 Temporarily allow bfloat16 dot algorithms on CPU.
Since XLA:CPU doesn't (yet!) support explicit algorithms for controlling the precision of dot products we have a check in JAX that fails when a non-trivial algorithm is specified on CPU. In order to support downstream use cases, this change allows some bfloat16 algorithms to pass through. XLA:CPU "emulates" these algorithms using `F32_F32_F32` with the appropriate casting, so that means that CPU numerics will be different than on other platforms with explicit algorithm support, but it is useful to be able to use these algorithms with the correct input and output casting without requiring platform dependent logic in user code.

PiperOrigin-RevId: 703834889
2024-12-07 11:14:09 -08:00
Chris Jones
569c2a3c6c Reverts 73962b740890a728295fa09f515dcf96cb820822
PiperOrigin-RevId: 703100851
2024-12-05 07:01:36 -08:00
jax authors
73962b7408 Reverts a54319ec1886ed920d50cacf10e147a743888464
PiperOrigin-RevId: 702405512
2024-12-03 11:15:14 -08:00
Chris Jones
a54319ec18 [jax] Make DotAlgorithmPreset.supported_output_types a function of the input types.
PiperOrigin-RevId: 702342849
2024-12-03 08:05:26 -08:00
Chris Jones
abf8f43007 [jax] Improve naming of DotAlgorithmPreset properties and simplify return types.
PiperOrigin-RevId: 702317395
2024-12-03 06:26:32 -08:00
Chris Jones
5d5b06cf8a [jax] Canonicalize dtypes when checking if dtypes present in target dtypes list.
PiperOrigin-RevId: 701961663
2024-12-02 07:32:11 -08:00
Yash Katariya
59e13f8114 Add sharding argument to reshape since it also takes a shape argument for the output shape
PiperOrigin-RevId: 700163883
2024-11-25 18:16:08 -08:00
jax authors
030ee4a1b2 Merge pull request #25070 from jax-ml:pjit-lin-rule
PiperOrigin-RevId: 699304829
2024-11-22 15:25:58 -08:00
Dougal
b1d1dcf607 Add linearization rule for pjit_p 2024-11-22 14:24:46 -08:00
Yash Katariya
21f8885a9e [sharding_in_types] Make argmax and argmin work with sharding_in_types. This also requires adding reduce_p sharding rule
PiperOrigin-RevId: 699244204
2024-11-22 12:00:22 -08:00
Yash Katariya
7635605262 Use with_spec where possible to clean up the code a bit
PiperOrigin-RevId: 699226058
2024-11-22 11:01:58 -08:00
Yash Katariya
355589f32b [sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here
* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path)

* Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager.

* Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them.

* scan only allows `xs` where the 0th dim is full replicated i.e. None.

PiperOrigin-RevId: 699014167
2024-11-21 20:13:23 -08:00
Dougal
170718c8d4 Change signature of linearization rules.
Give the rule the nonzero tangent pattern up-front. This is needed to make a
linearization rule for pjit_p. Also make the rules return the nonzero tangents
out, an explicit residual, and a closed tangent function. Add a rule for sin_p
to test it out. We still need to figure out how to avoid having to precompute
`cos(x)`. I think we need to update our backward pass code.
2024-11-21 19:03:42 -08:00
Yash Katariya
6568713a04 [sharding_in_types] Add concatenate_p support
PiperOrigin-RevId: 698621325
2024-11-20 20:12:44 -08:00
Yash Katariya
840cf3f7d2 [sharding_in_types] Add pad_p support to sharding_in_types to handle transpose to slice correctly.
PiperOrigin-RevId: 698573396
2024-11-20 17:14:22 -08:00
jax authors
f39392eaf4 Merge pull request #25020 from jakevdp:lax-pad-validation
PiperOrigin-RevId: 698568589
2024-11-20 16:55:39 -08:00
Jake VanderPlas
17825882d2 jax.lax.pad: improve input validation 2024-11-20 16:21:45 -08:00
Jake VanderPlas
2699e9507e DOC: add examples for jax.lax.pad 2024-11-20 15:13:14 -08:00
Yash Katariya
9b94180846 [sharding_in_types] Add slice_p and squeeze_p sharding rule to make flash attention work in backward pass
For `slice_p`'s sharding rule, I error out if the operand dim is sharded and the output dim is not divisible by that axis size.

I am working on a design to make JAX support uneven sharding at the top level after which slice_p's sharding rule can just `return operand.sharding`. Another option is to add `out_sharding` to `slice` but after uneven sharding support lands, it won't be necessary.

PiperOrigin-RevId: 698522980
2024-11-20 14:31:07 -08:00
Chris Jones
1e9e85a39e Simplify handling of DotAlgorithmPreset output types.
Create a clear distinction between the type used for accumulation and possible output types.

PiperOrigin-RevId: 698399447
2024-11-20 08:26:44 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00