Previously, we kept the `dim_vars` in the `mlir.ModuleContext`. Now we
replace that with a mutable `ShapePolyLoweringState` that also tracks
whether we encounter shape polymorphism anywhere in the lowering.
For this purpose, we also add `shape_poly_state` to the lowering.compile_args.
We need to keep track of whether a module contains dimension variables
because such modules need shape refinement before they can be converted
to MHLO and compiled. For now, we just test that we set the
`Exported.module_uses_dim_vars` correctly.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.
In the test, check for re-compilation which is what that test was doing before cl/535630905
PiperOrigin-RevId: 536575987
Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).
We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.
For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.
PiperOrigin-RevId: 535630905
Until now the jax_export.call_exported did not allow calling functions
that were exported with polymorphic shapes. We now add that support,
including resolving the dimension variables of the called function
in terms of the shapes at the call site (which themselves may include
dimension variables), and then computing the output shape of the
called function.
The support is partial in that we can export a JAX function that
calls an exported polymorphic function, but we cannot invoke it.
This is because we do not yet have access to the shape refinement
machinery that XlaCallModule uses. For now, we use XlaCallModule
for invoking exported that includes shape polymorphism.
In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure().
PiperOrigin-RevId: 535408671
Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails.
Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants.
PiperOrigin-RevId: 535315975
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API.
PiperOrigin-RevId: 535091927
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
Referring to them as simply `map` or `zip` will create recursive
reimplementations (with no base case!) if the cell is reevaluated in
the same runtime.