This is necessary to ensure we can correctly detect PJRT plugins via
entry_points without compatibility errors.
Prior to this change, there was conditional logic to handle if
importlib_metadata wasn't installed at all. However, it doesn't handle
the case where importlib_metadata is installed by not high enough
version to support Python 3.10 compat. This change gets rid of that
logic and just ensures the right version is installed.
All of this logic can be removed if/when jax requires Python version
>= 3.10
This also removes an unnecessary `requests` dep for the [tpu] install.
In the following case of nested call:
```
inputs = np.array(range(6), dtype=np.float32).reshape(3, 2)
@jax.jit
def forward(x):
return x + 1
# JAX -> TF
tf_fn = jax2tf.convert(forward, native_serialization=True)
call_tf_fn = jax2tf.call_tf(tf_fn)
tf_fn_too = jax2tf.convert(call_tf_fn, native_serialization=True)
tf_fn_too(inputs) # FAIL
```
Without the fix, it fails with the following error:
```
jax/experimental/jax2tf/jax2tf.py", line 499, in _restore_context
_thread_local_state.call_tf_concrete_function_list.clear()
AttributeError: 'NoneType' object has no attribute 'clear'
```
because we call `_restore_context` twice when executing `jax2tf.convert`ed functions,
the first time we call `_restore_context`, `call_tf_concrete_function_list` is set to `None`
instead of restoring it to the previous state, so the second time we call `_restore_context`,
`call_tf_concrete_function_list.clear()` throws the above error since `call_tf_concrete_function_list` is `None`.
PiperOrigin-RevId: 536650377
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.
In the test, check for re-compilation which is what that test was doing before cl/535630905
PiperOrigin-RevId: 536575987
Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).
We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.
For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.
PiperOrigin-RevId: 535630905
Until now the jax_export.call_exported did not allow calling functions
that were exported with polymorphic shapes. We now add that support,
including resolving the dimension variables of the called function
in terms of the shapes at the call site (which themselves may include
dimension variables), and then computing the output shape of the
called function.
The support is partial in that we can export a JAX function that
calls an exported polymorphic function, but we cannot invoke it.
This is because we do not yet have access to the shape refinement
machinery that XlaCallModule uses. For now, we use XlaCallModule
for invoking exported that includes shape polymorphism.
In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure().
PiperOrigin-RevId: 535408671
Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails.
Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants.
PiperOrigin-RevId: 535315975
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API.
PiperOrigin-RevId: 535091927