10505 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
968237080f Add importlib_metadata to project requirements.
This is necessary to ensure we can correctly detect PJRT plugins via
entry_points without compatibility errors.

Prior to this change, there was conditional logic to handle if
importlib_metadata wasn't installed at all. However, it doesn't handle
the case where importlib_metadata is installed by not high enough
version to support Python 3.10 compat. This change gets rid of that
logic and just ensures the right version is installed.

All of this logic can be removed if/when jax requires Python version
>= 3.10

This also removes an unnecessary `requests` dep for the [tpu] install.
2023-05-31 21:03:12 +00:00
jax authors
758d68df13 Restore call_tf_concrete_function_list to previous state
In the following case of nested call:

```
inputs = np.array(range(6), dtype=np.float32).reshape(3, 2)

@jax.jit
def forward(x):
	return x + 1

# JAX -> TF
tf_fn = jax2tf.convert(forward, native_serialization=True)
call_tf_fn = jax2tf.call_tf(tf_fn)
tf_fn_too = jax2tf.convert(call_tf_fn, native_serialization=True)

tf_fn_too(inputs)  # FAIL
```

Without the fix, it fails with the following error:

```
jax/experimental/jax2tf/jax2tf.py", line 499, in _restore_context
    _thread_local_state.call_tf_concrete_function_list.clear()
AttributeError: 'NoneType' object has no attribute 'clear'
```

because we call `_restore_context` twice when executing `jax2tf.convert`ed functions,
the first time we call `_restore_context`, `call_tf_concrete_function_list` is set to `None`
instead of restoring it to the previous state, so the second time we call `_restore_context`,
`call_tf_concrete_function_list.clear()` throws the above error since `call_tf_concrete_function_list` is `None`.

PiperOrigin-RevId: 536650377
2023-05-31 02:23:14 -07:00
Yash Katariya
f884b4d13f Fix the test_sharding_on_output_with_vmap failure in Pathways which was getting a cache miss in pjit_call_impl.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
2023-05-30 19:51:48 -07:00
George Necula
9ad8c3b9f1 [shape_poly] Add static constraint checking to the computation of dim vars
Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).

We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.

For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
2023-05-31 04:48:44 +03:00
Zac Cranko
a192b5e541 improve data parallel example
fix example

fix example

fix example

fix example

fix example

fix example
2023-05-30 01:25:17 +00:00
Jake VanderPlas
7a87995ecd Deprecate jax.interpreters.xla.Buffer, device_put, xla_call_p 2023-05-28 07:15:34 -07:00
Jieying Luo
cb3b7ec93a [PJRT PLUGIN] Add num_processes to distributed.global_state.
The number of processes is needed for multi-process GPU when plugin is used.

PiperOrigin-RevId: 535696950
2023-05-26 13:14:40 -07:00
Yash Katariya
d62bc0f795 Fix the jax2tf failure in mypy: https://github.com/google/jax/actions/runs/5094063162/jobs/9157426652?pr=16155
PiperOrigin-RevId: 535692853
2023-05-26 12:57:37 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
jax authors
9508f3ad9d Merge pull request #16148 from gnecula:export_poly
PiperOrigin-RevId: 535628086
2023-05-26 08:44:41 -07:00
George Necula
46a258ba17 [shape_poly] Add partial support for call_exported with polymorphic shapes
Until now the jax_export.call_exported did not allow calling functions
that were exported with polymorphic shapes. We now add that support,
including resolving the dimension variables of the called function
in terms of the shapes at the call site (which themselves may include
dimension variables), and then computing the output shape of the
called function.

The support is partial in that we can export a JAX function that
calls an exported polymorphic function, but we cannot invoke it.
This is because we do not yet have access to the shape refinement
machinery that XlaCallModule uses. For now, we use XlaCallModule
for invoking exported that includes shape polymorphism.
2023-05-26 17:27:44 +02:00
Yash Katariya
2858df24ff Start the process of removing OpSharding from JAX and replacing it with HloSharding. This will allow for future optimizations of HloSharding to work seamlessly with JAX.
Currently, no function producing HloSharding is being used. I will do that in follow up CLs.

