5570 Commits

Author SHA1 Message Date
Yash Katariya
05e1ddd4ea Make error_test a jax_test so that we can test other configs and fix it with jit/pjit merge.
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
Parker Schuh
b58dd3cbe1 Add support for __signature__ to PjitFunction.
PiperOrigin-RevId: 502731453
2023-01-17 17:28:14 -08:00
Yash Katariya
53fceab17c pjit allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test(). This is because the inner backend will take priority.
PiperOrigin-RevId: 502721834
2023-01-17 16:39:45 -08:00
jax authors
8da6c89c7b Merge pull request #13759 from sharadmv:io-callback
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Yash Katariya
cb9a9952fe Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
Yash Katariya
85654ceeab Default dynamic_api_test and custom_object_test to take the old jit path and not the merged path since there is no pjit support for it yet.
PiperOrigin-RevId: 502620662
2023-01-17 10:19:39 -08:00
George Necula
7e0041c903 Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python constants or C constants.

This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.

PiperOrigin-RevId: 502568204
2023-01-17 06:26:21 -08:00
Sharad Vikram
a58e59d98f Add in effects_barrier for the pmap unordered callback test
PiperOrigin-RevId: 502434258
2023-01-16 14:44:44 -08:00
Matthew Johnson
e516d41180 cond transpose, use UndefinedPrimal not linear for transpose inputs 2023-01-16 10:39:19 -08:00
Yash Katariya
4601928277 Enable jit_pjit_api_merge by default "in tests" and disable the current failing tests.
PiperOrigin-RevId: 502088044
2023-01-14 11:15:03 -08:00
Yash Katariya
38f91bdaa5 Skip core tests which have nested pjits and DShapedArray.
PiperOrigin-RevId: 502013080
2023-01-13 22:39:31 -08:00
Qiao Zhang
d203926c16 Expose fp8 in jax dtypes and mlir builder.
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Pankaj Kanwar
8fcb5180b2 disable flaky tests on certain targets.
PiperOrigin-RevId: 501974439
2023-01-13 17:23:36 -08:00
Yash Katariya
7e8fe13c6a jit was the default name in name_stack in mlir.py. Fix that by taking the name as an optional argument (defaulting to jit) so that nested pjits will show up as pjit in the name stack.
PiperOrigin-RevId: 501946780
2023-01-13 15:00:22 -08:00
Yash Katariya
5eb23a7615 Fix name_stack usage of pjit. Now all the metadata of transformations in hlo are correct.
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
jax authors
86ba62d05e Merge pull request #13991 from skye:pjrt_c_api_marker
PiperOrigin-RevId: 501909180
2023-01-13 12:15:00 -08:00
Yash Katariya
649ee1be34 Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Jake VanderPlas
7a8781db1c [sparse] add higher-level version of bcoo_extract & improve tests 2023-01-13 07:13:13 -08:00
jax authors
44a044f936 Merge pull request #13982 from Edenhofer:fix_zero_length_meshgrid
PiperOrigin-RevId: 501827615
2023-01-13 06:06:52 -08:00
jax authors
9c418e9399 Merge pull request #13862 from jakevdp:bcoo-extract
PiperOrigin-RevId: 501826918
2023-01-13 05:59:08 -08:00
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d Add batch_jaxpr2 which tells the caller where batch dims are.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5 Fix debugging primitives for pjit. This came up during jit/pjit merge
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358 Make jit a thin wrapper around pjit which ignores the mesh context manager (just like how it is today)
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.

This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.

PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
Parker Schuh
30d64f38f1 Add 'hard xmap' support for pure_callback.
PiperOrigin-RevId: 501689068
2023-01-12 15:56:50 -08:00
Jake VanderPlas
e37e3a9b0f [sparse] bcoo_extract: add assume_unique keyword 2023-01-12 15:21:11 -08:00
jax authors
34e10e3495 Merge pull request #13892 from jakevdp:bcoo-sum-duplicates-batch
PiperOrigin-RevId: 501676606
2023-01-12 15:02:53 -08:00
Skye Wanderman-Milne
c0577f70f9 Migrate pytestmark usage to new @jtu.pytest_mark_if_available decorator.
See discussion in https://github.com/google/jax/pull/13977. Marking
entire modules is magical and verbose, plus less precise than marking
individual classes or tests.

