jax authors
58d516c49e
Merge pull request #12764 from midjourney:gda_sharding_property
...
PiperOrigin-RevId: 480628909
2022-10-12 08:23:43 -07:00
Jack Gallagher
bd054f8197
add sharding property to GDA
...
should improve forward compatibility with Array
2022-10-11 21:45:15 -07:00
Skye Wanderman-Milne
e2aa939147
Update Colab TPU driver version
2022-10-11 17:49:10 -07:00
Skye Wanderman-Milne
a6e0e77624
Update version.py, setup.py and CHANGELOG post jax 0.3.22 release
2022-10-11 21:37:53 +00:00
jax authors
1fb886ce43
Merge pull request #12749 from hawkinsp:tests
...
PiperOrigin-RevId: 480452347
2022-10-11 14:23:16 -07:00
Peter Hawkins
0d3277b5c3
Port more tests from jtu.cases_from_list to jtu.sample_product.
2022-10-11 21:06:08 +00:00
jax authors
be3addf71e
Merge pull request #12738 from mattjj:issue12643
...
PiperOrigin-RevId: 480441842
2022-10-11 13:40:54 -07:00
Yash Katariya
335f45ebb2
Use _rewriting_take
and _chunk_iter
path during __getitem__
and __iter__
respectively when the Array is fully replicated
...
For example:
```
k1, k2 = jax.random.split(key, 2) # where key is fully replicated on 8 devices
```
Then `k1` and `k2` should also maintain the sharding of `key` since `key` is fully replicated.
PiperOrigin-RevId: 480434272
2022-10-11 13:09:33 -07:00
Matthew Johnson
b27acedf1f
add more info to pytree prefix key errors
...
fixes #12643
2022-10-11 12:34:03 -07:00
Jake VanderPlas
05faf0f40d
Remove deprecated functionality from jax.test_util
...
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Adam Paszke
9c994985a4
Support MANUAL collectives in top-level xmaps
...
It's a bit of a weird use-case, since MANUAL mode is meant for xmaps that
are nested inside pjits, but it doesn't hurt us to support it.
PiperOrigin-RevId: 480342531
2022-10-11 06:47:30 -07:00
jax authors
bf7869707a
Merge pull request #12727 from jakevdp:scipy-linalg-types
...
PiperOrigin-RevId: 480340339
2022-10-11 06:34:13 -07:00
Rishabh Kabra
e45df46a51
Clarify docs for fori_loop
, noting that negative or custom increments are not supported.
...
PiperOrigin-RevId: 480317277
2022-10-11 04:41:10 -07:00
George Necula
9c879adb73
[jax2tf] Implement jax2tf(pjit) for experimental_native_lowering
...
This implementation is for the case jax2tf.convert(pjit(f_jax)),
that is, the `pjit` appears at the top-level of the function to
be lowered.
2022-10-11 09:45:07 +02:00
Yash Katariya
ff17d3d9fe
Add support for calculating the device_assignment when there are no inputs to jit
and pjit
.
...
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).
Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Matthew Johnson
df5f7cb8d3
Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
...
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
Yash Katariya
9b3e864731
Add weak_type attribute to Array
since it exists on DA (but doesn't exist on SDA).
...
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Jake VanderPlas
afe74b4710
[typing] add type annotations to jax.scipy.linalg
2022-10-10 16:54:29 -07:00
Kuangyuan Chen
ec5b1c93d7
Turn on cpp pjit py default
...
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
Yash Katariya
76d8c08317
Fix the type annotation of return type of device_buffer
and device_buffers
which return ArrayImpl
instead of DeviceArray.
...
PiperOrigin-RevId: 480181798
2022-10-10 14:45:12 -07:00
Yash Katariya
752c3ffcd9
Lift lambda x: x
to the top level so that we don't recompile on every invocation of process_allgather
.
...
PiperOrigin-RevId: 480155482
2022-10-10 12:51:17 -07:00
Tianjian Lu
34eb6ce36b
[sparse] BCSR fromdense and todense.
...
PiperOrigin-RevId: 480141918
2022-10-10 11:54:22 -07:00
jax authors
707b07c1e9
Merge pull request #12697 from jakevdp:lax-slicing-types
...
PiperOrigin-RevId: 480131675
2022-10-10 11:22:35 -07:00
Jake VanderPlas
124021d720
[typing] add annotations to jax.scipy.ndimage
2022-10-10 09:11:13 -07:00
Jake VanderPlas
76e4a1d40f
[typing] add annotations to jax.scipy.fft
2022-10-09 05:18:45 -07:00
Jake VanderPlas
ae9f8eeb0c
[typing] annotate lax.slicing
2022-10-09 04:20:46 -07:00
jax authors
9cabd227d7
Copybara import of the project:
...
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:
implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
jax authors
25c6ef7ff9
Merge pull request #12707 from mattjj:djax-slice-sick4
...
PiperOrigin-RevId: 479876971
2022-10-09 00:23:39 -07:00
Matthew Johnson
6d2aaac245
implement bint arrays (opaque dtypes), add padding rules
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Yash Katariya
75b2d05989
Make is_fully_replicated
and is_fully_addressble
a property rather than a method.
...
Why?
1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)
2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.
PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
674038ca47
Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
...
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d
make device_put(prngkeyarray, sharding) for Array
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
jax authors
e8ba61d82b
Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
...
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
jax authors
58cd8376ee
Merge pull request #12675 from mattjj:device-put2
...
PiperOrigin-RevId: 479660808
2022-10-07 13:49:57 -07:00
Jake VanderPlas
730f9930aa
Update scipy version in jax.scipy.fft
2022-10-07 09:57:39 -07:00
George Necula
fb2141fc3b
[jax2tf] Allow the use of DimPolynomial with jnp.array and binary operations
...
Prior to this the user had to explicitly call core.dimension_as_value whenever
using a potentially polymorphic shape in the computation, e.g., x +
core.dimension_as_value(x.shape[0]). Furthermore, jnp.array(x.shape[0])
would fail.
Now, these operations are allowed implicitly,
and the user can call `jnp.array(x.shape[0])`.
This uses an internal extensibility mechanism called __jax_array__
that is experimental and probably not fully implemented.
2022-10-07 17:58:41 +03:00
George Necula
7c7c94c8dd
Expand support for __jax_array__ in jnp.array.
...
This relates to the long discussion in #4725 and #10065 .
2022-10-07 14:25:07 +03:00
jax authors
6c70e4dcaa
Merge pull request #12691 from mattjj:issue12688
...
PiperOrigin-RevId: 479507232
2022-10-07 00:10:18 -07:00
Matthew Johnson
076a7348d0
fix -O / PYTHONOPTIMIZE bug
...
fixes #12688
I'm not sure how to write test cases for PYTHONOPTIMIZE=1 (without growing our
whole test matrix), so I'm leaving this untested...
2022-10-06 23:15:22 -07:00
Tianjian Lu
7a825362fa
[sparse] Bug fix in _validate_bcsr.
...
PiperOrigin-RevId: 479452053
2022-10-06 17:28:52 -07:00
Matthew Johnson
bcca6fb57a
add test, small fixes
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:45:34 -07:00
Jake VanderPlas
9104536a98
[typing] add type annotations to lax.linalg functions
2022-10-06 16:19:00 -07:00
Matthew Johnson
ce95ebad94
make device_put work with Sharding 2nd arg
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:14:15 -07:00
Parker Schuh
f49d3d441d
Rename Executable to LoadedExecutable within jax.
...
PiperOrigin-RevId: 479423951
2022-10-06 15:14:33 -07:00
jax authors
df4f399b33
Merge pull request #12678 from jakevdp:average-axis-tuple
...
PiperOrigin-RevId: 479414002
2022-10-06 14:38:13 -07:00
jax authors
ba2d8035dc
Merge pull request #12674 from jakevdp:svd-type
...
PiperOrigin-RevId: 479413950
2022-10-06 14:31:33 -07:00
Yash Katariya
d174b3dce3
Take shardings
as a parameter to deserialize
and run_deserialization
instead of mesh
and pspecs
.
...
PiperOrigin-RevId: 479346552
2022-10-06 10:20:49 -07:00
Jake VanderPlas
32ef3ba37b
jnp.average: support tuple axis
2022-10-06 10:20:46 -07:00
Jake VanderPlas
d94327c9e9
Move promote_like_jnp to jax.test_util
2022-10-06 10:20:26 -07:00
Roy Frostig
1b5e9e45ae
add input/output sharding to executable protocol
2022-10-05 17:30:56 -07:00