20245 Commits

Author SHA1 Message Date
Meekail Zain
2899213efb Fixed hypot bug on nan/inf pairings, began deprecation of non-real values 2024-04-15 17:56:16 +00:00
jax authors
51352fa05c fix matrix dimension and block shape.
PiperOrigin-RevId: 624988654
2024-04-15 09:39:31 -07:00
Yash Katariya
90401d51e9 Accept layout on ShapeDtypeStruct on the sharding argument. DeviceLocalLayout.AUTO is not allowed on SDS.
PiperOrigin-RevId: 624982814
2024-04-15 09:19:40 -07:00
jax authors
b7293005af Merge pull request #20762 from j-towns:scatter-doc-correction
PiperOrigin-RevId: 624971136
2024-04-15 08:38:57 -07:00
Junwhan Ahn
ac1a53d8e4 Optimize _create_copy_plan by iterating over only the shards that are needed for materialization
For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX.

The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`.

PiperOrigin-RevId: 624969222
2024-04-15 08:29:47 -07:00
jax authors
3a09404426 Merge pull request #20586 from superbobry:jaxlib
PiperOrigin-RevId: 624941598
2024-04-15 06:40:07 -07:00
Jamie Townsend
b2783120c0 Correct a name in ScatterDimensionNumbers docstring 2024-04-15 10:36:24 +00:00
jax authors
78c056f41d Update XLA dependency to use revision
6805a043c6.

PiperOrigin-RevId: 624809937
2024-04-14 19:58:38 -07:00
Yash Katariya
2c85ca6fec If callback returns a fully replicated global array, return it as is.
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.

PiperOrigin-RevId: 624763603
2024-04-14 14:35:57 -07:00
jax authors
4a6ee78f4f [XLA] Clear derived instruction's sharding only if shapes are incompatible.
When AlgebraicSimplifier calls `dot->SetupDerivedInstruction(new_lhs);` in HandleDot, lhs sharding was cleared when dot didn't have a sharding. With this CL, lhs preserves its sharding because the condition for clearing the sharding is narrowed down to only when shapes are incompatible.

Fixes https://github.com/google/jax/issues/20710

PiperOrigin-RevId: 624731930
2024-04-14 10:30:53 -07:00
jax authors
7d0ba76ede Update XLA dependency to use revision
2b8f670cf5.

PiperOrigin-RevId: 624612863
2024-04-13 20:57:14 -07:00
Yash Katariya
5ce7dca969 Add support for loading checkpoints with a given layout to the array serialization library
PiperOrigin-RevId: 624596358
2024-04-13 19:35:50 -07:00
Sergei Lebedev
754fab91f7 Bumped the minimum jaxlib to 0.4.23
jaxlib 0.4.23 has xla_extension_version 223 and mlir_api_version 54.
2024-04-13 08:18:33 +01:00
Yash Katariya
70dca30395 Remove the dead code now that jax.Array is the only array we have
PiperOrigin-RevId: 624390245
2024-04-12 21:41:42 -07:00
jax authors
ee8ce0f81f Update XLA dependency to use revision
9ca54f324c.

PiperOrigin-RevId: 624386564
2024-04-12 21:17:23 -07:00
Yash Katariya
9e989321f1 Make sure we don't return GSPMDSharding in compiled.input_shardings
PiperOrigin-RevId: 624343180
2024-04-12 17:52:44 -07:00
Roy Frostig
09415607bb fix up extend:core build rule
We want `pytype_strict_library` here.

PiperOrigin-RevId: 624337356
2024-04-12 17:31:10 -07:00
jax authors
82504194f7 Merge pull request #20718 from kkiningh:patch-3
PiperOrigin-RevId: 624336436
2024-04-12 17:22:24 -07:00
jax authors
2948a80c78 Merge pull request #20217 from froystig:jex-primitives
PiperOrigin-RevId: 624323399
2024-04-12 16:26:09 -07:00
Yash Katariya
001732086b Use _internal_device_list in _get_device so that all places accessing _get_device get a speedup.
PiperOrigin-RevId: 624320655
2024-04-12 16:17:34 -07:00
jax authors
198ce52bf7 Merge pull request #20736 from jakevdp:doc-strict-promotion
PiperOrigin-RevId: 624319591
2024-04-12 16:08:45 -07:00
jax authors
1c5808a22f Merge pull request #20595 from google:refs-in-vjps
PiperOrigin-RevId: 624302083
2024-04-12 15:08:58 -07:00
jax authors
be44d83764 [Pallas TPU] Raise clearer NotImplementedError on vector -> scalar reductions.
Also adds tests with examples of jnp.sum() and jnp.max(), which should help provide a reference.

PiperOrigin-RevId: 624300956
2024-04-12 15:00:07 -07:00
jax authors
79dd75c6b7 Merge pull request #20727 from mattjj:make-jaxpr-docstring-fix
PiperOrigin-RevId: 624298977
2024-04-12 14:51:10 -07:00
Matthew Johnson
c33126f45a fix 2024-04-12 14:25:38 -07:00
Junwhan Ahn
3245455900 Optimize _create_copy_plan in array.py
* `_get_device` is called from many tight loops, so it's worth avoiding unnecessary work as much as possible.
* `_create_copy_plan` now uses sharding's `_internal_device_list` instead of querying the device of every shard in a loop.

PiperOrigin-RevId: 624288637
2024-04-12 14:12:31 -07:00
Jake VanderPlas
2be17dc778 DOC: document strict dtype promotion mode 2024-04-12 14:05:05 -07:00
Jieying Luo
44e83d4e0a Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target.
PiperOrigin-RevId: 624275186
2024-04-12 13:30:18 -07:00
jax authors
f581c6500b Merge pull request #20726 from mattjj:io-callback-while-batching-fix
PiperOrigin-RevId: 624274125
2024-04-12 13:21:24 -07:00
Jake VanderPlas
462e5c603a Finalize deprecation of invalid JIT argument names & numbers
Invalid static_argnames/static_argnums have been resulting in a warning since JAX v0.3.17, released in June 2022. After this change, they will result in an error.

PiperOrigin-RevId: 624270701
2024-04-12 13:09:17 -07:00
Chi Zeng
9a89a0cee8 Let caller switch implementation of reduction after import
Thank you to gnecula@ for adding the jax2tf_associative_scan_reductions flag and context: 5bfe1852a4
For GPU, the specific implementation of `cumsum` can make the whopping difference between a latency in microseconds versus milliseconds!

Before this change, adjusting the method of lowering `cumsum` via this scope has no effect:

```py
with jax.jax2tf_associative_scan_reductions(True):
  ...
