Adam Paszke
9c994985a4
Support MANUAL collectives in top-level xmaps
...
It's a bit of a weird use-case, since MANUAL mode is meant for xmaps that
are nested inside pjits, but it doesn't hurt us to support it.
PiperOrigin-RevId: 480342531
2022-10-11 06:47:30 -07:00
Yash Katariya
ff17d3d9fe
Add support for calculating the device_assignment when there are no inputs to jit
and pjit
.
...
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).
Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796
2022-10-10 22:08:42 -07:00
Matthew Johnson
df5f7cb8d3
Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
...
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
Yash Katariya
9b3e864731
Add weak_type attribute to Array
since it exists on DA (but doesn't exist on SDA).
...
PiperOrigin-RevId: 480223116
2022-10-10 18:11:11 -07:00
Kuangyuan Chen
ec5b1c93d7
Turn on cpp pjit py default
...
PiperOrigin-RevId: 480185387
2022-10-10 15:01:04 -07:00
Peter Hawkins
2ba0396ddb
Add changes accidentally omitted from
...
https://github.com/google/jax/pull/12717
2022-10-10 19:11:58 +00:00
Tianjian Lu
34eb6ce36b
[sparse] BCSR fromdense and todense.
...
PiperOrigin-RevId: 480141918
2022-10-10 11:54:22 -07:00
Peter Hawkins
c657449528
Copybara import of the project:
...
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
jax authors
9cabd227d7
Copybara import of the project:
...
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:
implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
jax authors
25c6ef7ff9
Merge pull request #12707 from mattjj:djax-slice-sick4
...
PiperOrigin-RevId: 479876971
2022-10-09 00:23:39 -07:00
Matthew Johnson
6d2aaac245
implement bint arrays (opaque dtypes), add padding rules
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Yash Katariya
75b2d05989
Make is_fully_replicated
and is_fully_addressble
a property rather than a method.
...
Why?
1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article)
2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices.
PiperOrigin-RevId: 479850850
2022-10-08 19:24:12 -07:00
jax authors
674038ca47
Merge pull request #12705 from mattjj:fix-prng-key-array-device-put
...
PiperOrigin-RevId: 479813689
2022-10-08 11:39:05 -07:00
Matthew Johnson
0a0f492a3d
make device_put(prngkeyarray, sharding) for Array
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-07 16:50:16 -07:00
jax authors
e8ba61d82b
Merge pull request #12677 from mattjj:jit-pjit-lower-sharding
...
PiperOrigin-RevId: 479669125
2022-10-07 14:28:51 -07:00
jax authors
58cd8376ee
Merge pull request #12675 from mattjj:device-put2
...
PiperOrigin-RevId: 479660808
2022-10-07 13:49:57 -07:00
Jake VanderPlas
730f9930aa
Update scipy version in jax.scipy.fft
2022-10-07 09:57:39 -07:00
jax authors
fde444e735
Merge pull request #12695 from gnecula:tf_dimension
...
PiperOrigin-RevId: 479602588
2022-10-07 09:38:45 -07:00
jax authors
363cc124e3
Merge pull request #12197 from ROCmSoftwarePlatform:fixedRocmUnitTestsSkip
...
PiperOrigin-RevId: 479566021
2022-10-07 06:36:11 -07:00
George Necula
7c7c94c8dd
Expand support for __jax_array__ in jnp.array.
...
This relates to the long discussion in #4725 and #10065 .
2022-10-07 14:25:07 +03:00
Matthew Johnson
bcca6fb57a
add test, small fixes
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:45:34 -07:00
Matthew Johnson
ce95ebad94
make device_put work with Sharding 2nd arg
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:14:15 -07:00
jax authors
df4f399b33
Merge pull request #12678 from jakevdp:average-axis-tuple
...
PiperOrigin-RevId: 479414002
2022-10-06 14:38:13 -07:00
Peter Hawkins
8107e3600e
Switch lax_numpy_indexing_test to use jtu.sample_product.
2022-10-06 17:44:17 +00:00
Jake VanderPlas
32ef3ba37b
jnp.average: support tuple axis
2022-10-06 10:20:46 -07:00
Jake VanderPlas
d94327c9e9
Move promote_like_jnp to jax.test_util
2022-10-06 10:20:26 -07:00
Jake VanderPlas
ff0810998e
test: fix LaxNumpyTest:testConcatenate
2022-10-05 15:29:15 -07:00
Yash Katariya
9d12e216d5
Add addressable_shards
to SDA and DA as a compatibility API to match with jax.Array
. This will aid in transition to jax.Array
.
...
PiperOrigin-RevId: 479115126
2022-10-05 12:35:17 -07:00
jax authors
3ec5e456b8
Merge pull request #12660 from hawkinsp:testing
...
PiperOrigin-RevId: 479098519
2022-10-05 11:33:55 -07:00
Yash Katariya
dd0a455a78
Add sharding to DeviceArray
and ShardedDeviceArray
as a compatibility change to rollout jax.Array
.
...
Also expose `device_replica_id_map` since that is important API for checkpointing to find all the unique shards of an Array. You can also use this to calculate the unique indices in a sharding (which is what `sda.one_replica_buffer_indices` does)
PiperOrigin-RevId: 479072520
2022-10-05 09:58:41 -07:00
Peter Hawkins
2c946b3b56
Migrate api_test, lax_numpy_test, and lax_vmap_test to
...
jtu.sample_product.
Gives a ~2x improvement in pytest --collect-only timing for
lax_numpy_test.
2022-10-05 13:46:19 +00:00
Sharad Vikram
a60ca9f051
Test that array layout is preserved in Python callbacks
...
PiperOrigin-RevId: 478852392
2022-10-04 12:14:47 -07:00
Jake VanderPlas
0d9367972b
jax.jacobian: propagate function signature to transformed function
2022-10-04 10:21:54 -07:00
Tianjian Lu
ae49d2e033
[sparse] Add conversions between BCSR and BCOO.
...
PiperOrigin-RevId: 478816413
2022-10-04 10:00:16 -07:00
Yash Katariya
37f9db77f7
Create Array
s from __getitem__
and __iter__
. This is done by device_put
ting from the host to default device which is suboptimal. But there is a TODO to fix this!
...
PiperOrigin-RevId: 478691051
2022-10-03 22:29:03 -07:00
jax authors
682e86cd63
Merge pull request #12628 from hawkinsp:testing
...
PiperOrigin-RevId: 478650729
2022-10-03 17:50:38 -07:00
Peter Hawkins
c7e5d3dc95
Add an internal jtu.sample_product test decorator.
...
This decorator samples from a cartesian product of parameterized tests
without materializing the full product explicitly.
Update lax_test.py to use the new decorator.
On my desktop machine, this improves the timing for `pytest
--collect-only tests/lax_test.py` from 6.8s to 1.9s.
2022-10-04 00:39:22 +00:00
jax authors
4fb5da4daf
Merge pull request #12612 from ROCmSoftwarePlatform:rocm_53_enhancements
...
PiperOrigin-RevId: 478603111
2022-10-03 14:14:56 -07:00
jax authors
6318fdc17b
Merge pull request #12611 from mattjj:custom-vjp-improve-type-error-checking
...
PiperOrigin-RevId: 478570245
2022-10-03 12:12:01 -07:00
jax authors
09720b9bcb
Merge pull request #12607 from jakevdp:fix-pure-callback-batch
...
PiperOrigin-RevId: 478190633
2022-10-01 05:06:44 -07:00
Rohit Santhanam
b815ac9d8e
[ROCm] Upgrade to ROCm 5.3 and associated enhancements
2022-10-01 04:45:26 -07:00
Matthew Johnson
b8c87bc9de
improve custom_jvp/vjp error messages
...
In particular:
* add function names so it's clear what decorated functions and rules
are causing the error;
* when possible (because the functions were run), check for agreement of pytree
structure and leaf shapes/dtypes between the primal function and rules
context: https://github.com/lucidrains/flash-attention-jax/issues/7
2022-09-30 22:41:43 -07:00
Jake VanderPlas
439217644a
Split parts of lax_numpy_test.py into separate test files.
...
Why? The main test file is getting too big and this hinders iteration on individual tests
PiperOrigin-RevId: 478130215
2022-09-30 19:38:11 -07:00
Jake VanderPlas
1c55f265dd
pure_callback: fix batching rule for multiple arguments
2022-09-30 15:35:42 -07:00
Yash Katariya
fb8558cfdd
Add jax_array coverage to debug_nans_test
...
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
jax authors
8a1e0ed13f
Merge pull request #12594 from skye:cache_warnings
...
PiperOrigin-RevId: 478063392
2022-09-30 13:11:34 -07:00
Skye Wanderman-Milne
15e5f38a16
Make persistent compilation cache warn instead of raise an error on cache read/write failures
...
Fixes #12582 . Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.
Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
Yash Katariya
aafc77d3c0
Improve the checks done in Array
and apply them to all Sharding
s rather than just XLACompatibleSharding
.
...
Also check the symmetric difference of sharding and `_arrays` devices.
PiperOrigin-RevId: 478017409
2022-09-30 09:56:16 -07:00
Yash Katariya
9ff570e6c3
Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place.
...
PiperOrigin-RevId: 477850567
2022-09-29 16:29:58 -07:00
Yash Katariya
3c7d927a2c
Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated.
...
PiperOrigin-RevId: 477833061
2022-09-29 15:09:55 -07:00