This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.
If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.
PiperOrigin-RevId: 713296722
With jax.experimental.export gone we can now do some cleanup in the export module.
In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.
PiperOrigin-RevId: 692398132
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.
This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.
PiperOrigin-RevId: 689708392
Consider the use case when we call_tf a restored saved model that
includes parameters (hence functions closing over tf.Variable), and then
we jax2tf.convert it with native serialization, under tf.function (or
for saving to saved model).
The lowering for call_tf in presence of functions with captured inputs
requires looking up the tf.Variable and reading its value. This fails
with an error that `v.numpy()` is not allowd in graph mode. The fix
is to use `tf.init_scope()` to lift out of graph building mode, so that
we can read the value of the variables.
The goal of this change is to catch PRs that introduce new warnings sooner.
To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.
Add code to suppress some new warnings uncovered in CI.
PiperOrigin-RevId: 678352286
We take the opportunity of a new jax.export package to rename some
of the API entry points:
* `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
because this is more accurate. The dimension variables are global
constants, but so is the platform index. And we need to run
global constant propagation and shape refinement for all of these.
* We rename "serialization version" with "calling convention version".
Hence we now have `Exported.calling_convention_version`,
and the configuration flag is renamed from `--jax-serialization-version`
to `--jax-export-calling-convention-version`. Also,
`jax.export.minimum_supported_serialization_version` is now
`jax.export.minimum_supported_calling_convention_version`.
* We rename `lowering_platforms` to `platforms` both as a field
of `Exported` and as the kwarg to `export.export`.
* We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
The functionality comes from the jax.experimental.export
module, which will be deprecated.
The following APIs are introduced:
```
from jax import export
def f(...): ...
ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)
blob: bytearray = ex.serialize()
rehydrated: export.Export = export.deserialize(blob)
def caller(...):
... rehydrated.call(*args, **kwargs)
```
Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.
Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:
* Instead of `jax.experimental.export.call(exp)` we now write
`exp.call`
* The `jax.experimental.export.export` allowed the function
argument to be any Python callable and it would wrap it with
a `jax.jit`. This is not supported anymore by export, and instead
the user must use `jax.jit`.
GetDefaultLayout added a fallback for GPU backend so it is no longer blocked by the fact that PJRT C API does not support GetDefaultLayout yet.
PiperOrigin-RevId: 632555239
- The root cause of the bug is that dtype lookups are incorrect because hashes behave differently between dtype instances and their types. Added comments to `jax.dlpack.SUPPORTED_DTYPES` about this.
- Added unit test coverage.
- Fixing this bug revealed a limitation of causing "host-to-device" copy in the following two situations. See the details in the unit test comments.:
- When the dtype is 'int32'.
- When using PJRT C API runtime.
PiperOrigin-RevId: 610799558
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:
```
from jax.experimental import export
exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```
This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.
In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.
PiperOrigin-RevId: 596563481
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:
```
from jax.experimental import export
exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1 = export.deserialized(ser)
export.call(exp1)
```
This change also includes a workaround to allow users to still
do `from jax.experimental.export import export`, for a while.
call_tf has per-platform lowering because the lowering
of the called TF function may depend on the platform. When
doing multi-platform lowering this means that we lower
call_tf several times and wrap the lowerings with a
conditional. This results in an assertion failure
in add_to_call_tf_concrete_function_list, because we
are attempting to add the same function multiple times.
Here we remove the assertion (afaik, it is Ok to add
multiple functions with the same name, because all
we care about is the index of the called function in
the list). We also reuse the existing function if
we are adding an identical one.
We add tests for call_tf with multi-platform lowering.
In presence of ordered effects JAX lowering produces a main
function that takes token
inputs and returns token outputs. Previously, when exporting
such a module, we would wrap the main function with a function
that does not use tokens on inputs and outputs. With this
change we actually leave the token inputs and outputs and
rely on consumers of the exported function to know how to
invoke a function with tokens.
Due to the fact that PJRT does not support passing tokens
as input and output to the top-level function, JAX native
lowering uses dummy bool[0] arrays in lieu of tokens for
the top-level function, and uses stablehlo tokens for the
inner functions. When we export a function for serialization
we want to use stablehlo tokens even at top-level, to enable
calling that function from a larger JAX computation later.
See more details about the calling convention in the
docstring for `export.export`.
We also fix and test multi-platform lowering in presence
of effects.
This introduces serialization version 9, but does not change the
default serialization version. This means that version 9 will not
be used except in tests that specifically override the
serialization version.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
We must ensure that we call jax2tf.convert recursively to ensure
that the proper tf.custom_gradient is used. This means that we can
reuse the conversion of the VJP function between native and graph
serialization.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.
TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).
The net effect of this change is mostly to tighten up many test tolerances on TPU.
PiperOrigin-RevId: 562953488
--
b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d by George Necula <gcnecula@gmail.com>:
[shape_poly] Fix lowering when we have both dimension variables and tokens
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16575 from gnecula:call_tf_poly b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d
PiperOrigin-RevId: 544252624
Previously, we had a boolean `native_serialization_strict_checks` parameter
that was disabling all safety checks. This mechanism had several
disadvantages:
* the mechanism did not differentiate between different safety checks.
E.g., in order to disable checking of the custom call targets, one
had to disable checking for all custom call targets, and also the
checking that the serialization and execution platforms are the same.
* the mechanism operated only at serialization time. Now, the
XlaCallModule supports a `disabled_checks` attribute to control
which safety checks should be disabled.
Here we replace the `native_serialization_strict_checks` with
`native_serialization_disabled_checks`, whose values are sequences
of disabled check descriptors.
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.
With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:
* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.
For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.
Example StableHLO produced from the added test:
```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<> : tensor<0xi1>
%1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
return %1#1 : tensor<f32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
%0 = stablehlo.create_token : !stablehlo.token
%1 = stablehlo.constant dense<0> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
cond {
%4 = stablehlo.constant dense<4> : tensor<i32>
%5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %5 : tensor<i1>
} do {
%4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
%5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
%7 = stablehlo.constant dense<1> : tensor<i32>
%8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
}
%3 = stablehlo.constant dense<> : tensor<0xi1>
return %3, %2#2 : tensor<0xi1>, tensor<f32>
}
}
```
PiperOrigin-RevId: 534926215
The main cleanup is around _code_generator_and_avals, which in
an earlier version of the code was used for both abstract values
and for code generation. That is why it was cached, and why it
returned a code generator and abstract values. A while
ago we did a first round of cleaning to not use it for abstract
values. Now we can actually eliminate the function and inline
it directly.
A second improvement is to add the explicit error message from
TF commpilation, instead of just the generic message that
call_tf cannot be used with non-compileable functions.