13165 Commits

Author SHA1 Message Date
jax authors
e034432872 Merge pull request #12513 from inoryy:patch-4
PiperOrigin-RevId: 476923412
jax-v0.3.18
2022-09-26 10:04:14 -07:00
jax authors
7962b01f5d Merge pull request #12485 from LenaMartens:checkify-lower
PiperOrigin-RevId: 476922387
2022-09-26 09:53:40 -07:00
lenamartens
27e3981d52 lowerable errors behind a config flag. 2022-09-26 17:34:27 +01:00
Roman Ring
8bcf358fde
Remove unused _remat_static_argnums import. 2022-09-26 17:14:09 +01:00
lenamartens
78ecc1442c Lowerable checks!! 2022-09-26 16:54:18 +01:00
jax authors
28672cca0e Merge pull request #12496 from mattjj:improve-leak-checker-2
PiperOrigin-RevId: 476907407
2022-09-26 08:50:13 -07:00
jax authors
9c66569514 Merge pull request #12468 from LenaMartens:checkify-but-better
PiperOrigin-RevId: 476901601
2022-09-26 08:23:02 -07:00
jax authors
2df61b1aa1 Merge pull request #12421 from jakevdp:jax-array
PiperOrigin-RevId: 476898184
2022-09-26 08:07:11 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
jax authors
2180710c2a Merge pull request #12511 from hawkinsp:release
PiperOrigin-RevId: 476889960
jax-v0.3.18-rc
2022-09-26 07:24:44 -07:00
Peter Hawkins
bcd36d8eb2 Jax and jaxlib 0.3.18 release candidate. 2022-09-26 14:10:57 +00:00
jax authors
53de057748 Merge pull request #12510 from hawkinsp:context
PiperOrigin-RevId: 476884674
2022-09-26 06:58:46 -07:00
Peter Hawkins
f4bc663c31 Wrap multiprocess test popen() uses in a context manager.
Ensures that resources from popen() are cleaned up.
2022-09-26 13:48:56 +00:00
jax authors
ec15e83018 - Wraps calls to lax.xeinsum and _einsum in a named call with their 'spec', the string specifying the computation. Makes xprof traces more interpretable.
PiperOrigin-RevId: 476796185
2022-09-25 20:54:17 -07:00
Yash Katariya
7c85ca38f4 Only look at hlo_modules for output sharding if there is more than 1 device because if there is only 1 device, the spmd partitioner won't run.
PiperOrigin-RevId: 476497929
2022-09-23 17:31:33 -07:00
Peter Hawkins
8ee7129874 Fix jnp.unwrap() test failures on GPU.
A recent XLA change allows XLA to use excess precision on GPU, which caused CompileAndCheck to report noticeable numerical changes for bfloat16.

In passing, also enable the comparison against NumPy test for bfloat16 by using a wrapper function.

PiperOrigin-RevId: 476494989
2022-09-23 17:11:51 -07:00
jax authors
d2fcfb6b83 Merge pull request #12407 from hirwa-nshuti:docs-fix
PiperOrigin-RevId: 476467728
2022-09-23 14:51:11 -07:00
Matthew Johnson
03abcc7c5c fix typo in test 2022-09-23 14:43:24 -07:00
jax authors
e76aa77895 Merge pull request #12437 from sudhakarsingh27:add_multi_host_pjit_tests
PiperOrigin-RevId: 476451469
2022-09-23 13:38:59 -07:00
Yash Katariya
1fa0dda760 Return single device Arrays from .device_buffer and .device_buffers.
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
jax authors
43bbce0cc6 Merge pull request #12486 from hawkinsp:debugging
PiperOrigin-RevId: 476445041
2022-09-23 13:09:26 -07:00
jax authors
737327a42d Merge pull request #12490 from mattjj:improve-leak-checker
PiperOrigin-RevId: 476442352
2022-09-23 12:58:03 -07:00
Matthew Johnson
b6ef90ffdd fix leak checker internal error
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).

