Jake VanderPlas
81e627d5bd
DOC: make API doc titles more uniform
2023-01-18 10:59:42 -08:00
Yash Katariya
a37121e195
Don't depend on flatten_axis_resources
which will error because flatten_axes
passes a dummy object()
which doesn't work with checks in user pytrees.
...
Only do this if the original {in|out}_shardings are _UNSPECIFIED.
PiperOrigin-RevId: 502792305
2023-01-18 00:13:04 -08:00
jax authors
74df2f9927
Merge pull request #14031 from gnecula:tf_jit_pjit
...
PiperOrigin-RevId: 502790279
2023-01-18 00:01:34 -08:00
George Necula
97b58bfae7
[jax2tf] Adjust jax2tf for the pjit==jit API migration.
...
jax2tf treats jit and pjit differently: jit was inlined while
pjit resulted in a recursive call to _interpret_jaxpr. This
resulted in differences of handling of constant sharing.
This PR actually makes the constant sharing more aggressive.
This should be Ok, because we are only sharing non-scalars
which JAX has already lifted to the top-level of the Jaxpr.
2023-01-18 09:04:21 +02:00
jax authors
b0e30eb067
Merge pull request #13911 from 8bitmp3:main
...
PiperOrigin-RevId: 502743593
2023-01-17 18:50:27 -08:00
Yash Katariya
05e1ddd4ea
Make error_test
a jax_test so that we can test other configs and fix it with jit
/pjit
merge.
...
PiperOrigin-RevId: 502743523
2023-01-17 18:43:05 -08:00
Parker Schuh
b58dd3cbe1
Add support for __signature__ to PjitFunction.
...
PiperOrigin-RevId: 502731453
2023-01-17 17:28:14 -08:00
jax authors
6d80e5909b
Merge pull request #13423 from nvlcambier:lcambier/on_demand_ci_jax_gpu
...
PiperOrigin-RevId: 502727655
2023-01-17 17:11:18 -08:00
jax authors
e6e8350389
Merge pull request #14047 from jakevdp:doc-installation
...
PiperOrigin-RevId: 502727042
2023-01-17 17:03:43 -08:00
Yash Katariya
53fceab17c
pjit
allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test()
. This is because the inner backend will take priority.
...
PiperOrigin-RevId: 502721834
2023-01-17 16:39:45 -08:00
Leopold Cambier
056702c1cb
Multinodes CICD on GPUs using on-demand cluster and e2e tests using T5X
2023-01-17 16:29:30 -08:00
Jake VanderPlas
5ebd1e79ff
DOC: add GPU & TPU installation tabs on index page
2023-01-17 16:19:57 -08:00
8bitmp3
7e5eba6173
Update headings in JAX Custom operations for GPUs and Building from source guides
2023-01-17 23:44:49 +00:00
jax authors
8da6c89c7b
Merge pull request #13759 from sharadmv:io-callback
...
PiperOrigin-RevId: 502694690
2023-01-17 14:48:50 -08:00
Sharad Vikram
3de5c2b716
Add IO callback
2023-01-17 13:55:05 -08:00
Yash Katariya
cb9a9952fe
Check if the sharding input to ShapeDtypeStruct is an instance of Sharding
...
PiperOrigin-RevId: 502652848
2023-01-17 12:08:51 -08:00
Yash Katariya
85654ceeab
Default dynamic_api_test and custom_object_test to take the old jit path and not the merged path since there is no pjit support for it yet.
...
PiperOrigin-RevId: 502620662
2023-01-17 10:19:39 -08:00
George Necula
7e0041c903
Fix scatter in CLIP mode with uint32 and uint64 indices
...
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python constants or C constants.
This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.
PiperOrigin-RevId: 502568204
2023-01-17 06:26:21 -08:00
jax authors
7ce9fa2f87
Merge pull request #14030 from gnecula:poly_vmap_error
...
PiperOrigin-RevId: 502546564
2023-01-17 04:19:16 -08:00
jax authors
74d6c219b5
Merge pull request #13312 from nouiz:jax_array_doc
...
PiperOrigin-RevId: 502534154
2023-01-17 03:00:19 -08:00
George Necula
cf4e568e21
[shape_poly] Improve error message from vmap axis size inconsistency
...
vmap tries hard to give nice error messages when the mapped axes
for different arguments have different sizes, but the code to
compute the error message can run into InconsistentDimensionOperation
in presence of dimension polynomials. Ensure that the comparisons
are done symbolically.
2023-01-17 10:45:12 +02:00
John QiangZhang
469a8eb520
Change args_tf_flat to tf.TensorSpec in jax2tf.call_tf.
...
PiperOrigin-RevId: 502492762
2023-01-16 22:43:13 -08:00
Sharad Vikram
a58e59d98f
Add in effects_barrier for the pmap unordered callback test
...
PiperOrigin-RevId: 502434258
2023-01-16 14:44:44 -08:00
jax authors
432b909ef8
Merge pull request #14027 from mattjj:issue14026
...
PiperOrigin-RevId: 502408481
2023-01-16 10:59:46 -08:00
Matthew Johnson
e516d41180
cond transpose, use UndefinedPrimal not linear
for transpose inputs
2023-01-16 10:39:19 -08:00
Yash Katariya
c4d21f97ea
Make xmap use dispatch.sharded_lowering as dispatch.lower_xla_callable is deprecated.
...
PiperOrigin-RevId: 502398366
2023-01-16 09:42:40 -08:00
jax authors
35820ef1e4
Merge pull request #13980 from gnecula:call_tf_effects2
...
PiperOrigin-RevId: 502346453
2023-01-16 03:54:11 -08:00
jax authors
78c9dd8104
Merge pull request #14020 from mattjj:issue13983
...
PiperOrigin-RevId: 502147420
2023-01-14 21:36:41 -08:00
Matthew Johnson
1da24b61fc
fix transcription error for initial_step_size (thanks @AlbertMitjans)
...
fixes #13983
2023-01-14 21:10:44 -08:00
Yash Katariya
8f538f95dc
Pass the proper api_name to debug_info
...
PiperOrigin-RevId: 502141425
2023-01-14 20:41:01 -08:00
Yash Katariya
1209ab17e4
Add abstracted axes to pjit to make jax2tf tests pass. abstracted_axes and dynamic_shapes is not supported by pjit yet.
...
PiperOrigin-RevId: 502138836
2023-01-14 20:17:30 -08:00
Yash Katariya
4601928277
Enable jit_pjit_api_merge by default "in tests" and disable the current failing tests.
...
PiperOrigin-RevId: 502088044
2023-01-14 11:15:03 -08:00
Yash Katariya
4c58ef3840
Add in_positional_semantics to new_params_known and new_params_staged otherwise it leads to length mismatch error down the stack. It is similar to donated_invars and in_shardings.
...
PiperOrigin-RevId: 502082828
2023-01-14 10:19:00 -08:00
George Necula
ade5691630
[call_tf] Add has_side_effects parameter
...
The CallTfEffect was added recently as an internal workaround for
DCE removing instances of call_tf. Here we add a parameter to
`call_tf` to be able to declare if the called computation is
effectful and should not be removed by DCE.
2023-01-14 08:12:29 +01:00
Yash Katariya
38f91bdaa5
Skip core tests which have nested pjits and DShapedArray.
...
PiperOrigin-RevId: 502013080
2023-01-13 22:39:31 -08:00
Qiao Zhang
d203926c16
Expose fp8 in jax dtypes and mlir builder.
...
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Pankaj Kanwar
8fcb5180b2
disable flaky tests on certain targets.
...
PiperOrigin-RevId: 501974439
2023-01-13 17:23:36 -08:00
jax authors
17785a4ded
Merge pull request #13285 from canyon289:draft_focused_landing_page
...
PiperOrigin-RevId: 501954451
2023-01-13 15:35:34 -08:00
Yash Katariya
7e8fe13c6a
jit
was the default name in name_stack in mlir.py. Fix that by taking the name as an optional argument (defaulting to jit
) so that nested pjits will show up as pjit
in the name stack.
...
PiperOrigin-RevId: 501946780
2023-01-13 15:00:22 -08:00
Jake VanderPlas
9aa8fe520f
Stricter input validation for shaped_abstractify
...
Fixes https://github.com/google/jax/issues/13976
PiperOrigin-RevId: 501944946
2023-01-13 14:53:08 -08:00
Yash Katariya
5eb23a7615
Fix name_stack
usage of pjit. Now all the metadata of transformations in hlo are correct.
...
PiperOrigin-RevId: 501918212
2023-01-13 12:54:12 -08:00
Ravin Kumar
2257e2075d
Update doc landing page
...
Co-authored-by: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com>
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2023-01-13 12:45:28 -08:00
jax authors
86ba62d05e
Merge pull request #13991 from skye:pjrt_c_api_marker
...
PiperOrigin-RevId: 501909180
2023-01-13 12:15:00 -08:00
Yash Katariya
649ee1be34
Make pickle_test.py pass with jit/pjit api merge. Also rename and move some functions around
...
PiperOrigin-RevId: 501878555
2023-01-13 10:16:01 -08:00
Yash Katariya
eb887f42c4
Add type hints to some functions in pxla.py. Adding typehints to pjit
is a bit difficult because of functions like _is_unspecified
, etc. TypeGuard
is a nice solution for it but it doesn't work in the negative case.
...
PiperOrigin-RevId: 501858300
2023-01-13 08:45:49 -08:00
jax authors
d1593289a0
Merge pull request #13869 from jakevdp:bcoo-extract-api
...
PiperOrigin-RevId: 501842123
2023-01-13 07:27:19 -08:00
Jake VanderPlas
7a8781db1c
[sparse] add higher-level version of bcoo_extract & improve tests
2023-01-13 07:13:13 -08:00
jax authors
44a044f936
Merge pull request #13982 from Edenhofer:fix_zero_length_meshgrid
...
PiperOrigin-RevId: 501827615
2023-01-13 06:06:52 -08:00
jax authors
9c418e9399
Merge pull request #13862 from jakevdp:bcoo-extract
...
PiperOrigin-RevId: 501826918
2023-01-13 05:59:08 -08:00
Sharad Vikram
c9a57e1b44
Delete jax.experimental.callback
...
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00