This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.
PiperOrigin-RevId: 535630905
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
This way we don't pass a potentially-large (Python builtin) int value to an
int32 JAX computation parameter and get an error.
Fixes#15068
Co-authored by: Matthew Johnson <mattjj@google.com>
This constructor unconditionally returns a typed key array, regardless
of the value of `jax.config.enable_custom_prng`. We can switch to
referring to it in randomness docs and tutorials as we complete the
typed key upgrade.
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.
The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix#10818.
However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
This effectively implements a mix of option 2 and option 3 from https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ as a pragmatic way to cover all packaging cases. The namespace/path based iteration works for situations where code has not been packaged and is present on the PYTHONPATH, whereas the advertised entry-points work around setuptools/pkgutil issues that make it impossible to reliably iterate over installed modules in certain scenarios (noted for editable installs which use a custom finder that does not implement iter_modules()).
A plugin entry-point can be advertised in setup.py (or equivalent pyproject.toml) with something like:
```
entry_points={
"jax_plugins": [
"openxla-cpu = jax_plugins.openxla_cpu",
],
}
```
The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method.
Logics to register a plugin from ENV is not deleted to facilitate development with ENV.
PiperOrigin-RevId: 533280115
The plugins in the namespace package `jax_plugins` will be imported. The plugins need to (1) be placed in a root folder `jax_plugins` and follow other namespace package requirements, and (2) implement an initialize() method which appends `plugin_name:file_path` to env var `PJRT_NAMES_AND_LIBRARY_PATHS`.
Appending to PJRT_NAMES_AND_LIBRARY_PATHS is a short term solution and what the initialize() should do is in discussion.
PiperOrigin-RevId: 532897890
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.
eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.
PiperOrigin-RevId: 532858670