13610 Commits

Author SHA1 Message Date
jax authors
42724ebc73 Merge pull request #22222 from mattjj:xla-computation-deprecation
PiperOrigin-RevId: 648719949
2024-07-02 08:03:23 -07:00
jax authors
92ebb533bd Merge pull request #22181 from jakevdp:xla-abbrevs
PiperOrigin-RevId: 648701764
2024-07-02 06:45:51 -07:00
Matthew Johnson
bd166e1d99 add more info to xla_computation deprecation warning 2024-07-02 13:31:07 +00:00
Sergei Lebedev
da76ebf095 Removed inspect.Signature hackery from `pl.BlockSpec`
I realized it is unnecessary and is no different than listing the parameters
in __init__ with relaxed types (to allow old argument order).

PiperOrigin-RevId: 648696510
2024-07-02 06:21:59 -07:00
Sergei Lebedev
8bc11138ad Migrated `pl.BlockSpec` uses in JAX to the new argument order
See #22209.

PiperOrigin-RevId: 648681171
2024-07-02 05:16:54 -07:00
Peter Hawkins
2350a73f87 Use a class with __slots__ instead of a NamedTuple in JaxprEqn and SourceInfo, which are two tuples we build frequently.
Surprisingly this is faster. With Python 3.12:

```
In [1]: from typing import NamedTuple

In [2]: class C(NamedTuple):
   ...:     a: int
   ...:     b: int
   ...:     c: int
   ...:     d: int
   ...:     e: int
   ...:     f: int
   ...:     g: int
   ...:

In [3]: class D:
   ...:     __slots__ = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
   ...:     def __init__(self, a, b, c, d, e, f, g):
   ...:         self.a = a
   ...:         self.b = b
   ...:         self.c = c
   ...:         self.d = d
   ...:         self.e = e
   ...:         self.f = f
   ...:         self.g = g
   ...:

In [4]: %timeit D(1, 2, 3, 4, 5, 6, 7)
158 ns ± 0.458 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [5]: %timeit C(1, 2, 3, 4, 5, 6, 7)
236 ns ± 0.498 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: %timeit D(1, 2, 3, 4, 5, 6, 7)
159 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [7]: %timeit C(1, 2, 3, 4, 5, 6, 7)
235 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
```

No behavioral changes intended.

PiperOrigin-RevId: 648556436
2024-07-01 19:18:58 -07:00
Peter Hawkins
063581a374 Small optimization when forming IR constant.
It's faster and simpler to write `0 in x.strides` than `np.any(np.equal(0, x.strides))`.

No behavior changes intended.

PiperOrigin-RevId: 648539396
2024-07-01 17:53:36 -07:00
jax authors
db13e6fc0e Merge pull request #22119 from dfm:cond-linear
PiperOrigin-RevId: 648535400
2024-07-01 17:36:59 -07:00
jax authors
68e9683040 Merge pull request #22167 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 648535148
2024-07-01 17:33:37 -07:00
jax authors
b669ab7bb1 Merge pull request #21925 from dfm:ffi-call
PiperOrigin-RevId: 648532673
2024-07-01 17:24:10 -07:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
Yash Katariya
94ba6c3f98 Run single controller device_put via efficient reshard if the device_set of input and the sharding is the same. The transfer_guard in the test fails before this CL.
PiperOrigin-RevId: 648464214
2024-07-01 13:14:43 -07:00
George Necula
9def0f1c00 [pallas] Add limited support for shape polymorphism for TPU
The main change is to pass the `result_shapes` to the
hlo.CustomCallOp when the output shapes contain dimension
variables. Everything else is already handled by the
support for dynamic bounds sizes for TPU.

Note that this CL only adds limited support for shape
polymorphism: only on TPU, and only when the block
sizes are static.
PiperOrigin-RevId: 648409699
2024-07-01 10:18:06 -07:00
Peter Hawkins
55589fbf41 Don't use the trace context in the prune_closed_jaxpr_outputs cache.
This code only manipulates jaxprs, and does not trace anything.