But after #7345, the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Sudhakar
4dd0d85139 add multihost pjit tests 2022-09-23 12:11:56 -07:00
Jake VanderPlas
a6b24b379c Add regression test for lax.rev simplification error
PiperOrigin-RevId: 476430486
2022-09-23 12:07:15 -07:00
Yash Katariya
ecb27a9b24 Update the _check_special code to not use xla_shape since its deprecated and does not work with Array.
PiperOrigin-RevId: 476422732
2022-09-23 11:40:32 -07:00
jax authors
d078f3f5fc Merge pull request #12478 from sharadmv:sharding-docs
PiperOrigin-RevId: 476420315
2022-09-23 11:31:37 -07:00
jax authors
e8865c8264 Merge pull request #12481 from kho:changelist/476272494
PiperOrigin-RevId: 476411483
2022-09-23 10:55:10 -07:00
Ke Wu
c823151771 Allow transpose axes to be negative to match (undocumented) NumPy behavior 2022-09-23 10:18:23 -07:00
Tres Popp
0c085471c7 Modify CorrCoef test to not rely on floating poing representation of 1/3
The operation computed an average while using the dimension of size 3. This is then changed to multiplying by 1/3 with compilers, but 1/3 cannot be represented perfectly. That made this test case rely on a very precise result from an unrepresentable calculation.

PiperOrigin-RevId: 476391389
2022-09-23 09:39:01 -07:00
Peter Hawkins
38fb8ed22f Fix copyright attribution for some newly added files.
PiperOrigin-RevId: 476390902
2022-09-23 09:32:47 -07:00
jax authors
6c47dc51cb Merge pull request #12471 from ROCmSoftwarePlatform:rocm-dockerfile-update
PiperOrigin-RevId: 476387200
2022-09-23 09:16:38 -07:00
Yash Katariya
da50bdd75a Fix the asan failure in pjit_test_cpu build target
PiperOrigin-RevId: 476382929
2022-09-23 08:59:57 -07:00
Yash Katariya
c8f55414fc Convert the devices in the Mesh constructor to a numpy array if its a list, tuple, etc.
PiperOrigin-RevId: 476380496
2022-09-23 08:48:31 -07:00
Peter Hawkins
a88c5ad789 Fix xla extension version test in debugging.py
The custom call partitioner callback was not present in version 94 but is present in version 95.
2022-09-23 10:53:50 -04:00
jax authors
254dc24a8b Merge pull request #11961 from jakeh-gc:plugin_device
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
jax authors
342f896032 Merge pull request #12484 from hawkinsp:doc
PiperOrigin-RevId: 476361203
2022-09-23 07:13:50 -07:00
lenamartens
7078f81dd0 Checkify: misc improvements.
- err.throw == check_error(err) -> meaning they have the same behavior
  under checkify now
- "divided by zero" -> "division by zero"
- add validation that check_error only takes args of type Error
2022-09-23 14:33:06 +01:00
Peter Hawkins
eed327914e Improve documentation for unique_indices. 2022-09-23 09:11:15 -04:00
Felix Hirwa Nshuti
820efab6fa removed repeated nan_to_num in docs 2022-09-23 06:23:09 +00:00
Tianjian Lu
67b7ae259f [sparse] Move _bcoo_nse to sparse util.
PiperOrigin-RevId: 476263483
2022-09-22 20:22:06 -07:00
Sharad Vikram
99d4d8b89a Update debugging docs to have sharding visualization 2022-09-22 19:42:36 -07:00
Sharad Vikram
805073f36a Add inspect_array_sharding, enabling looking at shardings in pjit-ted functions
PiperOrigin-RevId: 476237731
2022-09-22 17:36:56 -07:00
jax authors
11a6fd9e79 Merge pull request #12476 from jakevdp:match-sharding
PiperOrigin-RevId: 476220059
2022-09-22 16:04:20 -07:00
Jake VanderPlas
3d23592cf6 [array] full_like: only match sharding if shape==None 2022-09-22 15:28:59 -07:00
jax authors
dfdf00c2eb Merge pull request #12472 from google:sharadmv-patch-2
PiperOrigin-RevId: 476190259
2022-09-22 14:00:07 -07:00
jax authors
bc08381da3 Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Sharad Vikram
1a8a8a5586
Fix example in pjit docstring 2022-09-22 12:55:55 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1 Add generic interface for auto initialization of distributed JAX service
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00