8437 Commits

Author SHA1 Message Date
jax authors
e034432872 Merge pull request #12513 from inoryy:patch-4
PiperOrigin-RevId: 476923412
2022-09-26 10:04:14 -07:00
lenamartens
27e3981d52 lowerable errors behind a config flag. 2022-09-26 17:34:27 +01:00
Roman Ring
8bcf358fde
Remove unused _remat_static_argnums import. 2022-09-26 17:14:09 +01:00
lenamartens
78ecc1442c Lowerable checks!! 2022-09-26 16:54:18 +01:00
jax authors
9c66569514 Merge pull request #12468 from LenaMartens:checkify-but-better
PiperOrigin-RevId: 476901601
2022-09-26 08:23:02 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
jax authors
ec15e83018 - Wraps calls to lax.xeinsum and _einsum in a named call with their 'spec', the string specifying the computation. Makes xprof traces more interpretable.
PiperOrigin-RevId: 476796185
2022-09-25 20:54:17 -07:00
Yash Katariya
7c85ca38f4 Only look at hlo_modules for output sharding if there is more than 1 device because if there is only 1 device, the spmd partitioner won't run.
PiperOrigin-RevId: 476497929
2022-09-23 17:31:33 -07:00
Yash Katariya
1fa0dda760 Return single device Arrays from .device_buffer and .device_buffers.
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
jax authors
43bbce0cc6 Merge pull request #12486 from hawkinsp:debugging
PiperOrigin-RevId: 476445041
2022-09-23 13:09:26 -07:00
jax authors
737327a42d Merge pull request #12490 from mattjj:improve-leak-checker
PiperOrigin-RevId: 476442352
2022-09-23 12:58:03 -07:00
Matthew Johnson
b6ef90ffdd fix leak checker internal error
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).

But after #7345, the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Yash Katariya
ecb27a9b24 Update the _check_special code to not use xla_shape since its deprecated and does not work with Array.
PiperOrigin-RevId: 476422732
2022-09-23 11:40:32 -07:00
jax authors
d078f3f5fc Merge pull request #12478 from sharadmv:sharding-docs
PiperOrigin-RevId: 476420315
2022-09-23 11:31:37 -07:00
Ke Wu
c823151771 Allow transpose axes to be negative to match (undocumented) NumPy behavior 2022-09-23 10:18:23 -07:00
Peter Hawkins
38fb8ed22f Fix copyright attribution for some newly added files.
PiperOrigin-RevId: 476390902
2022-09-23 09:32:47 -07:00
Yash Katariya
c8f55414fc Convert the devices in the Mesh constructor to a numpy array if its a list, tuple, etc.
PiperOrigin-RevId: 476380496
2022-09-23 08:48:31 -07:00
Peter Hawkins
a88c5ad789 Fix xla extension version test in debugging.py
The custom call partitioner callback was not present in version 94 but is present in version 95.
2022-09-23 10:53:50 -04:00
jax authors
254dc24a8b Merge pull request #11961 from jakeh-gc:plugin_device
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
lenamartens
7078f81dd0 Checkify: misc improvements.
- err.throw == check_error(err) -> meaning they have the same behavior
  under checkify now
- "divided by zero" -> "division by zero"
- add validation that check_error only takes args of type Error
2022-09-23 14:33:06 +01:00
Peter Hawkins
eed327914e Improve documentation for unique_indices. 2022-09-23 09:11:15 -04:00
Tianjian Lu
67b7ae259f [sparse] Move _bcoo_nse to sparse util.
PiperOrigin-RevId: 476263483
2022-09-22 20:22:06 -07:00
Sharad Vikram
99d4d8b89a Update debugging docs to have sharding visualization 2022-09-22 19:42:36 -07:00
Sharad Vikram
805073f36a Add inspect_array_sharding, enabling looking at shardings in pjit-ted functions
PiperOrigin-RevId: 476237731
2022-09-22 17:36:56 -07:00
Jake VanderPlas
3d23592cf6 [array] full_like: only match sharding if shape==None 2022-09-22 15:28:59 -07:00
jax authors
dfdf00c2eb Merge pull request #12472 from google:sharadmv-patch-2
PiperOrigin-RevId: 476190259
2022-09-22 14:00:07 -07:00
jax authors
bc08381da3 Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Sharad Vikram
1a8a8a5586
Fix example in pjit docstring 2022-09-22 12:55:55 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1 Add generic interface for auto initialization of distributed JAX service
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Yash Katariya
a157982e8c Make jit(f).lower(*args) go via lower_sharding_computation when jax_array is enabled.
PiperOrigin-RevId: 476148608
2022-09-22 11:13:33 -07:00
Tres Popp
640e15fe07 Don't tuple arguments passed to XLA:CPU
This is not needed and tuples are being avoided when possible for new code.

