21660 Commits

Author SHA1 Message Date
Ruturaj4
332435e028 [ROCM] make mosaic dependency cuda specific 2024-07-02 11:05:42 -05:00
jax authors
42724ebc73 Merge pull request #22222 from mattjj:xla-computation-deprecation
PiperOrigin-RevId: 648719949
2024-07-02 08:03:23 -07:00
jax authors
2e0c100fef Merge pull request #22199 from gnecula:export_test_skip
PiperOrigin-RevId: 648705137
2024-07-02 07:01:26 -07:00
George Necula
cfa3c91c32 [export] Disable serialization in export_test if flatbuffers is not installed
This allows one to run most of export_test even if flatbuffers
is not installed. Only the serialization and deserialization are
skipped.
2024-07-02 15:46:38 +02:00
jax authors
92ebb533bd Merge pull request #22181 from jakevdp:xla-abbrevs
PiperOrigin-RevId: 648701764
2024-07-02 06:45:51 -07:00
Matthew Johnson
bd166e1d99 add more info to xla_computation deprecation warning 2024-07-02 13:31:07 +00:00
Sergei Lebedev
da76ebf095 Removed inspect.Signature hackery from `pl.BlockSpec`
I realized it is unnecessary and is no different than listing the parameters
in __init__ with relaxed types (to allow old argument order).

PiperOrigin-RevId: 648696510
2024-07-02 06:21:59 -07:00
jax authors
81fcc97e08 Merge pull request #21982 from ayaka14732:main
PiperOrigin-RevId: 648686086
2024-07-02 05:39:42 -07:00
Sergei Lebedev
8bc11138ad Migrated `pl.BlockSpec` uses in JAX to the new argument order
See #22209.

PiperOrigin-RevId: 648681171
2024-07-02 05:16:54 -07:00
George Necula
de0fd722f0 [pallas] Make pallas_test run on CPU and TPU also.
pallas_test was only running on GPU, but it is useful to run this test on all platform, in both interpret mode and the native mode. Added `skipTest` and `TODO` for the tests that fail, and in some cases configured numerical comparison tolerances.

All tests now have a "Interpreter" version, e.g., for `CallTest` we also define a `CallInterpreterTest` that runs the same tests but in interpreter
mode. This was not done systematically before, and in some cases the
interpreter test was missing, or was empty.

Some of the tests in pallas_test perhaps make sense only for GPU. I will
split them out in a separate CL.

PiperOrigin-RevId: 648619580
2024-07-02 00:40:59 -07:00
Peter Hawkins
2350a73f87 Use a class with __slots__ instead of a NamedTuple in JaxprEqn and SourceInfo, which are two tuples we build frequently.
Surprisingly this is faster. With Python 3.12:

```
In [1]: from typing import NamedTuple

In [2]: class C(NamedTuple):
   ...:     a: int
   ...:     b: int
   ...:     c: int
   ...:     d: int
   ...:     e: int
   ...:     f: int
   ...:     g: int
   ...:

In [3]: class D:
   ...:     __slots__ = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
   ...:     def __init__(self, a, b, c, d, e, f, g):
   ...:         self.a = a
   ...:         self.b = b
   ...:         self.c = c
   ...:         self.d = d
   ...:         self.e = e
   ...:         self.f = f
   ...:         self.g = g
   ...:

In [4]: %timeit D(1, 2, 3, 4, 5, 6, 7)
158 ns ± 0.458 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [5]: %timeit C(1, 2, 3, 4, 5, 6, 7)
236 ns ± 0.498 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: %timeit D(1, 2, 3, 4, 5, 6, 7)
159 ns ± 2.13 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [7]: %timeit C(1, 2, 3, 4, 5, 6, 7)
235 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
```

No behavioral changes intended.

PiperOrigin-RevId: 648556436
2024-07-01 19:18:58 -07:00
Peter Hawkins
063581a374 Small optimization when forming IR constant.
It's faster and simpler to write `0 in x.strides` than `np.any(np.equal(0, x.strides))`.

