jax authors
aafc70d293
Merge pull request #12556 from hawkinsp:rocm
...
PiperOrigin-RevId: 477440001
2022-09-28 06:50:19 -07:00
Peter Hawkins
f7bafb3d4c
Disable multiprocess_gpu_test that fails on ROCm.
2022-09-28 13:40:57 +00:00
Peter Hawkins
eabb91e53f
Fix test failure in GPU CI if NCCL_DEBUG is enabled.
...
If NCCL_DEBUG is enabled, NCCL prints extra status information. Make
test accept this.
2022-09-28 13:06:04 +00:00
jax authors
96abd9ac75
Merge pull request #12540 from sharadmv:cond-lowering-fix
...
PiperOrigin-RevId: 477358889
2022-09-27 22:33:12 -07:00
jax authors
948906885d
Merge pull request #12546 from mattjj:issue12542
...
PiperOrigin-RevId: 477356925
2022-09-27 22:16:02 -07:00
Sharad Vikram
ddeaa8dbbc
Fix lowering bug in effectful batched cond and add tests
2022-09-27 22:12:13 -07:00
Yash Katariya
b4e1d0af8a
Propagate name
through ExecuteReplicated for dispatch.check_special
...
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Matthew Johnson
b175e11731
[c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays
...
fixes #12542
Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Kuangyuan Chen <chky@google.com>
2022-09-27 20:51:07 -07:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
jax authors
0e116888ea
Merge pull request #12382 from jakevdp:reduction-dtype
...
PiperOrigin-RevId: 477179725
2022-09-27 08:38:46 -07:00
jax authors
1bcf8d646d
Merge pull request #12497 from mattjj:djax-dag-fix1
...
PiperOrigin-RevId: 477038279
2022-09-26 18:14:56 -07:00
Matthew Johnson
1e7ca8f77a
fix bug in djax type signature inference logic
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-26 17:48:25 -07:00
Yash Katariya
cbf34cb609
Rename the concrete class Array
to ArrayImpl
...
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Tianjian Lu
71bcabe499
[sparse] Add BCSR format template.
...
PiperOrigin-RevId: 477013899
2022-09-26 16:02:16 -07:00
Peter Hawkins
d63a9442bb
Change jax_jit_test to be a jax_test() under Bazel that works across backends.
...
Make it pass under TPU if x64 types are enabled.
PiperOrigin-RevId: 476994286
2022-09-26 14:38:35 -07:00
Jake VanderPlas
1860f6d839
[x64] add promote_integers argument to jnp.prod & jnp.sum
2022-09-26 13:31:43 -07:00
lenamartens
27e3981d52
lowerable errors behind a config flag.
2022-09-26 17:34:27 +01:00
jax authors
28672cca0e
Merge pull request #12496 from mattjj:improve-leak-checker-2
...
PiperOrigin-RevId: 476907407
2022-09-26 08:50:13 -07:00
jax authors
9c66569514
Merge pull request #12468 from LenaMartens:checkify-but-better
...
PiperOrigin-RevId: 476901601
2022-09-26 08:23:02 -07:00
Jake VanderPlas
0cb233eec9
Add initial jax.Array base class for instance checks & annotation
2022-09-26 07:48:43 -07:00
Peter Hawkins
f4bc663c31
Wrap multiprocess test popen() uses in a context manager.
...
Ensures that resources from popen() are cleaned up.
2022-09-26 13:48:56 +00:00
Peter Hawkins
8ee7129874
Fix jnp.unwrap() test failures on GPU.
...
A recent XLA change allows XLA to use excess precision on GPU, which caused CompileAndCheck to report noticeable numerical changes for bfloat16.
In passing, also enable the comparison against NumPy test for bfloat16 by using a wrapper function.
PiperOrigin-RevId: 476494989
2022-09-23 17:11:51 -07:00
Matthew Johnson
03abcc7c5c
fix typo in test
2022-09-23 14:43:24 -07:00
jax authors
e76aa77895
Merge pull request #12437 from sudhakarsingh27:add_multi_host_pjit_tests
...
PiperOrigin-RevId: 476451469
2022-09-23 13:38:59 -07:00
Yash Katariya
1fa0dda760
Return single device Arrays from .device_buffer
and .device_buffers
.
...
PiperOrigin-RevId: 476449591
2022-09-23 13:30:26 -07:00
jax authors
737327a42d
Merge pull request #12490 from mattjj:improve-leak-checker
...
PiperOrigin-RevId: 476442352
2022-09-23 12:58:03 -07:00
Matthew Johnson
b6ef90ffdd
fix leak checker internal error
...
The issue was that partial_eval.py's _memoize, used in custom_jvp, was made
into an identity function by enabling config.jax_check_tracer_leaks (from
references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger
the leak checker (which would see if any references to the main trace persisted
after finishing tracing of the user function).
But after #7345 , the leak checker should only trigger when actual Tracers are
leaked. So disabling the memoization when jax_check_tracer_leaks is no longer
active shouldn't be necessary. (These PR numbers seem out of order! We're not
sure why.)
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-23 12:33:45 -07:00
Sudhakar
4dd0d85139
add multihost pjit tests
2022-09-23 12:11:56 -07:00
Jake VanderPlas
a6b24b379c
Add regression test for lax.rev simplification error
...
PiperOrigin-RevId: 476430486
2022-09-23 12:07:15 -07:00
Ke Wu
c823151771
Allow transpose axes to be negative to match (undocumented) NumPy behavior
2022-09-23 10:18:23 -07:00
Tres Popp
0c085471c7
Modify CorrCoef test to not rely on floating poing representation of 1/3
...
The operation computed an average while using the dimension of size 3. This is then changed to multiplying by 1/3 with compilers, but 1/3 cannot be represented perfectly. That made this test case rely on a very precise result from an unrepresentable calculation.
PiperOrigin-RevId: 476391389
2022-09-23 09:39:01 -07:00
Yash Katariya
da50bdd75a
Fix the asan failure in pjit_test_cpu build target
...
PiperOrigin-RevId: 476382929
2022-09-23 08:59:57 -07:00
Yash Katariya
c8f55414fc
Convert the devices in the Mesh
constructor to a numpy array if its a list, tuple, etc.
...
PiperOrigin-RevId: 476380496
2022-09-23 08:48:31 -07:00
lenamartens
7078f81dd0
Checkify: misc improvements.
...
- err.throw == check_error(err) -> meaning they have the same behavior
under checkify now
- "divided by zero" -> "division by zero"
- add validation that check_error only takes args of type Error
2022-09-23 14:33:06 +01:00
Tianjian Lu
67b7ae259f
[sparse] Move _bcoo_nse
to sparse util.
...
PiperOrigin-RevId: 476263483
2022-09-22 20:22:06 -07:00
Sharad Vikram
805073f36a
Add inspect_array_sharding, enabling looking at shardings in pjit-ted functions
...
PiperOrigin-RevId: 476237731
2022-09-22 17:36:56 -07:00
jax authors
bc08381da3
Merge pull request #12152 from nvcastet:add_slurm_orchestrator_support
...
PiperOrigin-RevId: 476179963
2022-09-22 13:18:25 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Nicolas Castet
412a5379c1
Add generic interface for auto initialization of distributed JAX service
...
* Also add slurm cluster support
2022-09-22 14:15:38 -05:00
Tyler Augustine
d52de206cb
Disable tests that timeout in debug mode in CI
...
PiperOrigin-RevId: 476157051
2022-09-22 11:44:56 -07:00
Yash Katariya
a157982e8c
Make jit(f).lower(*args)
go via lower_sharding_computation when jax_array
is enabled.
...
PiperOrigin-RevId: 476148608
2022-09-22 11:13:33 -07:00
Kuangyuan Chen
405a2310ce
Implement pjit fast path in cpp for jax.Array inputs
...
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Yash Katariya
52476d1ab5
Add addressable_data to Array (similar to GDA) to aid in transition and also in auto spmd partitioner mode, always convert to MeshPspecSharding.
...
PiperOrigin-RevId: 475972534
2022-09-21 18:19:35 -07:00
Kuangyuan Chen
a09ef8a6a6
Temporarily skip LaxBackedNumpyTests.testUnwrap on gpu to unblock jaxlib build
...
PiperOrigin-RevId: 475970440
2022-09-21 18:06:02 -07:00
George Karpenkov
541aadcfe8
[XLA:GPU] Allow simplifying lowering-precision-conversions by default
...
This might lead to the output having higher precision than specified by HLO.
PiperOrigin-RevId: 475889141
2022-09-21 12:04:45 -07:00
Peter Hawkins
d0e1c3e684
Disable tests under sanitizers that are timing out in CI.
...
PiperOrigin-RevId: 475839926
2022-09-21 08:50:55 -07:00
Yash Katariya
6183727acc
Update pjit_test to skip GDA tests with Array is enabled.
...
PiperOrigin-RevId: 475684445
2022-09-20 16:38:43 -07:00
jax authors
310bcd57a2
Merge pull request #12389 from LenaMartens:check-while-2
...
PiperOrigin-RevId: 475606527
2022-09-20 11:23:42 -07:00
lenamartens
018e700ead
Checkify: support batched while.
2022-09-20 17:59:46 +01:00
jax authors
e855a9c458
Merge pull request #12428 from jakevdp:tracer-methods
...
PiperOrigin-RevId: 475580916
2022-09-20 09:52:56 -07:00