16218 Commits

Author SHA1 Message Date
George Necula
5cbc38d4f5 [shape_poly] Keep track of whether a lowering contains shape polymorphism
Previously, we kept the `dim_vars` in the `mlir.ModuleContext`. Now we
replace that with a mutable `ShapePolyLoweringState` that also tracks
whether we encounter shape polymorphism anywhere in the lowering.
For this purpose, we also add `shape_poly_state` to the lowering.compile_args.

We need to keep track of whether a module contains dimension variables
because such modules need shape refinement before they can be converted
to MHLO and compiled. For now, we just test that we set the
`Exported.module_uses_dim_vars` correctly.
2023-05-31 11:40:50 +03:00
Yash Katariya
f884b4d13f Fix the test_sharding_on_output_with_vmap failure in Pathways which was getting a cache miss in pjit_call_impl.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
2023-05-30 19:51:48 -07:00
jax authors
3ad756f7e0 Merge pull request #16176 from gnecula:poly_constraints
PiperOrigin-RevId: 536571493
2023-05-30 19:16:52 -07:00
George Necula
9ad8c3b9f1 [shape_poly] Add static constraint checking to the computation of dim vars
Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).

We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.

For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
2023-05-31 04:48:44 +03:00
jax authors
acfeb9bb13 Merge pull request #16169 from ZacCranko:data_parallel_example
PiperOrigin-RevId: 536260245
2023-05-29 18:39:44 -07:00
Zac Cranko
a192b5e541 improve data parallel example
fix example

fix example

fix example

fix example

fix example

fix example
2023-05-30 01:25:17 +00:00
jax authors
ae9160a4e9 Merge pull request #16159 from jakevdp:deprecations
PiperOrigin-RevId: 536003451
2023-05-28 07:27:45 -07:00
Jake VanderPlas
7a87995ecd Deprecate jax.interpreters.xla.Buffer, device_put, xla_call_p 2023-05-28 07:15:34 -07:00
Sharad Vikram
1279418ce5 Link in CUDA runtime for triton in jaxlib
PiperOrigin-RevId: 535708416
2023-05-26 14:02:16 -07:00
Jieying Luo
cb3b7ec93a [PJRT PLUGIN] Add num_processes to distributed.global_state.
The number of processes is needed for multi-process GPU when plugin is used.

PiperOrigin-RevId: 535696950
2023-05-26 13:14:40 -07:00
Yash Katariya
d62bc0f795 Fix the jax2tf failure in mypy: https://github.com/google/jax/actions/runs/5094063162/jobs/9157426652?pr=16155
PiperOrigin-RevId: 535692853
2023-05-26 12:57:37 -07:00
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
jax authors
25a9a978fb Merge pull request #16151 from hawkinsp:cudnn
PiperOrigin-RevId: 535642800
2023-05-26 09:48:40 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
jax authors
9508f3ad9d Merge pull request #16148 from gnecula:export_poly
PiperOrigin-RevId: 535628086
2023-05-26 08:44:41 -07:00
George Necula
46a258ba17 [shape_poly] Add partial support for call_exported with polymorphic shapes
Until now the jax_export.call_exported did not allow calling functions
that were exported with polymorphic shapes. We now add that support,
including resolving the dimension variables of the called function
in terms of the shapes at the call site (which themselves may include
dimension variables), and then computing the output shape of the
called function.

The support is partial in that we can export a JAX function that
calls an exported polymorphic function, but we cannot invoke it.
This is because we do not yet have access to the shape refinement
machinery that XlaCallModule uses. For now, we use XlaCallModule
for invoking exported that includes shape polymorphism.
2023-05-26 17:27:44 +02:00
Yash Katariya
2858df24ff Start the process of removing OpSharding from JAX and replacing it with HloSharding. This will allow for future optimizations of HloSharding to work seamlessly with JAX.
Currently, no function producing HloSharding is being used. I will do that in follow up CLs.

PiperOrigin-RevId: 535622806
2023-05-26 08:19:14 -07:00
Peter Hawkins
69cf67f252 Bump the minimum CUDNN version for CUDA 12 wheels to 8.9. 2023-05-26 10:04:34 -04:00
Chris Jones
ea37043577 Switch to STATUS_RETURNING callback API.
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
jax authors
7833528765 Merge pull request #16143 from jakevdp:fix-shape-poly
PiperOrigin-RevId: 535427698
2023-05-25 16:31:09 -07:00
John QiangZhang
ed10293f9c Add new called_index to custom_call tf.backend_config DictAttr.
Here, `called_index` indicates the tf concrete function index in the `function_list` of the parent XLACallModule.

PiperOrigin-RevId: 535417558
2023-05-25 15:58:50 -07:00
Jake VanderPlas
bbae2edd12 jax2tf: correctly handle opaque dtype in jax2tf pure()
In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure().