This is tested by CPPJitTest.test_jit_with_many_args_works in jax/tests:api_test_cpu

PiperOrigin-RevId: 476032228
2022-09-22 01:29:14 -07:00
Kuangyuan Chen
405a2310ce Implement pjit fast path in cpp for jax.Array inputs
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Yash Katariya
52476d1ab5 Add addressable_data to Array (similar to GDA) to aid in transition and also in auto spmd partitioner mode, always convert to MeshPspecSharding.
PiperOrigin-RevId: 475972534
2022-09-21 18:19:35 -07:00
jax authors
d41fa296f9 Merge pull request #12370 from jakevdp:lax-sort-overload
PiperOrigin-RevId: 475907384
2022-09-21 13:26:01 -07:00
Jake VanderPlas
2dde63334c [typing] add class-level declarations of Array members.
This fixes some pytype errors associated with the changes in #12421
2022-09-21 12:51:32 -07:00
George Karpenkov
541aadcfe8 [XLA:GPU] Allow simplifying lowering-precision-conversions by default
This might lead to the output having higher precision than specified by HLO.

PiperOrigin-RevId: 475889141
2022-09-21 12:04:45 -07:00
jax authors
c7f2712e74 Flip default value of jax_unique_mhlo_module_names to False.
This should help avoid unnecessary cache misses.

PiperOrigin-RevId: 475852954
2022-09-21 09:48:01 -07:00
jax authors
310bcd57a2 Merge pull request #12389 from LenaMartens:check-while-2
PiperOrigin-RevId: 475606527
2022-09-20 11:23:42 -07:00
lenamartens
018e700ead Checkify: support batched while. 2022-09-20 17:59:46 +01:00
jax authors
e855a9c458 Merge pull request #12428 from jakevdp:tracer-methods
PiperOrigin-RevId: 475580916
2022-09-20 09:52:56 -07:00
Jake VanderPlas
74698048f3 Tracer: add missing __round__ and __reversed__ methods 2022-09-20 09:09:23 -07:00
Yash Katariya
fc2902c6ac Make the gda and xmap sharding check work generally by checking the OpSharding protos.
PiperOrigin-RevId: 475560097
2022-09-20 08:24:47 -07:00
Sharad Vikram
0276a6e77c Add support for pmap sharding 2022-09-19 19:29:44 -07:00
Sharad Vikram
f825a3c8c0 Limit console width for visualize_sharding 2022-09-19 18:41:45 -07:00
Yash Katariya
e41e8d9a8f Only copy_to_device if the indices match. Otherwise reshard the array if its uncommitted. This is important where you have 1 process per device.
PiperOrigin-RevId: 475418561
2022-09-19 16:59:14 -07:00
jax authors
441f400358 Merge pull request #12386 from sharadmv:viz_sharding
PiperOrigin-RevId: 475387460
2022-09-19 14:36:21 -07:00
Jake VanderPlas
7ffe16b9f9 [typing] overloaded type declaration for variadic lax.sort 2022-09-19 13:40:28 -07:00
Sharad Vikram
2d8b228706 Add function to visualize Shardings 2022-09-19 13:27:08 -07:00
Yash Katariya
a24726d57c Remove fast_path_args from Array and add id checks to Sharding's __eq__ method as a fast shortcut.
Also the C++ pjit path should help optimize the dispatch path.

PiperOrigin-RevId: 475163903
2022-09-18 15:35:49 -07:00