578 Commits

Author SHA1 Message Date
Praveen Narayanan
b6d4fe5387 Define lax.ragged_dot_general and express lax.ragged_dot in terms of it.
PiperOrigin-RevId: 735471245
2025-03-10 12:25:22 -07:00
Yash Katariya
2d01df760b [sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent

* canonicalization does not happen for avals on an empty mesh

* jax.jit does not set abstract mesh context anymore before tracing

* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode

* Even if use_mesh is not used in explicit sharding mode, computation follows data works!

* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)

* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.

As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.

PiperOrigin-RevId: 726097292
2025-02-12 10:03:01 -08:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Gunhyun Park
6b19bb2091 Allow composites to provide default kwargs with None value
The current behavior will crash upon trying to convert NoneType to an mlir attribute. This allows a composite to have optional attributes that can be omitted when it's not provided. This behavior is similar to how default values in MLIR is not shown in the IR.

PiperOrigin-RevId: 725786442
2025-02-11 15:05:50 -08:00
Kevin Gleason
319f7f5a2d [StableHLO] Allow composites to use dtype values in attributes
PiperOrigin-RevId: 725305384
2025-02-10 12:21:56 -08:00
Gunhyun Park
2828bce2e6 Add check to lax.composite to prevent DynamicJaxprTracer type errors.
There's some confusion on whether jax arrays can be used inside the attributes, so I made the error more explicit.

PiperOrigin-RevId: 724316766
2025-02-07 06:02:24 -08:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
Gunhyun Park
c4e176328f Move ragged_all_to_all test under appropriate test file
PiperOrigin-RevId: 721947980
2025-01-31 16:44:04 -08:00
Gunhyun Park
20555f63da Lower np.ndarray to DenseElementsAttr instead of ArrayAttr.
PiperOrigin-RevId: 721833949
2025-01-31 11:06:06 -08:00
Gunhyun Park
809e1133c8 Add support for axis_name and axis_index_groups to lax.ragged_all_to_all
PiperOrigin-RevId: 720738861
2025-01-28 16:02:03 -08:00
Dan Foreman-Mackey
aa8c0010e2 Add support for constants in the decomposition of lax.composites.
This change adds support for including "consts" (i.e. closed-over arrays) in the body of the decomposition definition for a `lax.composite` op. The caveat here is that, since the signature of the decomposition must match the composite itself, the values of any consts must be known when lowering so that they can be inlined into the decomposition's HLO. Therefore, there is no support for closing over tracers. Since `lax.composite` doesn't support most transformations anyways, this typically isn't going to be a major limitation except with `jax.jit` (as demonstrated in the tests).

PiperOrigin-RevId: 718048021
2025-01-21 13:28:42 -08:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Pearu Peterson
50670bd907 Fix log10 and log2 for large inputs. 2025-01-01 12:45:39 +02:00
Gunhyun Park
38747a7a5d Move ragged tests under a new class.
PiperOrigin-RevId: 708811348
2024-12-22 07:50:19 -08:00
Jake VanderPlas
beee98ab4a Add int4/uint4 support to bitcast_convert_type 2024-12-20 12:45:24 -08:00
Adam Paszke
ad00ec1dc9 [Mosaic TPU] Guard tests for new features by the libtpu version
PiperOrigin-RevId: 707875450
2024-12-19 05:04:09 -08:00
Gunhyun Park
d206cc3b50 Add lax.composite primitive
A composite function can encapsulate an operation made up of other JAX functions. The semantics of the op is implemented by the `decomposition` function. For example, a `tangent` operation can be implemented as `sin(x) / cos(x)`.

This is what the HLO looks like for a tangent composite:
```
module @jit_my_tangent_composite {
  func.func public @main(%arg0: tensor<4xf64>) -> (tensor<4xf64>) {
    %0 = stablehlo.composite "my.tangent" %arg0 {decomposition = @my.tangent} : (tensor<4xf64>) -> tensor<4xf64>
    return %0 : tensor<4xf64>
  }
  func.func private @my.tangent(%arg0: tensor<4xf64>) -> tensor<4xf64> {
    %0 = stablehlo.sine %arg0 : tensor<4xf64>
    %1 = stablehlo.cosine %arg0 : tensor<4xf64>
    %2 = stablehlo.divide %0, %1 : tensor<4xf64>
    return %2 : tensor<4xf64>
  }
}
```

Similarly, this can scale to something like Attention. By preserving such an abstraction, it greatly simplifies pattern matching. Instead of matching the set of ops that represent Attention, the matcher can simply look for a uniquely identifying composite op like "MyAttention".

This is useful for preserving high level abstraction that would otherwise be lost during lowering. The hardware-aware compiler can recognize the single composite op and emit efficient code rather than pattern-matching a generic lowering which is then replaced with your own efficient lowering. And then the decomposition function can be DCE'd away. If the hardware does not have an efficient lowering, it can inline the `decomposition` which implements the semantics of the abstraction.

For more details on the API, refer to the documentation.

