Yash Katariya
05e1ddd4ea
Make error_test
a jax_test so that we can test other configs and fix it with jit
/pjit
merge.
...
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
Parker Schuh
b58dd3cbe1
Add support for __signature__ to PjitFunction.
...
PiperOrigin-RevId: 502731453
2023-01-17 17:28:14 -08:00
Yash Katariya
53fceab17c
pjit
allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test()
. This is because the inner backend will take priority.
...
PiperOrigin-RevId: 502721834
2023-01-17 16:39:45 -08:00
jax authors
8da6c89c7b
Merge pull request #13759 from sharadmv:io-callback
...
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716
Add IO callback
2023-01-17 13:55:05 -08:00
Yash Katariya
cb9a9952fe
Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
...
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
Yash Katariya
85654ceeab
Default dynamic_api_test and custom_object_test to take the old jit path and not the merged path since there is no pjit support for it yet.
...
PiperOrigin-RevId: 502620662
2023-01-17 10:19:39 -08:00
George Necula
7e0041c903
Fix scatter in CLIP mode with uint32 and uint64 indices
...
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python constants or C constants.
This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.
PiperOrigin-RevId: 502568204
2023-01-17 06:26:21 -08:00
Sharad Vikram
a58e59d98f
Add in effects_barrier for the pmap unordered callback test
...
PiperOrigin-RevId: 502434258
2023-01-16 14:44:44 -08:00
Matthew Johnson
e516d41180
cond transpose, use UndefinedPrimal not linear
for transpose inputs
2023-01-16 10:39:19 -08:00
Yash Katariya
4601928277
Enable jit_pjit_api_merge by default "in tests" and disable the current failing tests.
...
PiperOrigin-RevId: 502088044
2023-01-14 11:15:03 -08:00
Yash Katariya
38f91bdaa5
Skip core tests which have nested pjits and DShapedArray.
...
PiperOrigin-RevId: 502013080
2023-01-13 22:39:31 -08:00
Qiao Zhang
d203926c16
Expose fp8 in jax dtypes and mlir builder.
...
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Pankaj Kanwar
8fcb5180b2
disable flaky tests on certain targets.
...
PiperOrigin-RevId: 501974439
2023-01-13 17:23:36 -08:00
Yash Katariya
7e8fe13c6a
jit
was the default name in name_stack in mlir.py. Fix that by taking the name as an optional argument (defaulting to jit
) so that nested pjits will show up as pjit
in the name stack.
...
PiperOrigin-RevId: 501946780
2023-01-13 15:00:22 -08:00
Yash Katariya
5eb23a7615
Fix name_stack
usage of pjit. Now all the metadata of transformations in hlo are correct.
...
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
jax authors
86ba62d05e
Merge pull request #13991 from skye:pjrt_c_api_marker
...
PiperOrigin-RevId: 501909180
2023-01-13 12:15:00 -08:00
Yash Katariya
649ee1be34
Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
...
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Jake VanderPlas
7a8781db1c
[sparse] add higher-level version of bcoo_extract & improve tests
2023-01-13 07:13:13 -08:00
jax authors
44a044f936
Merge pull request #13982 from Edenhofer:fix_zero_length_meshgrid
...
PiperOrigin-RevId: 501827615
2023-01-13 06:06:52 -08:00
jax authors
9c418e9399
Merge pull request #13862 from jakevdp:bcoo-extract
...
PiperOrigin-RevId: 501826918
2023-01-13 05:59:08 -08:00
Sharad Vikram
c9a57e1b44
Delete jax.experimental.callback
...
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d
Add batch_jaxpr2 which tells the caller where batch dims are.
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a
Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
...
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5
Fix debugging primitives for pjit. This came up during jit/pjit merge
...
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358
Make jit
a thin wrapper around pjit
which ignores the mesh context manager (just like how it is today)
...
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.
This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.
PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
Parker Schuh
30d64f38f1
Add 'hard xmap' support for pure_callback.
...
PiperOrigin-RevId: 501689068
2023-01-12 15:56:50 -08:00
Jake VanderPlas
e37e3a9b0f
[sparse] bcoo_extract: add assume_unique keyword
2023-01-12 15:21:11 -08:00
jax authors
34e10e3495
Merge pull request #13892 from jakevdp:bcoo-sum-duplicates-batch
...
PiperOrigin-RevId: 501676606
2023-01-12 15:02:53 -08:00
Skye Wanderman-Milne
c0577f70f9
Migrate pytestmark usage to new @jtu.pytest_mark_if_available
decorator.
...
See discussion in https://github.com/google/jax/pull/13977 . Marking
entire modules is magical and verbose, plus less precise than marking
individual classes or tests.
I wasn't super careful on which classes to mark, and erred on the side
of marking too many classes (in line with the previous behavior). It's
possible some test classes don't actually benefit from multiple
accelerators.
2023-01-12 22:44:39 +00:00
jax authors
15ec37c581
Merge pull request #13977 from skye:pjrt_c_api_marker
...
PiperOrigin-RevId: 501671140
2023-01-12 14:41:43 -08:00
Jake VanderPlas
f314f2d504
[sparse] generalize batch rule for bcoo_sum_duplicates & improve tests
2023-01-12 14:28:12 -08:00
Skye Wanderman-Milne
f90b5eed52
Add pjrt_c_api_unimplemented
pytest marker to skip unsupported tests.
...
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Jake VanderPlas
a6a3b59748
[sparse] generalize batch rule for bcoo_dot_general
2023-01-12 12:03:21 -08:00
Jake VanderPlas
eaf6179594
[sparse] generalize batch rule for bcoo_spdot_general
2023-01-12 10:59:28 -08:00
Gordian Edenhofer
6b8125b320
mgrid: Fix zero-length meshgrid
2023-01-12 19:43:28 +01:00
Sharad Vikram
f729da4a36
Add shards for checkify_test on GPU
...
PiperOrigin-RevId: 501430172
2023-01-11 18:28:37 -08:00
jax authors
f0cca20a67
Merge pull request #13929 from nvcastet:add_ompi_init
...
PiperOrigin-RevId: 501429401
2023-01-11 18:20:51 -08:00
Yash Katariya
acd8dadc74
Make some api_test pass with jit/pjit merge
...
PiperOrigin-RevId: 501392938
2023-01-11 15:21:16 -08:00
Nicolas Castet
b86030d86f
Add Open MPI automatic distributed initialization
2023-01-11 17:08:09 -06:00
jax authors
45197ed8db
Merge pull request #13545 from LenaMartens:check-initial
...
PiperOrigin-RevId: 501248830
2023-01-11 04:52:02 -08:00
Yash Katariya
66aafb6e16
Don't take the cpp dispatch path for pjit
if it contains ordered effects just like jit
.
...
PiperOrigin-RevId: 501141750
2023-01-10 18:07:23 -08:00
Yash Katariya
c447e987e1
Skip custom_object_test and dynamic_api_test for pjit/jit merge since it doesn't work with jax.Array's too.
...
PiperOrigin-RevId: 501129056
2023-01-10 16:55:51 -08:00
Yash Katariya
e02c1da4c7
Fix debug nans test after merging jit
and pjit
codepaths
...
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
Jake VanderPlas
264bc9531e
Fix finfo test for older numpy versions
2023-01-10 12:17:31 -08:00
Jake VanderPlas
31fd81f2d5
Add tests of dtypes.finfo properties
2023-01-10 09:02:00 -08:00
Jake VanderPlas
7788f0cf6b
Re-land https://github.com/google/jax/pull/11996
...
Fixes #11965
PiperOrigin-RevId: 500841698
2023-01-09 16:57:32 -08:00
Yash Katariya
849af498d1
Make jaxpr_util_test work with jit/pjit merge
...
PiperOrigin-RevId: 500841015
2023-01-09 16:50:04 -08:00
Jake VanderPlas
f317943f56
Warn rather than fail when reloading JAX
...
Fixes https://github.com/google/jax/issues/13857
PiperOrigin-RevId: 500727768
2023-01-09 09:11:50 -08:00
Yash Katariya
74601e59e1
Fix the error message of different devices when jit/pjit are merged
...
PiperOrigin-RevId: 500727596
2023-01-09 09:03:55 -08:00