16227 Commits

Author SHA1 Message Date
Yash Katariya
48ad9a6f3e Start jax and jaxlib 0.4.11 release
PiperOrigin-RevId: 536860076
jax-v0.4.11 jaxlib-v0.4.11 jax-v0.4.11-rc
2023-05-31 16:48:52 -07:00
jax authors
525ba49ba7 Merge pull request #16204 from skye:importlib_metadata_version
PiperOrigin-RevId: 536823622
2023-05-31 14:27:53 -07:00
Skye Wanderman-Milne
968237080f Add importlib_metadata to project requirements.
This is necessary to ensure we can correctly detect PJRT plugins via
entry_points without compatibility errors.

Prior to this change, there was conditional logic to handle if
importlib_metadata wasn't installed at all. However, it doesn't handle
the case where importlib_metadata is installed by not high enough
version to support Python 3.10 compat. This change gets rid of that
logic and just ensures the right version is installed.

All of this logic can be removed if/when jax requires Python version
>= 3.10

This also removes an unnecessary `requests` dep for the [tpu] install.
2023-05-31 21:03:12 +00:00
Jieying Luo
b35c20ce5d Use xla_extension_version and remove some dead version check in xla_bridge_test.py.
Min jaxlib requires xla_extension_version >= 144.

PiperOrigin-RevId: 536810415
2023-05-31 13:50:07 -07:00
jax authors
727c121169 Merge pull request #16188 from nouiz:ci_jestimator
PiperOrigin-RevId: 536810121
2023-05-31 13:41:29 -07:00
jax authors
c587dac134 Merge pull request #16203 from skye:tpu_py_version2
PiperOrigin-RevId: 536776189
2023-05-31 11:39:06 -07:00
Skye Wanderman-Milne
131d28ba0f Use default Python version on Cloud TPU CI 2023-05-31 18:04:39 +00:00
Yash Katariya
6d6ba70c78 Disable the RunnTest.test_lstm1 test since it is fixed for cudnn >= 8.8
PiperOrigin-RevId: 536693061
2023-05-31 06:21:01 -07:00
jax authors
758d68df13 Restore call_tf_concrete_function_list to previous state
In the following case of nested call:

```
inputs = np.array(range(6), dtype=np.float32).reshape(3, 2)

@jax.jit
def forward(x):
	return x + 1

# JAX -> TF
tf_fn = jax2tf.convert(forward, native_serialization=True)
call_tf_fn = jax2tf.call_tf(tf_fn)
tf_fn_too = jax2tf.convert(call_tf_fn, native_serialization=True)

tf_fn_too(inputs)  # FAIL
```

Without the fix, it fails with the following error:

```
jax/experimental/jax2tf/jax2tf.py", line 499, in _restore_context
    _thread_local_state.call_tf_concrete_function_list.clear()
AttributeError: 'NoneType' object has no attribute 'clear'
```

because we call `_restore_context` twice when executing `jax2tf.convert`ed functions,
the first time we call `_restore_context`, `call_tf_concrete_function_list` is set to `None`
instead of restoring it to the previous state, so the second time we call `_restore_context`,
`call_tf_concrete_function_list.clear()` throws the above error since `call_tf_concrete_function_list` is `None`.

PiperOrigin-RevId: 536650377
2023-05-31 02:23:14 -07:00
Yash Katariya
f884b4d13f Fix the test_sharding_on_output_with_vmap failure in Pathways which was getting a cache miss in pjit_call_impl.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
2023-05-30 19:51:48 -07:00
jax authors
3ad756f7e0 Merge pull request #16176 from gnecula:poly_constraints
PiperOrigin-RevId: 536571493
2023-05-30 19:16:52 -07:00
George Necula
9ad8c3b9f1 [shape_poly] Add static constraint checking to the computation of dim vars
Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).

We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.

For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
2023-05-31 04:48:44 +03:00
Frederic Bastien
cdced240f5 WAR the bug in t5x dependency. It currently need the dev version of jestimator. 2023-05-30 12:55:21 -07:00
jax authors
acfeb9bb13 Merge pull request #16169 from ZacCranko:data_parallel_example
PiperOrigin-RevId: 536260245
2023-05-29 18:39:44 -07:00
Zac Cranko
a192b5e541 improve data parallel example
fix example

fix example

fix example

fix example

fix example

fix example
2023-05-30 01:25:17 +00:00
jax authors
ae9160a4e9 Merge pull request #16159 from jakevdp:deprecations
PiperOrigin-RevId: 536003451
2023-05-28 07:27:45 -07:00
Jake VanderPlas
7a87995ecd Deprecate jax.interpreters.xla.Buffer, device_put, xla_call_p 2023-05-28 07:15:34 -07:00
Sharad Vikram
1279418ce5 Link in CUDA runtime for triton in jaxlib
PiperOrigin-RevId: 535708416
2023-05-26 14:02:16 -07:00
Jieying Luo
cb3b7ec93a [PJRT PLUGIN] Add num_processes to distributed.global_state.
The number of processes is needed for multi-process GPU when plugin is used.

