103 Commits

Author SHA1 Message Date
Sergei Lebedev
0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
Peter Hawkins
3fa557289a Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
2025-01-08 08:14:50 -08:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
George Necula
9088adda68 [jax2tf] Disable jax2tf with non-native serialization.
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.

This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.

PiperOrigin-RevId: 689708392
2024-10-25 02:30:54 -07:00
George Necula
e5bbf3dca1 [jax2tf] Fixes a bad interaction between jax2tf.convert, TF, and call_tf.
Consider the use case when we call_tf a restored saved model that
includes parameters (hence functions closing over tf.Variable), and then
we jax2tf.convert it with native serialization, under tf.function (or
for saving to saved model).

The lowering for call_tf in presence of functions with captured inputs
requires looking up the tf.Variable and reading its value. This fails
with an error that `v.numpy()` is not allowd in graph mode. The fix
is to use `tf.init_scope()` to lift out of graph building mode, so that
we can read the value of the variables.
2024-10-24 17:41:32 +03:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Zhuo Peng
ad74e55dbc Support None leaves in arguments to gradient of a call_tf wrapped function.
PiperOrigin-RevId: 662115139
2024-08-12 09:24:25 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Yash Katariya
d3bfd32667 Remove jax.xla_computation tests from jax2tf. api_test.py has enough coverage for jax.xla_computation
PiperOrigin-RevId: 644605636
2024-06-18 21:01:52 -07:00
Yash Katariya
175183775b Replace jax.xla_computation with the AOT API and add a way to unaccelerate the deprecation in jax tests.
PiperOrigin-RevId: 644535402
2024-06-18 15:47:24 -07:00
George Necula
7c3a4db3e4 [export] Rename some API entry points
We take the opportunity of a new jax.export package to rename some
of the API entry points:

  * `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
    because this is more accurate. The dimension variables are global
    constants, but so is the platform index. And we need to run
    global constant propagation and shape refinement for all of these.
  * We rename "serialization version" with "calling convention version".
    Hence we now have `Exported.calling_convention_version`,
    and the configuration flag is renamed from `--jax-serialization-version`
    to `--jax-export-calling-convention-version`. Also,
    `jax.export.minimum_supported_serialization_version` is now
    `jax.export.minimum_supported_calling_convention_version`.
   * We rename `lowering_platforms` to `platforms` both as a field
    of `Exported` and as the kwarg to `export.export`.
   * We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
2024-06-13 06:44:13 +02:00
George Necula
b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
Jake VanderPlas
9a080f4b83 Test: use context manager to set jax_serialization_version 2024-06-04 16:08:16 -07:00
Jake VanderPlas
ca784a09a3 [jax2tf] test: fix jax serialization version tests 2024-06-04 11:03:06 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Jieying Luo
0a3e432745 [PJRT C API] Enable PJRT C API runtime in jax2tf dlpack.
GetDefaultLayout added a fallback for GPU backend so it is no longer blocked by the fact that PJRT C API does not support GetDefaultLayout yet.

PiperOrigin-RevId: 632555239
2024-05-10 11:30:37 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Seunghoon Park
e00149c39f Fix unnecessary memory copies between GPU and CPU when jax2tf.call_tf() is used.
- The root cause of the bug is that dtype lookups are incorrect because hashes behave differently between dtype instances and their types. Added comments to `jax.dlpack.SUPPORTED_DTYPES` about this.
- Added unit test coverage.
- Fixing this bug revealed a limitation of causing "host-to-device" copy in the following two situations. See the details in the unit test comments.:
  - When the dtype is 'int32'.
  - When using PJRT C API runtime.

PiperOrigin-RevId: 610799558
2024-02-27 10:35:50 -08:00
Dateng Lin
199591f135 Used platform_name for call_tf.
PiperOrigin-RevId: 599200102
2024-01-17 09:39:48 -08:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
Sergei Lebedev
41531123f4 Rolling back #18980, because it is not backwards compatible and breaks existing users.
Reverts 91faddd023c2df77df310f3f2f17eb2fa1e60df0

PiperOrigin-RevId: 591200403
2023-12-15 03:24:01 -08:00
George Necula
fd0f007765 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1 = export.deserialized(ser)
export.call(exp1)
```

