5429 Commits

Author SHA1 Message Date
jax authors
30a0df2b37 Merge pull request #14066 from rmlarsen:patch-1
PiperOrigin-RevId: 502980887
2023-01-18 14:33:44 -08:00
Peter Hawkins
1929d34cb6 Fix test failures under NumPy 1.24.
NumPy 1.24 release notes: https://numpy.org/devdocs/release/1.24.0-notes.html

The fixes vary, but there are three particularly common changes:
* NumPy 1.24 removes a number of deprecated NumPy type aliases references (np.bool, np.int, np.float, np.complex, np.object, np.str, np.unicode, np.long). This change replaces them with their recommended replacements (bool, int, float, complex, object, str, str, int).
* Under NumPy 1.24 no longer automatically infers dtype=object when ragged sequences are passed to np.array(). See https://numpy.org/neps/nep-0034-infer-dtype-is-object.html . In most cases the fix is to pass dtype=object explicitly, but in some cases where the raggedness seems accidental other fixes were used.
* NumPy 1.24 is pickier about the dtype= option passed to comparison ufuncs.

PiperOrigin-RevId: 502979152
2023-01-18 14:25:33 -08:00
Qiao Zhang
d58266eac7 Store sorted flattened dict keys in PyTree as a c++ vector instead of py::list to avoid creating new python object on every single dict flatten. For deeply nested dict, this avoids excessive gc pressure and avoids the slowdown whenever gc needs to sweep too many live python objects.
PiperOrigin-RevId: 502967020
2023-01-18 13:40:43 -08:00
Rasmus Munk Larsen
2ee33a0728
Update lax_control_flow_test.py
Fix brittle scan test. Adding tan(randn) is numerically brittle because evaluating tan() near its singularities is ill-conditioned.
2023-01-18 13:11:56 -08:00
Yash Katariya
4add3b8cee Make pjit an AxisPrimitive so that it can run the batching rules even if the argument is not batched but there is a axis_index/named shapes inside the pjitted function.
PiperOrigin-RevId: 502955369
2023-01-18 12:56:07 -08:00
George Necula
30cf057bf3 [host_callback] Add device_index to hcb.call and add tests
The device_index feature works only with outfeed, add an
error message.

PiperOrigin-RevId: 502951721
2023-01-18 12:41:11 -08:00
Skye Wanderman-Milne
6d0e22eaf9 Don't run FP8 dtype test on TPU.
This change makes dtypes_test.py pass even when not using Bazel (e.g. with
pytest). It also improves TPU coverage when using Bazel.

PiperOrigin-RevId: 502930531
2023-01-18 11:22:17 -08:00
Jake VanderPlas
6376dc9616 Fix excessive recompiles in lax.cond 2023-01-18 10:17:01 -08:00
Yash Katariya
a37121e195 Don't depend on flatten_axis_resources which will error because flatten_axes passes a dummy object() which doesn't work with checks in user pytrees.
Only do this if the original {in|out}_shardings are _UNSPECIFIED.

PiperOrigin-RevId: 502792305
2023-01-18 00:13:04 -08:00
Yash Katariya
05e1ddd4ea Make error_test a jax_test so that we can test other configs and fix it with jit/pjit merge.
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
Parker Schuh
b58dd3cbe1 Add support for __signature__ to PjitFunction.
PiperOrigin-RevId: 502731453
2023-01-17 17:28:14 -08:00
Yash Katariya
53fceab17c pjit allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test(). This is because the inner backend will take priority.
PiperOrigin-RevId: 502721834
2023-01-17 16:39:45 -08:00
jax authors
8da6c89c7b Merge pull request #13759 from sharadmv:io-callback
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716 Add IO callback 2023-01-17 13:55:05 -08:00
Yash Katariya
cb9a9952fe Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
Yash Katariya
85654ceeab Default dynamic_api_test and custom_object_test to take the old jit path and not the merged path since there is no pjit support for it yet.
PiperOrigin-RevId: 502620662
2023-01-17 10:19:39 -08:00
George Necula
7e0041c903 Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python constants or C constants.

This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.