PiperOrigin-RevId: 535696950
2023-05-26 13:14:40 -07:00
Yash Katariya
d62bc0f795 Fix the jax2tf failure in mypy: https://github.com/google/jax/actions/runs/5094063162/jobs/9157426652?pr=16155
PiperOrigin-RevId: 535692853
2023-05-26 12:57:37 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
jax authors
25a9a978fb Merge pull request #16151 from hawkinsp:cudnn
PiperOrigin-RevId: 535642800
2023-05-26 09:48:40 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
jax authors
9508f3ad9d Merge pull request #16148 from gnecula:export_poly
PiperOrigin-RevId: 535628086
2023-05-26 08:44:41 -07:00
George Necula
46a258ba17 [shape_poly] Add partial support for call_exported with polymorphic shapes
Until now the jax_export.call_exported did not allow calling functions
that were exported with polymorphic shapes. We now add that support,
including resolving the dimension variables of the called function
in terms of the shapes at the call site (which themselves may include
dimension variables), and then computing the output shape of the
called function.

The support is partial in that we can export a JAX function that
calls an exported polymorphic function, but we cannot invoke it.
This is because we do not yet have access to the shape refinement
machinery that XlaCallModule uses. For now, we use XlaCallModule
for invoking exported that includes shape polymorphism.
2023-05-26 17:27:44 +02:00
Yash Katariya
2858df24ff Start the process of removing OpSharding from JAX and replacing it with HloSharding. This will allow for future optimizations of HloSharding to work seamlessly with JAX.
Currently, no function producing HloSharding is being used. I will do that in follow up CLs.

PiperOrigin-RevId: 535622806
2023-05-26 08:19:14 -07:00
Peter Hawkins
69cf67f252 Bump the minimum CUDNN version for CUDA 12 wheels to 8.9. 2023-05-26 10:04:34 -04:00
Chris Jones
ea37043577 Switch to STATUS_RETURNING callback API.
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
jax authors
7833528765 Merge pull request #16143 from jakevdp:fix-shape-poly
PiperOrigin-RevId: 535427698
2023-05-25 16:31:09 -07:00
John QiangZhang
ed10293f9c Add new called_index to custom_call tf.backend_config DictAttr.
Here, `called_index` indicates the tf concrete function index in the `function_list` of the parent XLACallModule.

PiperOrigin-RevId: 535417558
2023-05-25 15:58:50 -07:00
Jake VanderPlas
bbae2edd12 jax2tf: correctly handle opaque dtype in jax2tf pure()
In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure().

PiperOrigin-RevId: 535408671
2023-05-25 15:32:47 -07:00
jax authors
8534f0bfc3 Merge pull request #16142 from froystig:outline-random-functions
PiperOrigin-RevId: 535406588
2023-05-25 15:25:18 -07:00
Jake VanderPlas
b853ce9967 jax2tf: make shape_poly_test pass with custom PRNG 2023-05-25 15:16:46 -07:00
Roy Frostig
3238b627a1 outline jitted jax.random functions
We may want to continue to inline these in Jaxpr somehow, but it's
useful to outline them in HLO for visualization and debugging.
2023-05-25 15:01:04 -07:00
jax authors
14089fb2f8 Merge pull request #16138 from hawkinsp:cudnn
PiperOrigin-RevId: 535367462
2023-05-25 13:40:34 -07:00
Peter Hawkins
2b7790290b Bump minimum CUDNN version in pip installation to 8.8.
There are known wrong output bugs observed in JAX for earlier versions, in particular related to RNNs.
2023-05-25 14:46:39 -04:00
Peter Hawkins
16368bc672 [XLA:Python] Clean up handling of unsupported types in buffer protocol.
Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails.

Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants.

PiperOrigin-RevId: 535315975
2023-05-25 11:10:19 -07:00
Chris Jones
2155b9181f Switch to using JAX status macros in jax-triton kernel call lib.
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00
Mark Sandler
bc547aa318 Adds a note that pjit is equivalent to jit.
PiperOrigin-RevId: 535296532
2023-05-25 10:17:25 -07:00
Peter Hawkins
32026ad18b Disable random_test_with_custom_prng on CPU under msan.
This test flakily times out in CI.

PiperOrigin-RevId: 535293997
2023-05-25 10:10:01 -07:00
jax authors
24928a507b Merge pull request #16117 from jakevdp:matrix-transpose
PiperOrigin-RevId: 535292507
2023-05-25 10:02:26 -07:00
Jake VanderPlas
222b951b19 Use new matrix_transpose in linalg code 2023-05-25 09:32:14 -07:00
Jake VanderPlas
333ff4abbc Add jnp.matrix_transpose() and jax.Array.mT
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Chris Jones
6b13d4eb86 Add branch prediction to JAX status macros.
PiperOrigin-RevId: 535233546
2023-05-25 06:23:23 -07:00
Eugene Burmako
e25052c6f8 Use stablehlo.get_minimum_version in jax_export.py
The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API.

PiperOrigin-RevId: 535091927
2023-05-24 21:15:16 -07:00
Ce Zheng
8e397f7f08 [XLA:Client] Change replicate_last_dim to subgroup_types in HloSharding.iota_tile to cover arbitrary subgroups, adding necessary accessors.
PiperOrigin-RevId: 535079635
2023-05-24 20:26:28 -07:00
John QiangZhang
5e82d6b5d5 Fix jax2tf_test regression failure.
PiperOrigin-RevId: 535002015
2023-05-24 15:27:57 -07:00
Sharad Vikram
4fb834b351 Use jaxlib version guard for triton instead of xla_extension_version
PiperOrigin-RevId: 534974834
2023-05-24 14:06:45 -07:00
Yash Katariya
6a54ebd031 Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local fun_caches weakrefDict.
PiperOrigin-RevId: 534971855
2023-05-24 13:58:51 -07:00