6034 Commits

Author SHA1 Message Date
Jieying Luo
b35c20ce5d Use xla_extension_version and remove some dead version check in xla_bridge_test.py.
Min jaxlib requires xla_extension_version >= 144.

PiperOrigin-RevId: 536810415
2023-05-31 13:50:07 -07:00
Yash Katariya
6d6ba70c78 Disable the RunnTest.test_lstm1 test since it is fixed for cudnn >= 8.8
PiperOrigin-RevId: 536693061
2023-05-31 06:21:01 -07:00
Yash Katariya
f884b4d13f Fix the test_sharding_on_output_with_vmap failure in Pathways which was getting a cache miss in pjit_call_impl.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
2023-05-30 19:51:48 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
Peter Hawkins
16368bc672 [XLA:Python] Clean up handling of unsupported types in buffer protocol.
Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails.

Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants.

PiperOrigin-RevId: 535315975
2023-05-25 11:10:19 -07:00
Peter Hawkins
32026ad18b Disable random_test_with_custom_prng on CPU under msan.
This test flakily times out in CI.

PiperOrigin-RevId: 535293997
2023-05-25 10:10:01 -07:00
Jake VanderPlas
333ff4abbc Add jnp.matrix_transpose() and jax.Array.mT
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Ce Zheng
8e397f7f08 [XLA:Client] Change replicate_last_dim to subgroup_types in HloSharding.iota_tile to cover arbitrary subgroups, adding necessary accessors.
PiperOrigin-RevId: 535079635
2023-05-24 20:26:28 -07:00
Yash Katariya
6a54ebd031 Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local fun_caches weakrefDict.
PiperOrigin-RevId: 534971855
2023-05-24 13:58:51 -07:00
Ce Zheng
4f1f5e4516 [XLA:Client] Expose HloSharding pybind factories for iota tile/partial tile, replicated and manual sharding,
PiperOrigin-RevId: 534600886
2023-05-23 16:37:42 -07:00
Matthew Johnson
d42350f879 disable custom_jvp for softmax by default
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
2023-05-23 11:56:50 -07:00
Jake VanderPlas
62fb0cd8a2 explicitly convert jnp.var scalar normalizer to float (from int)
This way we don't pass a potentially-large (Python builtin) int value to an
int32 JAX computation parameter and get an error.

Fixes #15068

Co-authored by: Matthew Johnson <mattjj@google.com>
2023-05-23 09:44:08 -07:00
jax authors
13f5090c4c Merge pull request #16018 from ZacCranko:tree_reduce_is_leaf
PiperOrigin-RevId: 534165099
2023-05-22 13:31:04 -07:00
Matthew Johnson
61b106ec8f allow lax.dot_general to accept different input dtypes
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.

The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix #10818.

However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
2023-05-22 10:33:42 -07:00
Yash Katariya
b71829f882 Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).
`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
2023-05-20 23:00:35 -07:00
Sholto Douglas
e0b5003880 Allow unconstrained dimensions when using NamedShardings.
PiperOrigin-RevId: 533752415
2023-05-20 16:28:12 -07:00
Parker Schuh
56ca8af9bb Make custom_partitioning support multiple return values.
PiperOrigin-RevId: 533584581
2023-05-19 16:58:54 -07:00
Peter Hawkins
1d20d2f301 Increase sharding of host_callback_test on TPU to fix CI flakiness.
PiperOrigin-RevId: 533451822
2023-05-19 07:44:53 -07:00
Jieying Luo
9da52e8905 [PJRT PLUGIN] Provide a register_plugin method that plugin can use to register their backend factory.
The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method.

Logics to register a plugin from ENV is not deleted to facilitate development with ENV.

PiperOrigin-RevId: 533280115
2023-05-18 16:13:02 -07:00
Yash Katariya
4a5c6f8200 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 533263584
2023-05-18 15:09:54 -07:00
Parker Schuh
08169291a4 Simplify custom_partitioning to use jax.ShapeDtypeStruct instead of passing separate
arguments for shape and sharding.

PiperOrigin-RevId: 533257532
2023-05-18 14:48:07 -07:00
Peter Hawkins
39097df02e Add some preliminary support for int4/uint4 types to JAX.
PiperOrigin-RevId: 533251630
2023-05-18 14:27:33 -07:00
Roy Frostig
717d3c88fc inline and remove eq_mlir and ne_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
f18bff5371 inline and remove scatter_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
cc54b6e6ad inline and remove select_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
301d058b3d inline and remove gather_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
071c77e5bb inline and remove transpose_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
06132ac764 inline and remove broadcast_in_dim_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
0ac792f4ed inline and remove dynamic_update_slice_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
2dbdf1a6c1 inline and remove dynamic_slice_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
aed77c5031 inline and remove slice_mlir rules 2023-05-17 20:07:58 -07:00
Roy Frostig
129a4a5f35 inline and remove empty_mlir rules 2023-05-17 20:07:58 -07:00
Roy Frostig
180e26dafb remove physical_avals rule in favor of physical_element_aval 2023-05-17 20:07:58 -07:00
jax authors
f3cecd07c7 Merge pull request #16047 from jakevdp:prngkey-replicated
PiperOrigin-RevId: 532959862
2023-05-17 17:21:54 -07:00
Jake VanderPlas
48abe7c684 PRNGKeyArray: add several missing attributes & methods 2023-05-17 14:47:22 -07:00
Peter Hawkins
389564551b Second attempt at fixing warnings from jax.dtypes.issubdtype.
PiperOrigin-RevId: 532902714
2023-05-17 14:10:22 -07:00
Yash Katariya
f1c2711292 Add impl rule for with_sharding_constraint so that users can use their functions with and without a jit.
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.

eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.

PiperOrigin-RevId: 532858670
2023-05-17 11:50:12 -07:00
jax authors
79be3482ca Merge pull request #16041 from LenaMartens:checkify-leaks
PiperOrigin-RevId: 532843752
2023-05-17 11:04:45 -07:00
lenamartens
ee6cbafa85 Checkify: Fix closing over Tracer in while_loop cond_f.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-05-17 18:43:23 +01:00
Jake VanderPlas
6ef4e5f01a Custom PRNG: make KeyArray compatible with custom_jvp 2023-05-17 10:31:09 -07:00
Jake VanderPlas
0e483223c6 Custom PRNG: support lax.full() and related constructors 2023-05-17 09:04:50 -07:00
Peter Hawkins
e6628e2e72 Disable tests that time out in CI.
PiperOrigin-RevId: 532792740
2023-05-17 08:16:07 -07:00
Peter Hawkins
eaf7eb2626 Break cycle between _src/core.py and _src/dtypes.py.
PiperOrigin-RevId: 532788430
2023-05-17 07:58:59 -07:00
jax authors
bac1298338 Merge pull request #16032 from sharadmv:scan-dce-effects
PiperOrigin-RevId: 532612724
2023-05-16 16:51:20 -07:00
Sharad Vikram
7b3cea62d8 Fix scan DCE rule to update effects 2023-05-16 23:17:02 +00:00
jax authors
050b243d01 Merge pull request #16028 from jakevdp:prngkey-error
PiperOrigin-RevId: 532604063
2023-05-16 16:15:56 -07:00
jax authors
e85b9619a4 Merge pull request #15998 from patrick-kidger:dce-linear-transpose
PiperOrigin-RevId: 532596793
2023-05-16 15:49:48 -07:00
Matthew Johnson
42b2a80df2 add a test for tree_reduce with is_leaf argument 2023-05-16 15:37:52 -07:00