No behavior changes intended.

PiperOrigin-RevId: 648539396
2024-07-01 17:53:36 -07:00
jax authors
db13e6fc0e Merge pull request #22119 from dfm:cond-linear
PiperOrigin-RevId: 648535400
2024-07-01 17:36:59 -07:00
jax authors
68e9683040 Merge pull request #22167 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 648535148
2024-07-01 17:33:37 -07:00
jax authors
b669ab7bb1 Merge pull request #21925 from dfm:ffi-call
PiperOrigin-RevId: 648532673
2024-07-01 17:24:10 -07:00
jax authors
1eaaa10ad4 Add tests for some Pallas/Mosaic issues filed recently.
PiperOrigin-RevId: 648515147
2024-07-01 16:12:21 -07:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
Yash Katariya
94ba6c3f98 Run single controller device_put via efficient reshard if the device_set of input and the sharding is the same. The transfer_guard in the test fails before this CL.
PiperOrigin-RevId: 648464214
2024-07-01 13:14:43 -07:00
George Necula
9def0f1c00 [pallas] Add limited support for shape polymorphism for TPU
The main change is to pass the `result_shapes` to the
hlo.CustomCallOp when the output shapes contain dimension
variables. Everything else is already handled by the
support for dynamic bounds sizes for TPU.

Note that this CL only adds limited support for shape
polymorphism: only on TPU, and only when the block
sizes are static.
PiperOrigin-RevId: 648409699
2024-07-01 10:18:06 -07:00
Peter Hawkins
55589fbf41 Don't use the trace context in the prune_closed_jaxpr_outputs cache.
This code only manipulates jaxprs, and does not trace anything.

PiperOrigin-RevId: 648398046
2024-07-01 09:41:37 -07:00
Justin Fu
9653f58fa2 [Pallas] Add vmap/batched key support to to_pallas_key. This is helpful for workflows where the key is split before being passed into the kernel.
PiperOrigin-RevId: 648381795
2024-07-01 08:46:18 -07:00
jax authors
7cd5b5c65e Merge pull request #22132 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 648362572
2024-07-01 07:30:10 -07:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
rajasekharporeddy
6713e9dc8e Improve docs for jnp.var and std 2024-07-01 19:42:08 +05:30
jax authors
32a7071882 Merge pull request #22208 from superbobry:docs
PiperOrigin-RevId: 648353511
2024-07-01 06:52:02 -07:00
Dan Foreman-Mackey
e9b087d3a8 Add ffi_call function with a similar signature to pure_callback.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
2024-07-01 09:40:31 -04:00
Sergei Lebedev
e80632e6fd Revived the workaround for not-expanding type aliases
The version here only works for modules with
``from __future__ import annotations``, but we can safely add that import
to all modules now, since the minimal Python version JAX supports is 3.10.

The worakround was previously removed in #3485.
2024-07-01 14:31:53 +01:00
Andreas Steiner
a386af446c Avoid different return value when running on single host.
This PR reverts #22076.

Reverts 817eb7a9ee5343ba05df1faefbf41d2c0de2d31f

PiperOrigin-RevId: 648339674
2024-07-01 05:48:25 -07:00
Adam Paszke
727d120401 Bump up the shard_count for GPU FFT tests
They seem to be timing out with ASAN and no sharding.

PiperOrigin-RevId: 648301571
2024-07-01 02:58:23 -07:00
rajasekharporeddy
9623dcfa95 Update jnp.mean doc 2024-07-01 12:07:52 +05:30
jax authors
1949691daa Update XLA dependency to use revision
6fd83ac447.