This change also includes a workaround to allow users to still
do `from jax.experimental.export import export`, for a while.
2023-12-15 10:00:05 +02:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
George Necula
db44249afc [call_tf] Fix call_tf lowering for multi-platform lowering
call_tf has per-platform lowering because the lowering
of the called TF function may depend on the platform. When
doing multi-platform lowering this means that we lower
call_tf several times and wrap the lowerings with a
conditional. This results in an assertion failure
in add_to_call_tf_concrete_function_list, because we
are attempting to add the same function multiple times.

Here we remove the assertion (afaik, it is Ok to add
multiple functions with the same name, because all
we care about is the index of the called function in
the list). We also reuse the existing function if
we are adding an identical one.

We add tests for call_tf with multi-platform lowering.
2023-10-24 18:57:58 +02:00
George Necula
70f6a9e725 [export] Add support for exporting functions with effects
In presence of ordered effects JAX lowering produces a main
function that takes token
inputs and returns token outputs. Previously, when exporting
such a module, we would wrap the main function with a function
that does not use tokens on inputs and outputs. With this
change we actually leave the token inputs and outputs and
rely on consumers of the exported function to know how to
invoke a function with tokens.

Due to the fact that PJRT does not support passing tokens
as input and output to the top-level function, JAX native
lowering uses dummy bool[0] arrays in lieu of tokens for
the top-level function, and uses stablehlo tokens for the
inner functions. When we export a function for serialization
we want to use stablehlo tokens even at top-level, to enable
calling that function from a larger JAX computation later.

See more details about the calling convention in the
docstring for `export.export`.

We also fix and test multi-platform lowering in presence
of effects.

This introduces serialization version 9, but does not change the
default serialization version. This means that version 9 will not
be used except in tests that specifically override the
serialization version.
2023-10-20 22:27:27 +02:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
George Necula
5b8f91fed7 [jax2tf] Fix higher-order differentiation.
We must ensure that we call jax2tf.convert recursively to ensure
that the proper tf.custom_gradient is used. This means that we can
reuse the conversion of the VJP function between native and graph
serialization.
2023-09-22 07:53:45 +02:00
George Necula
d873ba7b0b [shape_poly] Ensure we can have both ordered effects and dead variables
Fix the case when we export a function with ordered effects and some
of the inputs are dead.
2023-09-21 11:13:37 +02:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
John QiangZhang
dec2366c16 Create the failure test when tf.SavedModel miss the XLACallModule function_list after loading.
PiperOrigin-RevId: 554726455
2023-08-08 00:50:50 -07:00
John QiangZhang
2e35f25b4b Consolidate the code path for both call_tf_graph=True or call_tf_graph=False.
PiperOrigin-RevId: 552605585
2023-07-31 15:18:33 -07:00
John QiangZhang
4114e6c428 Improve the default value of output_shape_dtype.
PiperOrigin-RevId: 549988693
2023-07-21 10:45:02 -07:00
George Necula
46aa9e0b31 Copybara import of the project:
--
b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d by George Necula <gcnecula@gmail.com>:

[shape_poly] Fix lowering when we have both dimension variables and tokens

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16575 from gnecula:call_tf_poly b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d
PiperOrigin-RevId: 544252624
2023-06-28 22:15:16 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
George Necula
0961fb9eba [jax2tf] Add native_lowering_disabled_checks parameter to jax2tf.convert.
Previously, we had a boolean `native_serialization_strict_checks` parameter
that was disabling all safety checks. This mechanism had several
disadvantages:

  * the mechanism did not differentiate between different safety checks.
    E.g., in order to disable checking of the custom call targets, one
    had to disable checking for all custom call targets, and also the
    checking that the serialization and execution platforms are the same.
  * the mechanism operated only at serialization time. Now, the
    XlaCallModule supports a `disabled_checks` attribute to control
    which safety checks should be disabled.