PiperOrigin-RevId: 502568204
2023-01-17 06:26:21 -08:00
Sharad Vikram
a58e59d98f Add in effects_barrier for the pmap unordered callback test
PiperOrigin-RevId: 502434258
2023-01-16 14:44:44 -08:00
Matthew Johnson
e516d41180 cond transpose, use UndefinedPrimal not linear for transpose inputs 2023-01-16 10:39:19 -08:00
Yash Katariya
4601928277 Enable jit_pjit_api_merge by default "in tests" and disable the current failing tests.
PiperOrigin-RevId: 502088044
2023-01-14 11:15:03 -08:00
Yash Katariya
38f91bdaa5 Skip core tests which have nested pjits and DShapedArray.
PiperOrigin-RevId: 502013080
2023-01-13 22:39:31 -08:00
Qiao Zhang
d203926c16 Expose fp8 in jax dtypes and mlir builder.
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Pankaj Kanwar
8fcb5180b2 disable flaky tests on certain targets.
PiperOrigin-RevId: 501974439
2023-01-13 17:23:36 -08:00
Yash Katariya
7e8fe13c6a jit was the default name in name_stack in mlir.py. Fix that by taking the name as an optional argument (defaulting to jit) so that nested pjits will show up as pjit in the name stack.
PiperOrigin-RevId: 501946780
2023-01-13 15:00:22 -08:00
Yash Katariya
5eb23a7615 Fix name_stack usage of pjit. Now all the metadata of transformations in hlo are correct.
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
jax authors
86ba62d05e Merge pull request #13991 from skye:pjrt_c_api_marker
PiperOrigin-RevId: 501909180
2023-01-13 12:15:00 -08:00
Yash Katariya
649ee1be34 Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Jake VanderPlas
7a8781db1c [sparse] add higher-level version of bcoo_extract & improve tests 2023-01-13 07:13:13 -08:00
jax authors
44a044f936 Merge pull request #13982 from Edenhofer:fix_zero_length_meshgrid
PiperOrigin-RevId: 501827615
2023-01-13 06:06:52 -08:00
jax authors
9c418e9399 Merge pull request #13862 from jakevdp:bcoo-extract
PiperOrigin-RevId: 501826918
2023-01-13 05:59:08 -08:00
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
e21c29476d Add batch_jaxpr2 which tells the caller where batch dims are.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Yash Katariya
94f0ccc54a Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
PiperOrigin-RevId: 501713533
2023-01-12 18:00:33 -08:00
Yash Katariya
936247a7e5 Fix debugging primitives for pjit. This came up during jit/pjit merge
PiperOrigin-RevId: 501710198
2023-01-12 17:40:35 -08:00
Yash Katariya
c8ad89e358 Make jit a thin wrapper around pjit which ignores the mesh context manager (just like how it is today)
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.

This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.

PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
Parker Schuh
30d64f38f1 Add 'hard xmap' support for pure_callback.
PiperOrigin-RevId: 501689068
2023-01-12 15:56:50 -08:00
Jake VanderPlas
e37e3a9b0f [sparse] bcoo_extract: add assume_unique keyword 2023-01-12 15:21:11 -08:00
jax authors
34e10e3495 Merge pull request #13892 from jakevdp:bcoo-sum-duplicates-batch
PiperOrigin-RevId: 501676606
2023-01-12 15:02:53 -08:00
Skye Wanderman-Milne
c0577f70f9 Migrate pytestmark usage to new @jtu.pytest_mark_if_available decorator.
See discussion in https://github.com/google/jax/pull/13977. Marking
entire modules is magical and verbose, plus less precise than marking
individual classes or tests.

I wasn't super careful on which classes to mark, and erred on the side
of marking too many classes (in line with the previous behavior). It's
possible some test classes don't actually benefit from multiple
accelerators.
2023-01-12 22:44:39 +00:00
jax authors
15ec37c581 Merge pull request #13977 from skye:pjrt_c_api_marker
PiperOrigin-RevId: 501671140
2023-01-12 14:41:43 -08:00
Jake VanderPlas
f314f2d504 [sparse] generalize batch rule for bcoo_sum_duplicates & improve tests 2023-01-12 14:28:12 -08:00
Skye Wanderman-Milne
f90b5eed52 Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Jake VanderPlas
a6a3b59748 [sparse] generalize batch rule for bcoo_dot_general 2023-01-12 12:03:21 -08:00
Jake VanderPlas
eaf6179594 [sparse] generalize batch rule for bcoo_spdot_general 2023-01-12 10:59:28 -08:00
Gordian Edenhofer
6b8125b320 mgrid: Fix zero-length meshgrid 2023-01-12 19:43:28 +01:00
Sharad Vikram
f729da4a36 Add shards for checkify_test on GPU
PiperOrigin-RevId: 501430172
2023-01-11 18:28:37 -08:00
jax authors
f0cca20a67 Merge pull request #13929 from nvcastet:add_ompi_init
PiperOrigin-RevId: 501429401
2023-01-11 18:20:51 -08:00
Yash Katariya
acd8dadc74 Make some api_test pass with jit/pjit merge
PiperOrigin-RevId: 501392938
2023-01-11 15:21:16 -08:00
Nicolas Castet
b86030d86f Add Open MPI automatic distributed initialization 2023-01-11 17:08:09 -06:00
jax authors
45197ed8db Merge pull request #13545 from LenaMartens:check-initial
PiperOrigin-RevId: 501248830
2023-01-11 04:52:02 -08:00