PiperOrigin-RevId: 648165885
2024-06-30 11:59:08 -07:00
jax authors
5fac179f2f Merge pull request #22134 from gnecula:pallas_doc
PiperOrigin-RevId: 648147118
2024-06-30 09:15:16 -07:00
Yash Katariya
89c404e703 Improve error message when a global jax.Array is closed over a jitted function in McJAX.
PiperOrigin-RevId: 648010704
2024-06-29 14:36:44 -07:00
George Necula
2f808e9da9 Fix error in custom call registration for some FFI functions
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```

It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.

The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.

PiperOrigin-RevId: 647993991
2024-06-29 12:18:34 -07:00
jax authors
2e83c32437 Update XLA dependency to use revision
46b9412435.

PiperOrigin-RevId: 647991056
2024-06-29 11:57:52 -07:00
George Necula
bfdf8f4bd3 [pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
2024-06-29 14:43:48 +03:00
jax authors
071da567fb Merge pull request #22187 from selamw1:array_equal_doc
PiperOrigin-RevId: 647803461
2024-06-28 14:47:00 -07:00
selamw1
74d59019b5 array_equal_doscstring_added 2024-06-28 14:04:26 -07:00
Yash Katariya
75e7172e23 Simplify the ignore key when trace_context_in_key is False for util.cache
PiperOrigin-RevId: 647787828
2024-06-28 13:50:25 -07:00
jax authors
167f68bcfc Merge pull request #22180 from jakevdp:api-compat
PiperOrigin-RevId: 647776272
2024-06-28 13:10:24 -07:00
Junwhan Ahn
1a91fe7677 Explicitly disallow duplicated devices during array construction
`jax.make_array_from_single_device_arrays` should not allow passing more than one array on the same device as that would lead to an invalid array. While some of this case is already detected by later checks (e.g., `ArrayImpl._check_and_rearrange`), this CL explicitly checks the device list before calling IFRT so that we don't create an invalid IFRT array to begin with.

PiperOrigin-RevId: 647772472
2024-06-28 12:56:43 -07:00
Eugene Zhulenev
de6339569d [jax] Add a test that runs reduction on host
Check that nested computations generated by offloaded computation are correctly outlined into the host module.

PiperOrigin-RevId: 647771541
2024-06-28 12:53:10 -07:00
Jake VanderPlas
671db54f44 doc: remove references to submodules that no longer exist 2024-06-28 12:39:14 -07:00
Jake VanderPlas
251dfcad3c Deprecate jax.interpreters xb, xc, xe abbreviations.
Instead, import directly as jax.lib.xla_bridge, jax.lib.xla_client, jax.lib.xla_extension.
2024-06-28 12:38:43 -07:00
jax authors
8c889b50c0 Reverts 56526b46aad726ec4632fc3b18c79d26a8f399ef
PiperOrigin-RevId: 647759974
2024-06-28 12:13:01 -07:00
jax authors
a256b44efb Update XLA dependency to use revision
bb661128c9.

PiperOrigin-RevId: 647758753
2024-06-28 12:08:39 -07:00
Dan Foreman-Mackey
9b33df6438 Update C++ registration of cu_lu_pivots_to_permutation to use XLA_FFI_REGISTER_HANDLER
PiperOrigin-RevId: 647734115
2024-06-28 10:53:33 -07:00
Yash Katariya
061f4df82a Make device_put work with inputs which are host local and the sharding is global sharding i.e. sharding spanning across multiple hosts.
Use `multihost_utils.assert_equal` to check if the input is the same across all hosts.

Do some formatting fixes too ;)

PiperOrigin-RevId: 647711853
2024-06-28 09:46:49 -07:00
Yash Katariya
ba88601b9c remove the cloud TPU disable of memories_test.py because everything should work now
PiperOrigin-RevId: 647711611
2024-06-28 09:43:25 -07:00
Jake VanderPlas
56526b46aa Rollback of #22157 due to internal breakage
Reverts d577e29998c8b9a3db5d835bbf9af8c13f453a01

PiperOrigin-RevId: 647709501
2024-06-28 09:36:06 -07:00