PiperOrigin-RevId: 707750633
2024-12-18 19:38:37 -08:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
jax authors
5beb4794b7 Merge pull request #23830 from pearu:pearu/acos
PiperOrigin-RevId: 707525735
2024-12-18 06:17:59 -08:00
Pearu Peterson
f592173c6c Use StableHLO acos and update complex acos accuracy tests. 2024-12-18 15:19:38 +02:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
jax authors
0fa541972e Merge pull request #25456 from jakevdp:xla-abstractify
PiperOrigin-RevId: 707175097
2024-12-17 11:13:18 -08:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
Oleg Shyshkov
6d82a6fc90 Allow lax.ragged_all_to_all input and output operands to have different ragged dimension sizes.
We need to guaranty that the outermost dimension of the output is big enough to fit all received elements, but it's not necessary for input and output outermost dimensions to be exactly equal.

PiperOrigin-RevId: 707011916
2024-12-17 02:20:10 -08:00
Peter Hawkins
6548caf239 Relax test tolerance for complex128 pow in lax_test.py.
This is failing in CI in some CPU configurations.

PiperOrigin-RevId: 705558897
2024-12-12 10:50:44 -08:00
Jake VanderPlas
65d2ca632c jax.lax: raise TypeError for mismatched dtypes 2024-12-11 11:59:10 -08:00
Gunhyun Park
12c30578b2 Introduce lax.ragged_all_to_all primitive
This version emits a StableHLO custom call. The test outputs the following MLIR module:
```
module @jit_ragged_all_to_all {
  func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) {
    %0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32>
    return %0 : tensor<6xf32>
  }
}
```

For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above).

The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all.

PiperOrigin-RevId: 704550890
2024-12-09 22:19:40 -08:00
Yash Katariya
944d822ce6 Add a no-op batching rule for optimization_barrier_p
PiperOrigin-RevId: 704507586
2024-12-09 19:21:07 -08:00
Dan Foreman-Mackey
1f4d184ac8 Temporarily allow bfloat16 dot algorithms on CPU.
Since XLA:CPU doesn't (yet!) support explicit algorithms for controlling the precision of dot products we have a check in JAX that fails when a non-trivial algorithm is specified on CPU. In order to support downstream use cases, this change allows some bfloat16 algorithms to pass through. XLA:CPU "emulates" these algorithms using `F32_F32_F32` with the appropriate casting, so that means that CPU numerics will be different than on other platforms with explicit algorithm support, but it is useful to be able to use these algorithms with the correct input and output casting without requiring platform dependent logic in user code.

PiperOrigin-RevId: 703834889
2024-12-07 11:14:09 -08:00
Pearu Peterson
504c738781 Use next to tiny as smallest floating point value on Mac ARM 2024-11-26 16:50:49 +00:00
Jake VanderPlas
17825882d2 jax.lax.pad: improve input validation 2024-11-20 16:21:45 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
Peter Hawkins
461a2507f8 Disable some complex function accuracy tests that fail on Mac ARM. 2024-11-18 17:04:52 -05:00
Pearu Peterson
4d0a007d57 Add square_p 2024-11-13 20:14:37 +02:00
Dan Foreman-Mackey
9bb6366741 Allow more output storage types for some dot algorithms.
As reported in https://github.com/jax-ml/jax/issues/24794, there were some dot products that were resulting in an unnecessary conversion. This change makes the output storage type selection more flexible.

Fixes https://github.com/jax-ml/jax/issues/24794

PiperOrigin-RevId: 695694179
2024-11-12 05:31:50 -08:00
Yash Katariya
0bb30f0777 Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.

This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.

Fixes https://github.com/jax-ml/jax/issues/24521

PiperOrigin-RevId: 694274616
2024-11-07 15:51:54 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Dan Foreman-Mackey
52ad60521c Run dot algorithm tests with PJRT plugin. 2024-10-31 06:01:11 -04:00
Dan Foreman-Mackey
03854cfce4 Allow dot algorithms in default_matmul_precision config. 2024-10-29 10:48:21 -04:00
Christos Perivolaropoulos
6235158582 Dot algorithms are now supported for all types, change the test to reflect it.
PiperOrigin-RevId: 689036316
2024-10-23 11:17:55 -07:00
Praveen Narayanan
ad1aff098d Respect dot algorithm spec on TPU backends.
PiperOrigin-RevId: 688274131
2024-10-21 14:30:48 -07:00
Dan Foreman-Mackey
5ed2f4ef1c Remove checks for jaxlib v0.4.33 in tests 2024-10-11 15:39:24 -04:00
Dan Foreman-Mackey
28bbbf894f Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
2024-10-07 13:21:34 -07:00
Tom Natan
ed5ba633d4 Reverts 6cf09f8c24c67ff650b95d174501fff3cb59db0d
PiperOrigin-RevId: 682440543
2024-10-04 13:56:27 -07:00
Peter Hawkins
d3f63a66b8 Remove code to support jaxlib <= 0.4.33. 2024-10-04 11:39:05 -04:00