```

... because the cumsum method (and other reduce methods) have their implementations set when the `jax2tf` library is imported, ie when this line is called:

```py
from jax.experimental import jax2tf
```

Thus, any future switches of the implementation (to, say associative scanning), even if they happen before the `jax2tf.convert` method executes, had no effect because methods such as `cumsum` had already been curried at import time.

This change fixes that by varying the implementation based on the current value of `config.jax2tf_associative_scan_reductions`.

We use existing tests to verify the continued correctness of this CL that affects latency. We add TPU to the list of devices to apply some limitations - One TPU unit test had suddenly failed because the scope now works: Even though TPUs use a different path to lower by default, the context above explicitly sets to associative scanning.

PiperOrigin-RevId: 624264567
2024-04-12 12:47:59 -07:00
Matthew Johnson
8b691d15a8 fix cache key typo np.ndarray -> np.arange(...).reshape
still untested

issue introduced in cl/617093247 aka 0b28a4b

hopefully addresses google/jax#20681

PiperOrigin-RevId: 624259947
2024-04-12 12:30:31 -07:00
Dougal Maclaurin
f313a46916
Merge branch 'main' into refs-in-vjps 2024-04-12 15:25:37 -04:00
Dougal
29368e6a8e Add a zeros rule for mutable arrays and test it using a custom vjp.
add jit compatibility (have pjit jvp instantiate all ref tangents)

Co-authored-by: Matt Johnson <mattjj@google.com>
2024-04-12 15:22:07 -04:00
jax authors
a155c5a999 [Pallas] Global Barrier bug fix.
Each of left and right neighbors increment a core's sync flag
by one not two.

PiperOrigin-RevId: 624245970
2024-04-12 11:47:52 -07:00
jax authors
f6061cab89 Merge pull request #19605 from Micky774:cuda_errors
PiperOrigin-RevId: 624244248
2024-04-12 11:38:42 -07:00
jax authors
7fb27a4cce Merge pull request #19604 from Micky774:faq_cuda
PiperOrigin-RevId: 624234866
2024-04-12 11:09:12 -07:00
jax authors
2e5243605a Merge pull request #20730 from albanie:patch-1
PiperOrigin-RevId: 624221912
2024-04-12 10:39:55 -07:00
jax authors
4331abecff Merge pull request #20603 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 624221601
2024-04-12 10:30:01 -07:00
Samuel
959ecca182
Minor typo fix in docstring jax.lax.psum
Fix code formatting inconsistency in `psum` docstring

Currently, "device2" and "device3" are rendered incorrectly in the JAX documentation (see second example [here](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.psum.html))
2024-04-12 13:12:00 +01:00
Sergei Lebedev
386be2d307 Pallas TPU now accepts compiler parameters only via mosaic=...
This mirrors a similar change in Pallas GPU, which uses triton=.

PiperOrigin-RevId: 624131518
2024-04-12 04:34:21 -07:00
Matthew Johnson
83a200a42f simple fix to make_jaxpr docstring
maybe it was accidentally copied from xla_computation before?
2024-04-11 21:50:51 -07:00
Matthew Johnson
4f90d365b2 [callbacks] allow unordered effects in batched while_loop if predicate is not batched 2024-04-11 21:43:59 -07:00
jax authors
3146c2a3f6 Merge pull request #20725 from mattjj:io-callback-batching-fix
PiperOrigin-RevId: 624039690
2024-04-11 21:32:18 -07:00
jax authors
e2139c7082 Update XLA dependency to use revision
c42c5e1f7f.

PiperOrigin-RevId: 624036912
2024-04-11 21:14:55 -07:00
Matthew Johnson
8037e7b08f [callbacks] io_callback batching rule accidentally called pure_callback 2024-04-11 20:45:46 -07:00
jax authors
9758348043 Merge pull request #20690 from jakevdp:initial-dep
PiperOrigin-RevId: 623999461
2024-04-11 18:17:59 -07:00
jax authors
f68b3b1d2e Merge pull request #20712 from jakevdp:key-reuse-impl-location
PiperOrigin-RevId: 623989088
2024-04-11 17:34:52 -07:00
jax authors
c64ddb74d0 Merge pull request #20723 from mattjj:jax-namedsharding-import
PiperOrigin-RevId: 623984896
2024-04-11 17:15:58 -07:00
Jake VanderPlas
dc2d8c13d0 [key reuse] call key reuse logic directly in dispatch 2024-04-11 17:08:32 -07:00