PiperOrigin-RevId: 535408671
2023-05-25 15:32:47 -07:00
jax authors
8534f0bfc3 Merge pull request #16142 from froystig:outline-random-functions
PiperOrigin-RevId: 535406588
2023-05-25 15:25:18 -07:00
Jake VanderPlas
b853ce9967 jax2tf: make shape_poly_test pass with custom PRNG 2023-05-25 15:16:46 -07:00
Roy Frostig
3238b627a1 outline jitted jax.random functions
We may want to continue to inline these in Jaxpr somehow, but it's
useful to outline them in HLO for visualization and debugging.
2023-05-25 15:01:04 -07:00
jax authors
14089fb2f8 Merge pull request #16138 from hawkinsp:cudnn
PiperOrigin-RevId: 535367462
2023-05-25 13:40:34 -07:00
Peter Hawkins
2b7790290b Bump minimum CUDNN version in pip installation to 8.8.
There are known wrong output bugs observed in JAX for earlier versions, in particular related to RNNs.
2023-05-25 14:46:39 -04:00
Peter Hawkins
16368bc672 [XLA:Python] Clean up handling of unsupported types in buffer protocol.
Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails.

Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants.

PiperOrigin-RevId: 535315975
2023-05-25 11:10:19 -07:00
Chris Jones
2155b9181f Switch to using JAX status macros in jax-triton kernel call lib.
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00
Mark Sandler
bc547aa318 Adds a note that pjit is equivalent to jit.
PiperOrigin-RevId: 535296532
2023-05-25 10:17:25 -07:00
Peter Hawkins
32026ad18b Disable random_test_with_custom_prng on CPU under msan.
This test flakily times out in CI.

PiperOrigin-RevId: 535293997
2023-05-25 10:10:01 -07:00
jax authors
24928a507b Merge pull request #16117 from jakevdp:matrix-transpose
PiperOrigin-RevId: 535292507
2023-05-25 10:02:26 -07:00
Jake VanderPlas
222b951b19 Use new matrix_transpose in linalg code 2023-05-25 09:32:14 -07:00
Jake VanderPlas
333ff4abbc Add jnp.matrix_transpose() and jax.Array.mT
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Peter Hawkins
e464dc8700 Reland: [XLA:Python] Add buffer protocol support to jax.Array
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 535248553
2023-05-25 07:20:42 -07:00
Chris Jones
6b13d4eb86 Add branch prediction to JAX status macros.
PiperOrigin-RevId: 535233546
2023-05-25 06:23:23 -07:00
Eugene Burmako
e25052c6f8 Use stablehlo.get_minimum_version in jax_export.py
The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API.

PiperOrigin-RevId: 535091927
2023-05-24 21:15:16 -07:00
Ce Zheng
8e397f7f08 [XLA:Client] Change replicate_last_dim to subgroup_types in HloSharding.iota_tile to cover arbitrary subgroups, adding necessary accessors.
PiperOrigin-RevId: 535079635
2023-05-24 20:26:28 -07:00
John QiangZhang
5e82d6b5d5 Fix jax2tf_test regression failure.
PiperOrigin-RevId: 535002015
2023-05-24 15:27:57 -07:00
Sharad Vikram
4fb834b351 Use jaxlib version guard for triton instead of xla_extension_version
PiperOrigin-RevId: 534974834
2023-05-24 14:06:45 -07:00
Yash Katariya
6a54ebd031 Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local fun_caches weakrefDict.
PiperOrigin-RevId: 534971855
2023-05-24 13:58:51 -07:00
Sharad Vikram
557ca52f10 Add cuda_pip extra for jaxlib
PiperOrigin-RevId: 534957585
2023-05-24 13:19:27 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
jax authors
7de1677011 Add (optional) ordered effects for jax2tf.call_tf
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.

With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:

* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.

For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.

Example StableHLO produced from the added test:

```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.constant dense<> : tensor<0xi1>
    %1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
    return %1#1 : tensor<f32>
  }
  func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.create_token : !stablehlo.token
    %1 = stablehlo.constant dense<0> : tensor<i32>
    %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
     cond {
      %4 = stablehlo.constant dense<4> : tensor<i32>
      %5 = stablehlo.compare  LT, %iterArg_0, %4,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %5 : tensor<i1>
    } do {
      %4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
      %5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
      %6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
      %7 = stablehlo.constant dense<1> : tensor<i32>
      %8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
      stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
    }
    %3 = stablehlo.constant dense<> : tensor<0xi1>
    return %3, %2#2 : tensor<0xi1>, tensor<f32>
  }
}
```

PiperOrigin-RevId: 534926215
2023-05-24 11:48:35 -07:00
Jake Vanderplas
399e4ee87f Copybara import of the project:
--
8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>:

Remove several deprecated APIs

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244
PiperOrigin-RevId: 534897394
2023-05-24 10:35:37 -07:00
jax authors
1831b3cd95 Merge pull request #16105 from kmillikin:main
PiperOrigin-RevId: 534854308
2023-05-24 08:41:51 -07:00
jax authors
2f7cc7d575 Merge pull request #16109 from michaeldeistler:readme-fix
PiperOrigin-RevId: 534844507
2023-05-24 08:11:47 -07:00
Michael Deistler
5f1952df4d
fix typo 2023-05-24 10:43:03 +02:00
Kevin Millikin
921fd222bf Refer to the original map/zip classes via builtins
Referring to them as simply `map` or `zip` will create recursive
reimplementations (with no base case!) if the cell is reevaluated in
the same runtime.
2023-05-24 07:47:50 +01:00
jax authors
d9e7a2abf8 Merge pull request #16102 from jakevdp:deprecate-lax-prod
PiperOrigin-RevId: 534618632
2023-05-23 17:40:18 -07:00