I wasn't super careful on which classes to mark, and erred on the side
of marking too many classes (in line with the previous behavior). It's
possible some test classes don't actually benefit from multiple
accelerators.
2023-01-12 22:44:39 +00:00
jax authors
15ec37c581 Merge pull request #13977 from skye:pjrt_c_api_marker
PiperOrigin-RevId: 501671140
2023-01-12 14:41:43 -08:00
Jake VanderPlas
f314f2d504 [sparse] generalize batch rule for bcoo_sum_duplicates & improve tests 2023-01-12 14:28:12 -08:00
Skye Wanderman-Milne
f90b5eed52 Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Jake VanderPlas
a6a3b59748 [sparse] generalize batch rule for bcoo_dot_general 2023-01-12 12:03:21 -08:00
Jake VanderPlas
eaf6179594 [sparse] generalize batch rule for bcoo_spdot_general 2023-01-12 10:59:28 -08:00
Gordian Edenhofer
6b8125b320 mgrid: Fix zero-length meshgrid 2023-01-12 19:43:28 +01:00
Sharad Vikram
f729da4a36 Add shards for checkify_test on GPU
PiperOrigin-RevId: 501430172
2023-01-11 18:28:37 -08:00
jax authors
f0cca20a67 Merge pull request #13929 from nvcastet:add_ompi_init
PiperOrigin-RevId: 501429401
2023-01-11 18:20:51 -08:00
Yash Katariya
acd8dadc74 Make some api_test pass with jit/pjit merge
PiperOrigin-RevId: 501392938
2023-01-11 15:21:16 -08:00
Nicolas Castet
b86030d86f Add Open MPI automatic distributed initialization 2023-01-11 17:08:09 -06:00
jax authors
45197ed8db Merge pull request #13545 from LenaMartens:check-initial
PiperOrigin-RevId: 501248830
2023-01-11 04:52:02 -08:00
Yash Katariya
66aafb6e16 Don't take the cpp dispatch path for pjit if it contains ordered effects just like jit.
PiperOrigin-RevId: 501141750
2023-01-10 18:07:23 -08:00
Yash Katariya
c447e987e1 Skip custom_object_test and dynamic_api_test for pjit/jit merge since it doesn't work with jax.Array's too.
PiperOrigin-RevId: 501129056
2023-01-10 16:55:51 -08:00
Yash Katariya
e02c1da4c7 Fix debug nans test after merging jit and pjit codepaths
PiperOrigin-RevId: 501122848
2023-01-10 16:27:00 -08:00
Jake VanderPlas
264bc9531e Fix finfo test for older numpy versions 2023-01-10 12:17:31 -08:00
Jake VanderPlas
31fd81f2d5 Add tests of dtypes.finfo properties 2023-01-10 09:02:00 -08:00
Jake VanderPlas
7788f0cf6b Re-land https://github.com/google/jax/pull/11996
Fixes #11965

PiperOrigin-RevId: 500841698
2023-01-09 16:57:32 -08:00
Yash Katariya
849af498d1 Make jaxpr_util_test work with jit/pjit merge
PiperOrigin-RevId: 500841015
2023-01-09 16:50:04 -08:00
Jake VanderPlas
f317943f56 Warn rather than fail when reloading JAX
Fixes https://github.com/google/jax/issues/13857

PiperOrigin-RevId: 500727768
2023-01-09 09:11:50 -08:00
Yash Katariya
74601e59e1 Fix the error message of different devices when jit/pjit are merged
PiperOrigin-RevId: 500727596
2023-01-09 09:03:55 -08:00