jax authors
ec15e83018
- Wraps calls to lax.xeinsum and _einsum in a named call with their 'spec', the string specifying the computation. Makes xprof traces more interpretable.
...
PiperOrigin-RevId: 476796185
2022-09-25 20:54:17 -07:00
Yash Katariya
7c85ca38f4
Only look at hlo_modules for output sharding if there is more than 1 device because if there is only 1 device, the spmd partitioner won't run.
...
PiperOrigin-RevId: 476497929
2022-09-23 17:31:33 -07:00
Yash Katariya
1fa0dda760
Return single device Arrays from .device_buffer
and .device_buffers
.
...
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
jax authors
43bbce0cc6
Merge pull request #12486 from hawkinsp:debugging
...
PiperOrigin-RevId: 476445041
2022-09-23 13:09:26 -07:00
jax authors
737327a42d
Merge pull request #12490 from mattjj:improve-leak-checker
...
PiperOrigin-RevId: 476442352
2022-09-23 12:58:03 -07:00
Matthew Johnson
b6ef90ffdd
fix leak checker internal error
...
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).
But after #7345 , the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Yash Katariya
ecb27a9b24
Update the _check_special
code to not use xla_shape since its deprecated and does not work with Array.
...
PiperOrigin-RevId: 476422732
2022-09-23 11:40:32 -07:00
jax authors
d078f3f5fc
Merge pull request #12478 from sharadmv:sharding-docs
...
PiperOrigin-RevId: 476420315
2022-09-23 11:31:37 -07:00
Ke Wu
c823151771
Allow transpose axes to be negative to match (undocumented) NumPy behavior
2022-09-23 10:18:23 -07:00
Peter Hawkins
38fb8ed22f
Fix copyright attribution for some newly added files.
...
PiperOrigin-RevId: 476390902
2022-09-23 09:32:47 -07:00
Yash Katariya
c8f55414fc
Convert the devices in the Mesh
constructor to a numpy array if its a list, tuple, etc.
...
PiperOrigin-RevId: 476380496
2022-09-23 08:48:31 -07:00
Peter Hawkins
a88c5ad789
Fix xla extension version test in debugging.py
...
The custom call partitioner callback was not present in version 94 but is present in version 95.
2022-09-23 10:53:50 -04:00
jax authors
254dc24a8b
Merge pull request #11961 from jakeh-gc:plugin_device
...
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
Peter Hawkins
eed327914e
Improve documentation for unique_indices.
2022-09-23 09:11:15 -04:00
Tianjian Lu
67b7ae259f
[sparse] Move _bcoo_nse
to sparse util.
...
PiperOrigin-RevId: 476263483
2022-09-22 20:22:06 -07:00
Sharad Vikram
99d4d8b89a
Update debugging docs to have sharding visualization
2022-09-22 19:42:36 -07:00
Sharad Vikram
805073f36a
Add inspect_array_sharding, enabling looking at shardings in pjit-ted functions
...
PiperOrigin-RevId: 476237731
2022-09-22 17:36:56 -07:00
Jake VanderPlas
3d23592cf6
[array] full_like: only match sharding if shape==None
2022-09-22 15:28:59 -07:00
jax authors
dfdf00c2eb
Merge pull request #12472 from google:sharadmv-patch-2
...
PiperOrigin-RevId: 476190259
2022-09-22 14:00:07 -07:00
jax authors
bc08381da3
Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
...
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Sharad Vikram
1a8a8a5586
Fix example in pjit
docstring
2022-09-22 12:55:55 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1
Add generic interface for auto initialization of distributed JAX service
...
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Yash Katariya
a157982e8c
Make jit(f).lower(*args)
go via lower_sharding_computation when jax_array
is enabled.
...
PiperOrigin-RevId: 476148608
2022-09-22 11:13:33 -07:00
Tres Popp
640e15fe07
Don't tuple arguments passed to XLA:CPU
...
This is not needed and tuples are being avoided when possible for new code.
This is tested by CPPJitTest.test_jit_with_many_args_works in jax/tests:api_test_cpu
PiperOrigin-RevId: 476032228
2022-09-22 01:29:14 -07:00
Kuangyuan Chen
405a2310ce
Implement pjit fast path in cpp for jax.Array inputs
...
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Yash Katariya
52476d1ab5
Add addressable_data to Array (similar to GDA) to aid in transition and also in auto spmd partitioner mode, always convert to MeshPspecSharding.
...
PiperOrigin-RevId: 475972534
2022-09-21 18:19:35 -07:00
jax authors
d41fa296f9
Merge pull request #12370 from jakevdp:lax-sort-overload
...
PiperOrigin-RevId: 475907384
2022-09-21 13:26:01 -07:00
Jake VanderPlas
2dde63334c
[typing] add class-level declarations of Array members.
...
This fixes some pytype errors associated with the changes in #12421
2022-09-21 12:51:32 -07:00
George Karpenkov
541aadcfe8
[XLA:GPU] Allow simplifying lowering-precision-conversions by default
...
This might lead to the output having higher precision than specified by HLO.
PiperOrigin-RevId: 475889141
2022-09-21 12:04:45 -07:00
jax authors
c7f2712e74
Flip default value of jax_unique_mhlo_module_names to False.
...
This should help avoid unnecessary cache misses.
PiperOrigin-RevId: 475852954
2022-09-21 09:48:01 -07:00
jax authors
310bcd57a2
Merge pull request #12389 from LenaMartens:check-while-2
...
PiperOrigin-RevId: 475606527
2022-09-20 11:23:42 -07:00
lenamartens
018e700ead
Checkify: support batched while.
2022-09-20 17:59:46 +01:00
jax authors
e855a9c458
Merge pull request #12428 from jakevdp:tracer-methods
...
PiperOrigin-RevId: 475580916
2022-09-20 09:52:56 -07:00
Jake VanderPlas
74698048f3
Tracer: add missing __round__ and __reversed__ methods
2022-09-20 09:09:23 -07:00
Yash Katariya
fc2902c6ac
Make the gda and xmap sharding check work generally by checking the OpSharding protos.
...
PiperOrigin-RevId: 475560097
2022-09-20 08:24:47 -07:00
Sharad Vikram
0276a6e77c
Add support for pmap sharding
2022-09-19 19:29:44 -07:00
Sharad Vikram
f825a3c8c0
Limit console width for visualize_sharding
2022-09-19 18:41:45 -07:00
Yash Katariya
e41e8d9a8f
Only copy_to_device if the indices match. Otherwise reshard the array if its uncommitted. This is important where you have 1 process per device.
...
PiperOrigin-RevId: 475418561
2022-09-19 16:59:14 -07:00
jax authors
441f400358
Merge pull request #12386 from sharadmv:viz_sharding
...
PiperOrigin-RevId: 475387460
2022-09-19 14:36:21 -07:00
Jake VanderPlas
7ffe16b9f9
[typing] overloaded type declaration for variadic lax.sort
2022-09-19 13:40:28 -07:00
Sharad Vikram
2d8b228706
Add function to visualize Sharding
s
2022-09-19 13:27:08 -07:00
Yash Katariya
a24726d57c
Remove fast_path_args from Array and add id
checks to Sharding's __eq__
method as a fast shortcut.
...
Also the C++ pjit path should help optimize the dispatch path.
PiperOrigin-RevId: 475163903
2022-09-18 15:35:49 -07:00
Yash Katariya
9d8363a5d6
Fix the bug where the indices returned from _get_input_metadata
were of length equal to the length of global devices but should have been the length of local devices instead. shard_args
only deals with local devices and indices.
...
Also, enable multihost pjit_test.py with Array.
PiperOrigin-RevId: 475044692
2022-09-17 13:29:40 -07:00
Yash Katariya
590b5b5d7f
Add Array
counterparts to the serialization_test.py and disable the GDA tests if jax_array is enabled.
...
PiperOrigin-RevId: 474944400
2022-09-16 18:37:50 -07:00
Rebecca Chen
dce93e45bb
Silence some pytype errors.
...
PiperOrigin-RevId: 474898625
2022-09-16 14:11:58 -07:00
Yash Katariya
eec1b4a017
Set the sharding of uncommitted single device sharding Arrays correctly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes).
...
PiperOrigin-RevId: 474858389
2022-09-16 11:16:27 -07:00
Yash Katariya
e010ae7845
Pass device_assignment to ShardingContext instead of first_sharding which contains partitioning of an input too.
...
It does not make sense to pass how an input is partitioned to ShardingContext because you can have `n` inputs all partitioned in a different way but all of them should have the same device_assignment. This follows SPMDAxisContext too.
PiperOrigin-RevId: 474808207
2022-09-16 07:18:50 -07:00
Jake VanderPlas
a423dc7cd7
tests: fix is_valid_shape() function
2022-09-15 15:28:55 -07:00
jax authors
a7f4cb028c
Merge pull request #12143 from wonhyeongseo:multinomial
...
PiperOrigin-RevId: 474659934
2022-09-15 14:37:38 -07:00