13373 Commits

Author SHA1 Message Date
jax authors
58d516c49e Merge pull request #12764 from midjourney:gda_sharding_property
PiperOrigin-RevId: 480628909
jax-v0.3.23 jax-v0.3.23-rc 0.3.23
2022-10-12 08:23:43 -07:00
Jack Gallagher
bd054f8197
add sharding property to GDA
should improve forward compatibility with Array
2022-10-11 21:45:15 -07:00
Peter Hawkins
9bb2c999d6 Reenable some tests disabled in the past because of an LLVM bug.
The issue no longer reproduces at head.

PiperOrigin-RevId: 480505525
2022-10-11 18:59:37 -07:00
jax authors
012398bdc7 Merge pull request #12761 from skye:colab_tpu_driver
PiperOrigin-RevId: 480499024
2022-10-11 18:13:01 -07:00
Skye Wanderman-Milne
e2aa939147 Update Colab TPU driver version 2022-10-11 17:49:10 -07:00
jax authors
6b459f5ebf Merge pull request #12758 from skye:version
PiperOrigin-RevId: 480463351
2022-10-11 15:09:30 -07:00
Skye Wanderman-Milne
a6e0e77624 Update version.py, setup.py and CHANGELOG post jax 0.3.22 release 2022-10-11 21:37:53 +00:00
jax authors
1fb886ce43 Merge pull request #12749 from hawkinsp:tests
PiperOrigin-RevId: 480452347
2022-10-11 14:23:16 -07:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
jax authors
7a323b7b9d Merge pull request #12748 from ROCmSoftwarePlatform:fixedROCm_builds_typo
PiperOrigin-RevId: 480445232
2022-10-11 13:54:30 -07:00
jax authors
be3addf71e Merge pull request #12738 from mattjj:issue12643
PiperOrigin-RevId: 480441842
2022-10-11 13:40:54 -07:00
Yash Katariya
335f45ebb2 Use _rewriting_take and _chunk_iter path during __getitem__ and __iter__ respectively when the Array is fully replicated
For example:

```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```

Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.

PiperOrigin-RevId: 480434272
2022-10-11 13:09:33 -07:00
Matthew Johnson
b27acedf1f add more info to pytree prefix key errors
fixes #12643
2022-10-11 12:34:03 -07:00
jax authors
41417eed6f Merge pull request #12752 from skye:workspace
PiperOrigin-RevId: 480405935
jaxlib-v0.3.22 jax-v0.3.22 jax-v0.3.22-rc 0.3.22
2022-10-11 11:12:26 -07:00
Skye Wanderman-Milne
63be2201aa Update WORKSPACE and setup.py for jaxlib 0.3.22 release 2022-10-11 17:55:05 +00:00
jax authors
af98d8b00f Merge pull request #12746 from jakevdp:fix-changelog
PiperOrigin-RevId: 480375907
2022-10-11 09:21:28 -07:00
Jake VanderPlas
05faf0f40d Remove deprecated functionality from jax.test_util
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Chao Chen
8c13142ae6 fixed build instructions typo 2022-10-11 07:30:49 -07:00
Adam Paszke
9c994985a4 Support MANUAL collectives in top-level xmaps
It's a bit of a weird use-case, since MANUAL mode is meant for xmaps that
are nested inside pjits, but it doesn't hurt us to support it.

PiperOrigin-RevId: 480342531
2022-10-11 06:47:30 -07:00
jax authors
bf7869707a Merge pull request #12727 from jakevdp:scipy-linalg-types
PiperOrigin-RevId: 480340339
2022-10-11 06:34:13 -07:00
Jake VanderPlas
943129419c changelog: add missing github commit links 2022-10-11 06:24:38 -07:00
Rishabh Kabra
e45df46a51 Clarify docs for fori_loop, noting that negative or custom increments are not supported.
PiperOrigin-RevId: 480317277
2022-10-11 04:41:10 -07:00
jax authors
eade3683af Merge pull request #12324 from gnecula:tf_pjit
PiperOrigin-RevId: 480310826
2022-10-11 04:01:12 -07:00
George Necula
9c879adb73 [jax2tf] Implement jax2tf(pjit) for experimental_native_lowering
This implementation is for the case jax2tf.convert(pjit(f_jax)),
that is, the `pjit` appears at the top-level of the function to
be lowered.
2022-10-11 09:45:07 +02:00
Yash Katariya
ff17d3d9fe Add support for calculating the device_assignment when there are no inputs to jit and pjit.
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).

Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
Yash Katariya
9b3e864731 Add weak_type attribute to Array since it exists on DA (but doesn't exist on SDA).
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Jake VanderPlas
afe74b4710 [typing] add type annotations to jax.scipy.linalg 2022-10-10 16:54:29 -07:00
Kuangyuan Chen
ec5b1c93d7 Turn on cpp pjit py default
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
Yash Katariya
76d8c08317 Fix the type annotation of return type of device_buffer and device_buffers which return ArrayImpl instead of DeviceArray.
PiperOrigin-RevId: 480181798
2022-10-10 14:45:12 -07:00
Peter Hawkins
a3a2206d49 Fix compilation failure in lapack kernel under msan.
a_size wasn't defined, but it would only be caught under memory sanitizer.

PiperOrigin-RevId: 480176934
2022-10-10 14:24:59 -07:00
Peter Hawkins
2246887f7b Add input-output aliasing annotations for LAPACK calls on CPU.
PiperOrigin-RevId: 480156067
2022-10-10 12:57:29 -07:00
Yash Katariya
752c3ffcd9 Lift lambda x: x to the top level so that we don't recompile on every invocation of process_allgather.
PiperOrigin-RevId: 480155482
2022-10-10 12:51:17 -07:00
jax authors
90e9abe278 Merge pull request #12722 from hawkinsp:tests
PiperOrigin-RevId: 480149233
2022-10-10 12:22:52 -07:00
Peter Hawkins
2ba0396ddb Add changes accidentally omitted from
https://github.com/google/jax/pull/12717
2022-10-10 19:11:58 +00:00
Tianjian Lu
34eb6ce36b [sparse] BCSR fromdense and todense.
PiperOrigin-RevId: 480141918
2022-10-10 11:54:22 -07:00
Peter Hawkins
c657449528 Copybara import of the project:
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:

Migrate more tests from jtu.cases_from_list to jtu.sample_product.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Peter Hawkins
22cd50535b Reapply: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.

PiperOrigin-RevId: 480134780
2022-10-10 11:29:18 -07:00
jax authors
707b07c1e9 Merge pull request #12697 from jakevdp:lax-slicing-types
PiperOrigin-RevId: 480131675
2022-10-10 11:22:35 -07:00
jax authors
ad4dcc4817 Merge pull request #12700 from jakevdp:scipy-types
PiperOrigin-RevId: 480131410
2022-10-10 11:15:46 -07:00
Jake VanderPlas
124021d720 [typing] add annotations to jax.scipy.ndimage 2022-10-10 09:11:13 -07:00
Jake VanderPlas
76e4a1d40f [typing] add annotations to jax.scipy.fft 2022-10-09 05:18:45 -07:00
Jake VanderPlas
ae9f8eeb0c [typing] annotate lax.slicing 2022-10-09 04:20:46 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
jax authors
25c6ef7ff9 Merge pull request #12707 from mattjj:djax-slice-sick4
PiperOrigin-RevId: 479876971
2022-10-09 00:23:39 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Yash Katariya
75b2d05989 Make is_fully_replicated and is_fully_addressble a property rather than a method.
Why?

1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)

2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.

PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
674038ca47 Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d make device_put(prngkeyarray, sharding) for Array
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00