jax authors
e034432872
Merge pull request #12513 from inoryy:patch-4
...
PiperOrigin-RevId: 476923412
2022-09-26 10:04:14 -07:00
lenamartens
27e3981d52
lowerable errors behind a config flag.
2022-09-26 17:34:27 +01:00
Roman Ring
8bcf358fde
Remove unused _remat_static_argnums import.
2022-09-26 17:14:09 +01:00
lenamartens
78ecc1442c
Lowerable checks!!
2022-09-26 16:54:18 +01:00
jax authors
9c66569514
Merge pull request #12468 from LenaMartens:checkify-but-better
...
PiperOrigin-RevId: 476901601
2022-09-26 08:23:02 -07:00
Jake VanderPlas
0cb233eec9
Add initial jax.Array base class for instance checks & annotation
2022-09-26 07:48:43 -07:00
jax authors
ec15e83018
- Wraps calls to lax.xeinsum and _einsum in a named call with their 'spec', the string specifying the computation. Makes xprof traces more interpretable.
...
PiperOrigin-RevId: 476796185
2022-09-25 20:54:17 -07:00
Yash Katariya
7c85ca38f4
Only look at hlo_modules for output sharding if there is more than 1 device because if there is only 1 device, the spmd partitioner won't run.
...
PiperOrigin-RevId: 476497929
2022-09-23 17:31:33 -07:00
Yash Katariya
1fa0dda760
Return single device Arrays from .device_buffer
and .device_buffers
.
...
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
jax authors
43bbce0cc6
Merge pull request #12486 from hawkinsp:debugging
...
PiperOrigin-RevId: 476445041
2022-09-23 13:09:26 -07:00
jax authors
737327a42d
Merge pull request #12490 from mattjj:improve-leak-checker
...
PiperOrigin-RevId: 476442352
2022-09-23 12:58:03 -07:00
Matthew Johnson
b6ef90ffdd
fix leak checker internal error
...
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).
But after #7345 , the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Yash Katariya
ecb27a9b24
Update the _check_special
code to not use xla_shape since its deprecated and does not work with Array.
...
PiperOrigin-RevId: 476422732
2022-09-23 11:40:32 -07:00
jax authors
d078f3f5fc
Merge pull request #12478 from sharadmv:sharding-docs
...
PiperOrigin-RevId: 476420315
2022-09-23 11:31:37 -07:00
Ke Wu
c823151771
Allow transpose axes to be negative to match (undocumented) NumPy behavior
2022-09-23 10:18:23 -07:00
Peter Hawkins
38fb8ed22f
Fix copyright attribution for some newly added files.
...
PiperOrigin-RevId: 476390902
2022-09-23 09:32:47 -07:00
Yash Katariya
c8f55414fc
Convert the devices in the Mesh
constructor to a numpy array if its a list, tuple, etc.
...
PiperOrigin-RevId: 476380496
2022-09-23 08:48:31 -07:00
Peter Hawkins
a88c5ad789
Fix xla extension version test in debugging.py
...
The custom call partitioner callback was not present in version 94 but is present in version 95.
2022-09-23 10:53:50 -04:00
jax authors
254dc24a8b
Merge pull request #11961 from jakeh-gc:plugin_device
...
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
lenamartens
7078f81dd0
Checkify: misc improvements.
...
- err.throw == check_error(err) -> meaning they have the same behavior
under checkify now
- "divided by zero" -> "division by zero"
- add validation that check_error only takes args of type Error
2022-09-23 14:33:06 +01:00
Peter Hawkins
eed327914e
Improve documentation for unique_indices.
2022-09-23 09:11:15 -04:00
Tianjian Lu
67b7ae259f
[sparse] Move _bcoo_nse
to sparse util.
...
PiperOrigin-RevId: 476263483
2022-09-22 20:22:06 -07:00
Sharad Vikram
99d4d8b89a
Update debugging docs to have sharding visualization
2022-09-22 19:42:36 -07:00
Sharad Vikram
805073f36a
Add inspect_array_sharding, enabling looking at shardings in pjit-ted functions
...
PiperOrigin-RevId: 476237731
2022-09-22 17:36:56 -07:00
Jake VanderPlas
3d23592cf6
[array] full_like: only match sharding if shape==None
2022-09-22 15:28:59 -07:00
jax authors
dfdf00c2eb
Merge pull request #12472 from google:sharadmv-patch-2
...
PiperOrigin-RevId: 476190259
2022-09-22 14:00:07 -07:00
jax authors
bc08381da3
Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
...
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Sharad Vikram
1a8a8a5586
Fix example in pjit
docstring
2022-09-22 12:55:55 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1
Add generic interface for auto initialization of distributed JAX service
...
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Yash Katariya
a157982e8c
Make jit(f).lower(*args)
go via lower_sharding_computation when jax_array
is enabled.
...
PiperOrigin-RevId: 476148608
2022-09-22 11:13:33 -07:00
Tres Popp
640e15fe07
Don't tuple arguments passed to XLA:CPU
...
This is not needed and tuples are being avoided when possible for new code.
This is tested by CPPJitTest.test_jit_with_many_args_works in jax/tests:api_test_cpu
PiperOrigin-RevId: 476032228
2022-09-22 01:29:14 -07:00
Kuangyuan Chen
405a2310ce
Implement pjit fast path in cpp for jax.Array inputs
...
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Yash Katariya
52476d1ab5
Add addressable_data to Array (similar to GDA) to aid in transition and also in auto spmd partitioner mode, always convert to MeshPspecSharding.
...
PiperOrigin-RevId: 475972534
2022-09-21 18:19:35 -07:00
jax authors
d41fa296f9
Merge pull request #12370 from jakevdp:lax-sort-overload
...
PiperOrigin-RevId: 475907384
2022-09-21 13:26:01 -07:00
Jake VanderPlas
2dde63334c
[typing] add class-level declarations of Array members.
...
This fixes some pytype errors associated with the changes in #12421
2022-09-21 12:51:32 -07:00
George Karpenkov
541aadcfe8
[XLA:GPU] Allow simplifying lowering-precision-conversions by default
...
This might lead to the output having higher precision than specified by HLO.
PiperOrigin-RevId: 475889141
2022-09-21 12:04:45 -07:00
jax authors
c7f2712e74
Flip default value of jax_unique_mhlo_module_names to False.
...
This should help avoid unnecessary cache misses.
PiperOrigin-RevId: 475852954
2022-09-21 09:48:01 -07:00
jax authors
310bcd57a2
Merge pull request #12389 from LenaMartens:check-while-2
...
PiperOrigin-RevId: 475606527
2022-09-20 11:23:42 -07:00
lenamartens
018e700ead
Checkify: support batched while.
2022-09-20 17:59:46 +01:00
jax authors
e855a9c458
Merge pull request #12428 from jakevdp:tracer-methods
...
PiperOrigin-RevId: 475580916
2022-09-20 09:52:56 -07:00
Jake VanderPlas
74698048f3
Tracer: add missing __round__ and __reversed__ methods
2022-09-20 09:09:23 -07:00
Yash Katariya
fc2902c6ac
Make the gda and xmap sharding check work generally by checking the OpSharding protos.
...
PiperOrigin-RevId: 475560097
2022-09-20 08:24:47 -07:00
Sharad Vikram
0276a6e77c
Add support for pmap sharding
2022-09-19 19:29:44 -07:00
Sharad Vikram
f825a3c8c0
Limit console width for visualize_sharding
2022-09-19 18:41:45 -07:00
Yash Katariya
e41e8d9a8f
Only copy_to_device if the indices match. Otherwise reshard the array if its uncommitted. This is important where you have 1 process per device.
...
PiperOrigin-RevId: 475418561
2022-09-19 16:59:14 -07:00
jax authors
441f400358
Merge pull request #12386 from sharadmv:viz_sharding
...
PiperOrigin-RevId: 475387460
2022-09-19 14:36:21 -07:00
Jake VanderPlas
7ffe16b9f9
[typing] overloaded type declaration for variadic lax.sort
2022-09-19 13:40:28 -07:00
Sharad Vikram
2d8b228706
Add function to visualize Sharding
s
2022-09-19 13:27:08 -07:00
Yash Katariya
a24726d57c
Remove fast_path_args from Array and add id
checks to Sharding's __eq__
method as a fast shortcut.
...
Also the C++ pjit path should help optimize the dispatch path.
PiperOrigin-RevId: 475163903
2022-09-18 15:35:49 -07:00