PiperOrigin-RevId: 648398046
2024-07-01 09:41:37 -07:00
Justin Fu
9653f58fa2 [Pallas] Add vmap/batched key support to to_pallas_key. This is helpful for workflows where the key is split before being passed into the kernel.
PiperOrigin-RevId: 648381795
2024-07-01 08:46:18 -07:00
jax authors
7cd5b5c65e Merge pull request #22132 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 648362572
2024-07-01 07:30:10 -07:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
rajasekharporeddy
6713e9dc8e Improve docs for jnp.var and std 2024-07-01 19:42:08 +05:30
Dan Foreman-Mackey
e9b087d3a8 Add ffi_call function with a similar signature to pure_callback.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
2024-07-01 09:40:31 -04:00
Andreas Steiner
a386af446c Avoid different return value when running on single host.
This PR reverts #22076.

Reverts 817eb7a9ee5343ba05df1faefbf41d2c0de2d31f

PiperOrigin-RevId: 648339674
2024-07-01 05:48:25 -07:00
rajasekharporeddy
9623dcfa95 Update jnp.mean doc 2024-07-01 12:07:52 +05:30
jax authors
5fac179f2f Merge pull request #22134 from gnecula:pallas_doc
PiperOrigin-RevId: 648147118
2024-06-30 09:15:16 -07:00
Yash Katariya
89c404e703 Improve error message when a global jax.Array is closed over a jitted function in McJAX.
PiperOrigin-RevId: 648010704
2024-06-29 14:36:44 -07:00
George Necula
2f808e9da9 Fix error in custom call registration for some FFI functions
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```

It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.

The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.

PiperOrigin-RevId: 647993991
2024-06-29 12:18:34 -07:00
George Necula
bfdf8f4bd3 [pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
2024-06-29 14:43:48 +03:00
jax authors
071da567fb Merge pull request #22187 from selamw1:array_equal_doc
PiperOrigin-RevId: 647803461
2024-06-28 14:47:00 -07:00
selamw1
74d59019b5 array_equal_doscstring_added 2024-06-28 14:04:26 -07:00
Yash Katariya
75e7172e23 Simplify the ignore key when trace_context_in_key is False for util.cache
PiperOrigin-RevId: 647787828
2024-06-28 13:50:25 -07:00
Jake VanderPlas
251dfcad3c Deprecate jax.interpreters xb, xc, xe abbreviations.
Instead, import directly as jax.lib.xla_bridge, jax.lib.xla_client, jax.lib.xla_extension.
2024-06-28 12:38:43 -07:00
jax authors
8c889b50c0 Reverts 56526b46aad726ec4632fc3b18c79d26a8f399ef
PiperOrigin-RevId: 647759974
2024-06-28 12:13:01 -07:00
Yash Katariya
061f4df82a Make device_put work with inputs which are host local and the sharding is global sharding i.e. sharding spanning across multiple hosts.
Use `multihost_utils.assert_equal` to check if the input is the same across all hosts.

Do some formatting fixes too ;)

PiperOrigin-RevId: 647711853
2024-06-28 09:46:49 -07:00
Jake VanderPlas
56526b46aa Rollback of #22157 due to internal breakage
Reverts d577e29998c8b9a3db5d835bbf9af8c13f453a01

PiperOrigin-RevId: 647709501
2024-06-28 09:36:06 -07:00
Yash Katariya
ba5b3c7941 Make lowering aware of compute_type so that we choose the correct lowering code.
For example, if you have 2 `lax.linalg.qr` calls (one on `TPU` and another on `device_host`), we should lower to the `device_host` qr decomposition to CPU.

PiperOrigin-RevId: 647705015
2024-06-28 09:21:34 -07:00
Jake VanderPlas
fbcb157ad3 Finalize deprecation of several previously-deprecated jax.core functions:
- `jax.core.canonicalize_shape`
- `jax.core.dimension_as_value`
- `jax.core.definitely_equal`
- `jax.core.symbolic_equal_dim`

These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024.

PiperOrigin-RevId: 647671122
2024-06-28 07:28:28 -07:00
Adam Paszke
648b9519cf Be stricter when it comes to handling dtypes in splash_attention mask function
We previously took a logical_and of a mix of boolean and integer inputs, which isn't allowed
under some of the strict dtype modes. This has been causing some JAX tests to fail.

PiperOrigin-RevId: 647669850
2024-06-28 07:23:08 -07:00
jax authors
44071f8595 Merge pull request #22159 from jakevdp:arange-doc
PiperOrigin-RevId: 647669022
2024-06-28 07:18:53 -07:00
Adam Paszke
b19ad5b315 [Mosaic GPU] Add support for non-128B swizzles in WGMMA
PiperOrigin-RevId: 647667550
2024-06-28 07:12:10 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
Adam Paszke
3ebebdfb76 [Mosaic GPU] Stop using nvgpu for TMA
It seems like nvgpu dialect bakes in a bunch of overly restrictive checks in its verifiers
and doesn't really buy us much in this case. nvvm works just fine.

PiperOrigin-RevId: 647653684
2024-06-28 06:08:36 -07:00
jax authors
ad4c9ab85a [Mosaic GPU] Avoid failing when importing profiler.py even if lib.mosaic_gpu is unavailable.
PiperOrigin-RevId: 647626620
2024-06-28 04:04:01 -07:00
Adam Paszke
54edaf6e6c [Mosaic GPU] Add a warp specialized kernel with a separate TMA warp
With this kernel we're able to significantly improve the performance
of large head_dim kernels, reaching ~62% utilization for 4k sequence
length and ~71% for 32k.

TODO: the two kernels are quite similar and it should be possible to
collapse them into one
PiperOrigin-RevId: 647597865
2024-06-28 01:50:07 -07:00
George Necula
24b42eed5e [export] Clean up BUILD targets for jax.experimental.export
jax.experimental.export is deprecated and will be removed in a future version of JAX.

See migration guide at: https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export

PiperOrigin-RevId: 647562073
2024-06-27 23:08:48 -07:00
George Necula
47f1b3de2c [export] Add documentation for debugging and for ensuring compatibility.
The rendered documentation is at https://jax--21976.org.readthedocs.build/en/21976/export/export.html#developer-documentation (for the export developer documentation, including compatibility) and https://jax--21976.org.readthedocs.build/en/21976/export/shape_poly.html#debugging (for the shape polymorphism debugging documentation)

While testing the compatibility mechanism I discovered that it can be circumvented by caches.
To fix this, I added export_ignore_forward_compatibility to mlir.LoweringParameters.
2024-06-28 08:36:55 +03:00
Mark Sandler
fdb1c14433 Switches make_array_from_callback to use batched_device_put
PiperOrigin-RevId: 647537267
2024-06-27 21:00:05 -07:00
Peter Hawkins
ac3cb6f954 Simplify mlir.dense_int_array.
The NumPy array conversion here is pointless and slightly slower than not doing it.

PiperOrigin-RevId: 647520922
2024-06-27 19:33:06 -07:00
Jake VanderPlas
3d7c53afec Improved documentation for jnp.arange 2024-06-27 19:09:07 -07:00
Yash Katariya
e1a496d3b6 Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
2024-06-27 16:47:31 -07:00
jax authors
d577e29998 Merge pull request #22157 from jakevdp:cudnn-version
PiperOrigin-RevId: 647467418
2024-06-27 15:32:53 -07:00
Peter Hawkins
61703690ee Add a direct HLO lowering of remat_p that doesn't call eval_jaxpr.
This turns out to be faster, not least because we don't need to use the tracing machinery again.

PiperOrigin-RevId: 647462045
2024-06-27 15:15:11 -07:00
Jake VanderPlas
337cefa7bb Bump minimum supported CUDNN version to 9.0
This should have been changed for the 0.4.30 release; updating this value
will lead to better errors when attempting to install with an older cudnn.
2024-06-27 14:56:51 -07:00