Here we replace the `native_serialization_strict_checks` with
`native_serialization_disabled_checks`, whose values are sequences
of disabled check descriptors.
2023-06-13 08:04:58 +03:00
John QiangZhang
886185831f Clean up the called_name of tf.call_tf_function custom_call.
PiperOrigin-RevId: 537480979
2023-06-02 21:27:18 -07:00
John QiangZhang
277e461046 Flip native serialization strict_check to True.
PiperOrigin-RevId: 537399539
2023-06-02 13:45:08 -07:00
John QiangZhang
ed10293f9c Add new called_index to custom_call tf.backend_config DictAttr.
Here, `called_index` indicates the tf concrete function index in the `function_list` of the parent XLACallModule.

PiperOrigin-RevId: 535417558
2023-05-25 15:58:50 -07:00
jax authors
7de1677011 Add (optional) ordered effects for jax2tf.call_tf
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.

With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:

* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.

For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.

Example StableHLO produced from the added test:

```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.constant dense<> : tensor<0xi1>
    %1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
    return %1#1 : tensor<f32>
  }
  func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.create_token : !stablehlo.token
    %1 = stablehlo.constant dense<0> : tensor<i32>
    %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
     cond {
      %4 = stablehlo.constant dense<4> : tensor<i32>
      %5 = stablehlo.compare  LT, %iterArg_0, %4,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %5 : tensor<i1>
    } do {
      %4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
      %5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
      %6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
      %7 = stablehlo.constant dense<1> : tensor<i32>
      %8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
      stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
    }
    %3 = stablehlo.constant dense<> : tensor<0xi1>
    return %3, %2#2 : tensor<0xi1>, tensor<f32>
  }
}
```

PiperOrigin-RevId: 534926215
2023-05-24 11:48:35 -07:00
John QiangZhang
2c05fe996e Add a new test to cover multiple calls to same tf function when call_tf_graph = True.
PiperOrigin-RevId: 531578811
2023-05-12 12:42:42 -07:00
John QiangZhang
47df8628a0 Fix the problem for tf function return StatefulPartitionedCall during jax2tf.call_tf.
PiperOrigin-RevId: 529964653
2023-05-06 08:30:26 -07:00
John QiangZhang
8acbe1557c Update stablehlo.custom_call call_target name based on design doc discussion.
PiperOrigin-RevId: 529281826
2023-05-03 21:23:13 -07:00
John QiangZhang
7fe62b5406 Bump XLACallModule to version 5 and add the function_list.
PiperOrigin-RevId: 529106145
2023-05-03 09:05:08 -07:00
Christina Sorokin
63d87c6c3d Add new attribute function_list to XLACallModule and bump the version.
PiperOrigin-RevId: 528076798
2023-04-28 22:34:41 -07:00
jax authors
566f17513b Merge pull request #15770 from gnecula:clean_call_tf
PiperOrigin-RevId: 527841341
2023-04-28 03:54:21 -07:00
George Necula
161664e858 [call_tf] Some cleanup of call_tf
The main cleanup is around _code_generator_and_avals, which in
an earlier version of the code was used for both abstract values
and for code generation. That is why it was cached, and why it
returned a code generator and abstract values. A while
ago we did a first round of cleaning to not use it for abstract
values. Now we can actually eliminate the function and inline
it directly.

A second improvement is to add the explicit error message from
TF commpilation, instead of just the generic message that
call_tf cannot be used with non-compileable functions.
2023-04-28 12:38:27 +02:00
John QiangZhang
5b4388ad03 Add new attribute function_list to XLACallModule and bump the version.
PiperOrigin-RevId: 527741961
2023-04-27 18:28:12 -07:00
John QiangZhang
af051ba49c [2/n] Embed the tf.Graph into the stablehlo.custom_call.
PiperOrigin-RevId: 527302563
2023-04-26 10:20:46 -07:00