Skye Wanderman-Milne
1511439c01
Update version.py and CHANGELOG for jax 0.3.23 release
2022-10-12 11:09:15 -07:00
jax authors
58d516c49e
Merge pull request #12764 from midjourney:gda_sharding_property
...
PiperOrigin-RevId: 480628909
jax-v0.3.23
jax-v0.3.23-rc
0.3.23
2022-10-12 08:23:43 -07:00
Jack Gallagher
bd054f8197
add sharding property to GDA
...
should improve forward compatibility with Array
2022-10-11 21:45:15 -07:00
Peter Hawkins
9bb2c999d6
Reenable some tests disabled in the past because of an LLVM bug.
...
The issue no longer reproduces at head.
PiperOrigin-RevId: 480505525
2022-10-11 18:59:37 -07:00
jax authors
012398bdc7
Merge pull request #12761 from skye:colab_tpu_driver
...
PiperOrigin-RevId: 480499024
2022-10-11 18:13:01 -07:00
Skye Wanderman-Milne
e2aa939147
Update Colab TPU driver version
2022-10-11 17:49:10 -07:00
jax authors
6b459f5ebf
Merge pull request #12758 from skye:version
...
PiperOrigin-RevId: 480463351
2022-10-11 15:09:30 -07:00
Skye Wanderman-Milne
a6e0e77624
Update version.py, setup.py and CHANGELOG post jax 0.3.22 release
2022-10-11 21:37:53 +00:00
jax authors
1fb886ce43
Merge pull request #12749 from hawkinsp:tests
...
PiperOrigin-RevId: 480452347
2022-10-11 14:23:16 -07:00
Peter Hawkins
0d3277b5c3
Port more tests from jtu.cases_from_list to jtu.sample_product.
2022-10-11 21:06:08 +00:00
jax authors
7a323b7b9d
Merge pull request #12748 from ROCmSoftwarePlatform:fixedROCm_builds_typo
...
PiperOrigin-RevId: 480445232
2022-10-11 13:54:30 -07:00
jax authors
be3addf71e
Merge pull request #12738 from mattjj:issue12643
...
PiperOrigin-RevId: 480441842
2022-10-11 13:40:54 -07:00
Yash Katariya
335f45ebb2
Use _rewriting_take
and _chunk_iter
path during __getitem__
and __iter__
respectively when the Array is fully replicated
...
For example:
```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```
Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.
PiperOrigin-RevId: 480434272
2022-10-11 13:09:33 -07:00
Matthew Johnson
b27acedf1f
add more info to pytree prefix key errors
...
fixes #12643
2022-10-11 12:34:03 -07:00
jax authors
41417eed6f
Merge pull request #12752 from skye:workspace
...
PiperOrigin-RevId: 480405935
jaxlib-v0.3.22
jax-v0.3.22
jax-v0.3.22-rc
0.3.22
2022-10-11 11:12:26 -07:00
Skye Wanderman-Milne
63be2201aa
Update WORKSPACE and setup.py for jaxlib 0.3.22 release
2022-10-11 17:55:05 +00:00
jax authors
af98d8b00f
Merge pull request #12746 from jakevdp:fix-changelog
...
PiperOrigin-RevId: 480375907
2022-10-11 09:21:28 -07:00
Jake VanderPlas
05faf0f40d
Remove deprecated functionality from jax.test_util
...
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Chao Chen
8c13142ae6
fixed build instructions typo
2022-10-11 07:30:49 -07:00
Adam Paszke
9c994985a4
Support MANUAL collectives in top-level xmaps
...
It's a bit of a weird use-case, since MANUAL mode is meant for xmaps that
are nested inside pjits, but it doesn't hurt us to support it.
PiperOrigin-RevId: 480342531
2022-10-11 06:47:30 -07:00
jax authors
bf7869707a
Merge pull request #12727 from jakevdp:scipy-linalg-types
...
PiperOrigin-RevId: 480340339
2022-10-11 06:34:13 -07:00
Jake VanderPlas
943129419c
changelog: add missing github commit links
2022-10-11 06:24:38 -07:00
Rishabh Kabra
e45df46a51
Clarify docs for fori_loop
, noting that negative or custom increments are not supported.
...
PiperOrigin-RevId: 480317277
2022-10-11 04:41:10 -07:00
jax authors
eade3683af
Merge pull request #12324 from gnecula:tf_pjit
...
PiperOrigin-RevId: 480310826
2022-10-11 04:01:12 -07:00
George Necula
9c879adb73
[jax2tf] Implement jax2tf(pjit) for experimental_native_lowering
...
This implementation is for the case jax2tf.convert(pjit(f_jax)),
that is, the `pjit` appears at the top-level of the function to
be lowered.
2022-10-11 09:45:07 +02:00
Yash Katariya
ff17d3d9fe
Add support for calculating the device_assignment when there are no inputs to jit
and pjit
.
...
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).
Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Matthew Johnson
df5f7cb8d3
Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
...
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
Yash Katariya
9b3e864731
Add weak_type attribute to Array
since it exists on DA (but doesn't exist on SDA).
...
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Jake VanderPlas
afe74b4710
[typing] add type annotations to jax.scipy.linalg
2022-10-10 16:54:29 -07:00
Kuangyuan Chen
ec5b1c93d7
Turn on cpp pjit py default
...
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
Yash Katariya
76d8c08317
Fix the type annotation of return type of device_buffer
and device_buffers
which return ArrayImpl
instead of DeviceArray.
...
PiperOrigin-RevId: 480181798
2022-10-10 14:45:12 -07:00
Peter Hawkins
a3a2206d49
Fix compilation failure in lapack kernel under msan.
...
a_size wasn't defined, but it would only be caught under memory sanitizer.
PiperOrigin-RevId: 480176934
2022-10-10 14:24:59 -07:00
Peter Hawkins
2246887f7b
Add input-output aliasing annotations for LAPACK calls on CPU.
...
PiperOrigin-RevId: 480156067
2022-10-10 12:57:29 -07:00
Yash Katariya
752c3ffcd9
Lift lambda x: x
to the top level so that we don't recompile on every invocation of process_allgather
.
...
PiperOrigin-RevId: 480155482
2022-10-10 12:51:17 -07:00
jax authors
90e9abe278
Merge pull request #12722 from hawkinsp:tests
...
PiperOrigin-RevId: 480149233
2022-10-10 12:22:52 -07:00
Peter Hawkins
2ba0396ddb
Add changes accidentally omitted from
...
https://github.com/google/jax/pull/12717
2022-10-10 19:11:58 +00:00
Tianjian Lu
34eb6ce36b
[sparse] BCSR fromdense and todense.
...
PiperOrigin-RevId: 480141918
2022-10-10 11:54:22 -07:00
Peter Hawkins
c657449528
Copybara import of the project:
...
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Peter Hawkins
22cd50535b
Reapply: Use input-output aliasing for jaxlib GPU custom calls.
...
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.
PiperOrigin-RevId: 480134780
2022-10-10 11:29:18 -07:00
jax authors
707b07c1e9
Merge pull request #12697 from jakevdp:lax-slicing-types
...
PiperOrigin-RevId: 480131675
2022-10-10 11:22:35 -07:00
jax authors
ad4dcc4817
Merge pull request #12700 from jakevdp:scipy-types
...
PiperOrigin-RevId: 480131410
2022-10-10 11:15:46 -07:00
Jake VanderPlas
124021d720
[typing] add annotations to jax.scipy.ndimage
2022-10-10 09:11:13 -07:00
Jake VanderPlas
76e4a1d40f
[typing] add annotations to jax.scipy.fft
2022-10-09 05:18:45 -07:00
Jake VanderPlas
ae9f8eeb0c
[typing] annotate lax.slicing
2022-10-09 04:20:46 -07:00
jax authors
9cabd227d7
Copybara import of the project:
...
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:
implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
jax authors
25c6ef7ff9
Merge pull request #12707 from mattjj:djax-slice-sick4
...
PiperOrigin-RevId: 479876971
2022-10-09 00:23:39 -07:00
Matthew Johnson
6d2aaac245
implement bint arrays (opaque dtypes), add padding rules
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Yash Katariya
75b2d05989
Make is_fully_replicated
and is_fully_addressble
a property rather than a method.
...
Why?
1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)
2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.
PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
674038ca47
Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
...
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d
make device_put(prngkeyarray, sharding) for Array
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00