At the moment we can run the StableHLO module lowered by jax2tf
with polymorphic shapes only with jax2tf, because the tf.XlaCallModule op has the
necessary shape refinement logic (which is necessary to legalize
the StableHLO module with dynamic shapes to MHLO). Here we
expose the shape refinement MLIR transformation to JAX Python.
For now this is used only in a test in jax_export_test.py.
PiperOrigin-RevId: 537485288
Fixes the docstring `jax.scipy.special.gamma`, which was wrapping `scipy.special.gammaln` by mistake. Also adds a note that the function currently only accepts real inputs.
`tf.call_tf_function` arises from `jax2tf.call_tf(tf_fun, call_tf_graph)`. However, a function that contains this can be lowered and executed only with `jax2tf.convert` and ought to be serialized as ` tf.Graph` because the serialization includes a tf.function as well.
In order to support this we need to add some code to back_compat_test.py to serialize and run the serialized code as tf.Graph.
PiperOrigin-RevId: 537062963
This is necessary to ensure we can correctly detect PJRT plugins via
entry_points without compatibility errors.
Prior to this change, there was conditional logic to handle if
importlib_metadata wasn't installed at all. However, it doesn't handle
the case where importlib_metadata is installed by not high enough
version to support Python 3.10 compat. This change gets rid of that
logic and just ensures the right version is installed.
All of this logic can be removed if/when jax requires Python version
>= 3.10
This also removes an unnecessary `requests` dep for the [tpu] install.
In the following case of nested call:
```
inputs = np.array(range(6), dtype=np.float32).reshape(3, 2)
@jax.jit
def forward(x):
return x + 1
# JAX -> TF
tf_fn = jax2tf.convert(forward, native_serialization=True)
call_tf_fn = jax2tf.call_tf(tf_fn)
tf_fn_too = jax2tf.convert(call_tf_fn, native_serialization=True)
tf_fn_too(inputs) # FAIL
```
Without the fix, it fails with the following error:
```
jax/experimental/jax2tf/jax2tf.py", line 499, in _restore_context
_thread_local_state.call_tf_concrete_function_list.clear()
AttributeError: 'NoneType' object has no attribute 'clear'
```
because we call `_restore_context` twice when executing `jax2tf.convert`ed functions,
the first time we call `_restore_context`, `call_tf_concrete_function_list` is set to `None`
instead of restoring it to the previous state, so the second time we call `_restore_context`,
`call_tf_concrete_function_list.clear()` throws the above error since `call_tf_concrete_function_list` is `None`.
PiperOrigin-RevId: 536650377
Previously, we kept the `dim_vars` in the `mlir.ModuleContext`. Now we
replace that with a mutable `ShapePolyLoweringState` that also tracks
whether we encounter shape polymorphism anywhere in the lowering.
For this purpose, we also add `shape_poly_state` to the lowering.compile_args.
We need to keep track of whether a module contains dimension variables
because such modules need shape refinement before they can be converted
to MHLO and compiled. For now, we just test that we set the
`Exported.module_uses_dim_vars` correctly.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.
In the test, check for re-compilation which is what that test was doing before cl/535630905
PiperOrigin-RevId: 536575987
Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).
We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.
For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
* Mention that Tensorboard profiling supports device memory usage
* Recommend TB profiling instead of the pprof-based device memory profiling
* Minor updates to GCP instructions
Inspired by https://github.com/google/jax/issues/1491