PiperOrigin-RevId: 535622806
2023-05-26 08:19:14 -07:00
jax authors
7833528765 Merge pull request #16143 from jakevdp:fix-shape-poly
PiperOrigin-RevId: 535427698
2023-05-25 16:31:09 -07:00
John QiangZhang
ed10293f9c Add new called_index to custom_call tf.backend_config DictAttr.
Here, `called_index` indicates the tf concrete function index in the `function_list` of the parent XLACallModule.

PiperOrigin-RevId: 535417558
2023-05-25 15:58:50 -07:00
Jake VanderPlas
bbae2edd12 jax2tf: correctly handle opaque dtype in jax2tf pure()
In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure().

PiperOrigin-RevId: 535408671
2023-05-25 15:32:47 -07:00
Jake VanderPlas
b853ce9967 jax2tf: make shape_poly_test pass with custom PRNG 2023-05-25 15:16:46 -07:00
Roy Frostig
3238b627a1 outline jitted jax.random functions
We may want to continue to inline these in Jaxpr somehow, but it's
useful to outline them in HLO for visualization and debugging.
2023-05-25 15:01:04 -07:00
Mark Sandler
bc547aa318 Adds a note that pjit is equivalent to jit.
PiperOrigin-RevId: 535296532
2023-05-25 10:17:25 -07:00
Jake VanderPlas
222b951b19 Use new matrix_transpose in linalg code 2023-05-25 09:32:14 -07:00
Jake VanderPlas
333ff4abbc Add jnp.matrix_transpose() and jax.Array.mT
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Eugene Burmako
e25052c6f8 Use stablehlo.get_minimum_version in jax_export.py
The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API.

PiperOrigin-RevId: 535091927
2023-05-24 21:15:16 -07:00
John QiangZhang
5e82d6b5d5 Fix jax2tf_test regression failure.
PiperOrigin-RevId: 535002015
2023-05-24 15:27:57 -07:00
Sharad Vikram
4fb834b351 Use jaxlib version guard for triton instead of xla_extension_version
PiperOrigin-RevId: 534974834
2023-05-24 14:06:45 -07:00
Yash Katariya
6a54ebd031 Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local fun_caches weakrefDict.
PiperOrigin-RevId: 534971855
2023-05-24 13:58:51 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
jax authors
7de1677011 Add (optional) ordered effects for jax2tf.call_tf
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.

With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:

* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.

For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.

Example StableHLO produced from the added test:

```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.constant dense<> : tensor<0xi1>
    %1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
    return %1#1 : tensor<f32>
  }
  func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.create_token : !stablehlo.token
    %1 = stablehlo.constant dense<0> : tensor<i32>
    %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
     cond {
      %4 = stablehlo.constant dense<4> : tensor<i32>
      %5 = stablehlo.compare  LT, %iterArg_0, %4,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %5 : tensor<i1>
    } do {
      %4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
      %5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
      %6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
      %7 = stablehlo.constant dense<1> : tensor<i32>
      %8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
      stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
    }
    %3 = stablehlo.constant dense<> : tensor<0xi1>
    return %3, %2#2 : tensor<0xi1>, tensor<f32>
  }
}
```

PiperOrigin-RevId: 534926215
2023-05-24 11:48:35 -07:00
Jake Vanderplas
399e4ee87f Copybara import of the project:
--
8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>:

Remove several deprecated APIs

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244
PiperOrigin-RevId: 534897394
2023-05-24 10:35:37 -07:00
Jake VanderPlas
4cfa96ef8f deprecate jax.lax.prod 2023-05-23 17:33:50 -07:00
jax authors
2d525b815d Merge pull request #16103 from jakevdp:deprecation-stacklevel
PiperOrigin-RevId: 534616543
2023-05-23 17:32:17 -07:00
Parker Schuh
016eae4141 Allow disabling the parsing of GSPMDSharding -> NamedSharding.
Because this is best effort, users writing code to handle GPSMDSharding
should be able to deal only with the GSPMDSharding type.

