This is necessary to ensure we can correctly detect PJRT plugins via
entry_points without compatibility errors.
Prior to this change, there was conditional logic to handle if
importlib_metadata wasn't installed at all. However, it doesn't handle
the case where importlib_metadata is installed by not high enough
version to support Python 3.10 compat. This change gets rid of that
logic and just ensures the right version is installed.
All of this logic can be removed if/when jax requires Python version
>= 3.10
This also removes an unnecessary `requests` dep for the [tpu] install.
In the following case of nested call:
```
inputs = np.array(range(6), dtype=np.float32).reshape(3, 2)
@jax.jit
def forward(x):
return x + 1
# JAX -> TF
tf_fn = jax2tf.convert(forward, native_serialization=True)
call_tf_fn = jax2tf.call_tf(tf_fn)
tf_fn_too = jax2tf.convert(call_tf_fn, native_serialization=True)
tf_fn_too(inputs) # FAIL
```
Without the fix, it fails with the following error:
```
jax/experimental/jax2tf/jax2tf.py", line 499, in _restore_context
_thread_local_state.call_tf_concrete_function_list.clear()
AttributeError: 'NoneType' object has no attribute 'clear'
```
because we call `_restore_context` twice when executing `jax2tf.convert`ed functions,
the first time we call `_restore_context`, `call_tf_concrete_function_list` is set to `None`
instead of restoring it to the previous state, so the second time we call `_restore_context`,
`call_tf_concrete_function_list.clear()` throws the above error since `call_tf_concrete_function_list` is `None`.
PiperOrigin-RevId: 536650377
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.
In the test, check for re-compilation which is what that test was doing before cl/535630905
PiperOrigin-RevId: 536575987
Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).
We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.
For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.
PiperOrigin-RevId: 535630905
Until now the jax_export.call_exported did not allow calling functions
that were exported with polymorphic shapes. We now add that support,
including resolving the dimension variables of the called function
in terms of the shapes at the call site (which themselves may include
dimension variables), and then computing the output shape of the
called function.
The support is partial in that we can export a JAX function that
calls an exported polymorphic function, but we cannot invoke it.
This is because we do not yet have access to the shape refinement
machinery that XlaCallModule uses. For now, we use XlaCallModule
for invoking exported that includes shape polymorphism.
In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure().
PiperOrigin-RevId: 535408671
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API.
PiperOrigin-RevId: 535091927
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
Because this is best effort, users writing code to handle GPSMDSharding
should be able to deal only with the GSPMDSharding type.
PiperOrigin-RevId: 534612265
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
This way we don't pass a potentially-large (Python builtin) int value to an
int32 JAX computation parameter and get an error.
Fixes#15068
Co-authored by: Matthew Johnson <mattjj@google.com>
This constructor unconditionally returns a typed key array, regardless
of the value of `jax.config.enable_custom_prng`. We can switch to
referring to it in randomness docs and tutorials as we complete the
typed key upgrade.
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.
The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix#10818.
However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
This effectively implements a mix of option 2 and option 3 from https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ as a pragmatic way to cover all packaging cases. The namespace/path based iteration works for situations where code has not been packaged and is present on the PYTHONPATH, whereas the advertised entry-points work around setuptools/pkgutil issues that make it impossible to reliably iterate over installed modules in certain scenarios (noted for editable installs which use a custom finder that does not implement iter_modules()).
A plugin entry-point can be advertised in setup.py (or equivalent pyproject.toml) with something like:
```
entry_points={
"jax_plugins": [
"openxla-cpu = jax_plugins.openxla_cpu",
],
}
```