PiperOrigin-RevId: 534612265
2023-05-23 17:16:56 -07:00
Jake VanderPlas
7f7f995bf4 Export jax.lax.sharding_constraint_p
PiperOrigin-RevId: 534566582
2023-05-23 14:50:46 -07:00
Jake VanderPlas
2623473a44 Make deprecation warnings warn at appropriate stacklevel 2023-05-23 14:43:38 -07:00
John Cater
db8716701f Migrate exec_tools back to tools.
PiperOrigin-RevId: 534549617
2023-05-23 14:00:34 -07:00
Matthew Johnson
d42350f879 disable custom_jvp for softmax by default
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
2023-05-23 11:56:50 -07:00
Jake VanderPlas
62fb0cd8a2 explicitly convert jnp.var scalar normalizer to float (from int)
This way we don't pass a potentially-large (Python builtin) int value to an
int32 JAX computation parameter and get an error.

Fixes #15068

Co-authored by: Matthew Johnson <mattjj@google.com>
2023-05-23 09:44:08 -07:00
jax authors
a7b8129ffb Merge pull request #16073 from stellaraccident:extplugin
PiperOrigin-RevId: 534237189
2023-05-22 17:34:51 -07:00
jax authors
13f5090c4c Merge pull request #16018 from ZacCranko:tree_reduce_is_leaf
PiperOrigin-RevId: 534165099
2023-05-22 13:31:04 -07:00
jax authors
69d6c1b13c Merge pull request #16086 from froystig:upgraded-key-ctor
PiperOrigin-RevId: 534152508
2023-05-22 12:46:41 -07:00
Roy Frostig
b7b90e62e3 add new random key constructor
This constructor unconditionally returns a typed key array, regardless
of the value of `jax.config.enable_custom_prng`. We can switch to
referring to it in randomness docs and tutorials as we complete the
typed key upgrade.
2023-05-22 11:35:10 -07:00
Matthew Johnson
61b106ec8f allow lax.dot_general to accept different input dtypes
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.

The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix #10818.

However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
2023-05-22 10:33:42 -07:00
Peter Hawkins
ee14ca2628 Add option jax_include_full_tracebacks_in_locations.
If enabled, includes full stack traces in MLIR emitted by JAX. These cannot be consumed by XLA at the moment.

PiperOrigin-RevId: 534060827
2023-05-22 07:41:29 -07:00
kmillikin
fb99b9efd1 Remove some dead code
`tuple_args` is annotated as `bool`, it should not be `None`.

PiperOrigin-RevId: 534039682
2023-05-22 05:59:39 -07:00
Yash Katariya
b71829f882 Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).
`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
2023-05-20 23:00:35 -07:00
Sholto Douglas
e0b5003880 Allow unconstrained dimensions when using NamedShardings.
PiperOrigin-RevId: 533752415
2023-05-20 16:28:12 -07:00
Parker Schuh
56ca8af9bb Make custom_partitioning support multiple return values.
PiperOrigin-RevId: 533584581
2023-05-19 16:58:54 -07:00
Stella Laurenzo
f832710574 Update comments 2023-05-19 14:29:41 -07:00
Stella Laurenzo
368e20e2a3 Appease flake8 2023-05-19 14:18:56 -07:00
Stella Laurenzo
221aa76d81 Extend plugin discovery to also include entry-points.
This effectively implements a mix of option 2 and option 3 from https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ as a pragmatic way to cover all packaging cases. The namespace/path based iteration works for situations where code has not been packaged and is present on the PYTHONPATH, whereas the advertised entry-points work around setuptools/pkgutil issues that make it impossible to reliably iterate over installed modules in certain scenarios (noted for editable installs which use a custom finder that does not implement iter_modules()).

A plugin entry-point can be advertised in setup.py (or equivalent pyproject.toml) with something like:

```
    entry_points={
        "jax_plugins": [
          "openxla-cpu = jax_plugins.openxla_cpu",
        ],
    }
```
2023-05-19 14:10:08 -07:00
Peter Hawkins
26f2711aeb Fix typo in config.py.
Fixes #16066
2023-05-19 10:07:12 -04:00