4005 Commits

Author SHA1 Message Date
Yash Katariya
fe3fed3627 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy.
PiperOrigin-RevId: 535687618
2023-05-26 12:35:16 -07:00
Yash Katariya
4f074718d4 Make pjit_call_impl go via C++ dispatch.
This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl.

PiperOrigin-RevId: 535630905
2023-05-26 08:57:30 -07:00
Yash Katariya
2858df24ff Start the process of removing OpSharding from JAX and replacing it with HloSharding. This will allow for future optimizations of HloSharding to work seamlessly with JAX.
Currently, no function producing HloSharding is being used. I will do that in follow up CLs.

PiperOrigin-RevId: 535622806
2023-05-26 08:19:14 -07:00
Roy Frostig
3238b627a1 outline jitted jax.random functions
We may want to continue to inline these in Jaxpr somehow, but it's
useful to outline them in HLO for visualization and debugging.
2023-05-25 15:01:04 -07:00
Mark Sandler
bc547aa318 Adds a note that pjit is equivalent to jit.
PiperOrigin-RevId: 535296532
2023-05-25 10:17:25 -07:00
Jake VanderPlas
222b951b19 Use new matrix_transpose in linalg code 2023-05-25 09:32:14 -07:00
Jake VanderPlas
333ff4abbc Add jnp.matrix_transpose() and jax.Array.mT
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Sharad Vikram
4fb834b351 Use jaxlib version guard for triton instead of xla_extension_version
PiperOrigin-RevId: 534974834
2023-05-24 14:06:45 -07:00
Yash Katariya
6a54ebd031 Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local fun_caches weakrefDict.
PiperOrigin-RevId: 534971855
2023-05-24 13:58:51 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
jax authors
7de1677011 Add (optional) ordered effects for jax2tf.call_tf
This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens.

With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior:

* With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry.
* Without `call_tf_graph=True`, this raises a `NotImplementedError()`.

For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed.

Example StableHLO produced from the added test:

```
module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.constant dense<> : tensor<0xi1>
    %1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>)
    return %1#1 : tensor<f32>
  }
  func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.create_token : !stablehlo.token
    %1 = stablehlo.constant dense<0> : tensor<i32>
    %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32>
     cond {
      %4 = stablehlo.constant dense<4> : tensor<i32>
      %5 = stablehlo.compare  LT, %iterArg_0, %4,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %5 : tensor<i1>
    } do {
      %4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token
      %5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
      %6 = stablehlo.add %iterArg_1, %5 : tensor<f32>
      %7 = stablehlo.constant dense<1> : tensor<i32>
      %8 = stablehlo.add %iterArg_0, %7 : tensor<i32>
      stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32>
    }
    %3 = stablehlo.constant dense<> : tensor<0xi1>
    return %3, %2#2 : tensor<0xi1>, tensor<f32>
  }
}
```

PiperOrigin-RevId: 534926215
2023-05-24 11:48:35 -07:00
jax authors
2d525b815d Merge pull request #16103 from jakevdp:deprecation-stacklevel
PiperOrigin-RevId: 534616543
2023-05-23 17:32:17 -07:00
Jake VanderPlas
2623473a44 Make deprecation warnings warn at appropriate stacklevel 2023-05-23 14:43:38 -07:00
Matthew Johnson
d42350f879 disable custom_jvp for softmax by default
Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
2023-05-23 11:56:50 -07:00
Jake VanderPlas
62fb0cd8a2 explicitly convert jnp.var scalar normalizer to float (from int)
This way we don't pass a potentially-large (Python builtin) int value to an
int32 JAX computation parameter and get an error.

Fixes #15068

Co-authored by: Matthew Johnson <mattjj@google.com>
2023-05-23 09:44:08 -07:00
jax authors
a7b8129ffb Merge pull request #16073 from stellaraccident:extplugin
PiperOrigin-RevId: 534237189
2023-05-22 17:34:51 -07:00
jax authors
13f5090c4c Merge pull request #16018 from ZacCranko:tree_reduce_is_leaf
PiperOrigin-RevId: 534165099
2023-05-22 13:31:04 -07:00
jax authors
69d6c1b13c Merge pull request #16086 from froystig:upgraded-key-ctor
PiperOrigin-RevId: 534152508
2023-05-22 12:46:41 -07:00
Roy Frostig
b7b90e62e3 add new random key constructor
This constructor unconditionally returns a typed key array, regardless
of the value of `jax.config.enable_custom_prng`. We can switch to
referring to it in randomness docs and tutorials as we complete the
typed key upgrade.
2023-05-22 11:35:10 -07:00
Matthew Johnson
61b106ec8f allow lax.dot_general to accept different input dtypes
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.

The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix #10818.

However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
2023-05-22 10:33:42 -07:00
Peter Hawkins
ee14ca2628 Add option jax_include_full_tracebacks_in_locations.
If enabled, includes full stack traces in MLIR emitted by JAX. These cannot be consumed by XLA at the moment.

PiperOrigin-RevId: 534060827
2023-05-22 07:41:29 -07:00
kmillikin
fb99b9efd1 Remove some dead code
`tuple_args` is annotated as `bool`, it should not be `None`.

PiperOrigin-RevId: 534039682
2023-05-22 05:59:39 -07:00
Yash Katariya
b71829f882 Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).
`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD.

PiperOrigin-RevId: 533801596
2023-05-20 23:00:35 -07:00
Sholto Douglas
e0b5003880 Allow unconstrained dimensions when using NamedShardings.
PiperOrigin-RevId: 533752415
2023-05-20 16:28:12 -07:00
Stella Laurenzo
f832710574 Update comments 2023-05-19 14:29:41 -07:00
Stella Laurenzo
368e20e2a3 Appease flake8 2023-05-19 14:18:56 -07:00
Stella Laurenzo
221aa76d81 Extend plugin discovery to also include entry-points.
This effectively implements a mix of option 2 and option 3 from https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ as a pragmatic way to cover all packaging cases. The namespace/path based iteration works for situations where code has not been packaged and is present on the PYTHONPATH, whereas the advertised entry-points work around setuptools/pkgutil issues that make it impossible to reliably iterate over installed modules in certain scenarios (noted for editable installs which use a custom finder that does not implement iter_modules()).

A plugin entry-point can be advertised in setup.py (or equivalent pyproject.toml) with something like:

```
    entry_points={
        "jax_plugins": [
          "openxla-cpu = jax_plugins.openxla_cpu",
        ],
    }
```
2023-05-19 14:10:08 -07:00
Peter Hawkins
26f2711aeb Fix typo in config.py.
Fixes #16066
2023-05-19 10:07:12 -04:00
Jieying Luo
9da52e8905 [PJRT PLUGIN] Provide a register_plugin method that plugin can use to register their backend factory.
The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method.

Logics to register a plugin from ENV is not deleted to facilitate development with ENV.

PiperOrigin-RevId: 533280115
2023-05-18 16:13:02 -07:00
Yash Katariya
4a5c6f8200 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 533263584
2023-05-18 15:09:54 -07:00
jax authors
6034e87ddf Convert tuple to DeviceAssignment on the replicated compilation path.
PiperOrigin-RevId: 533258935
2023-05-18 14:55:29 -07:00
Peter Hawkins
39097df02e Add some preliminary support for int4/uint4 types to JAX.
PiperOrigin-RevId: 533251630
2023-05-18 14:27:33 -07:00
Roy Frostig
717d3c88fc inline and remove eq_mlir and ne_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
f18bff5371 inline and remove scatter_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
cc54b6e6ad inline and remove select_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
301d058b3d inline and remove gather_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
071c77e5bb inline and remove transpose_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
06132ac764 inline and remove broadcast_in_dim_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
0ac792f4ed inline and remove dynamic_update_slice_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
2dbdf1a6c1 inline and remove dynamic_slice_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
aed77c5031 inline and remove slice_mlir rules 2023-05-17 20:07:58 -07:00
Roy Frostig
129a4a5f35 inline and remove empty_mlir rules 2023-05-17 20:07:58 -07:00
Roy Frostig
180e26dafb remove physical_avals rule in favor of physical_element_aval 2023-05-17 20:07:58 -07:00
jax authors
f3cecd07c7 Merge pull request #16047 from jakevdp:prngkey-replicated
PiperOrigin-RevId: 532959862
2023-05-17 17:21:54 -07:00
Jake VanderPlas
48abe7c684 PRNGKeyArray: add several missing attributes & methods 2023-05-17 14:47:22 -07:00
Peter Hawkins
389564551b Second attempt at fixing warnings from jax.dtypes.issubdtype.
PiperOrigin-RevId: 532902714
2023-05-17 14:10:22 -07:00
Jieying Luo
2aa2282ea1 [PJRT PLUGIN] Add automatic discovery of PJRT plugins from pip installed packages.
The plugins in the namespace package `jax_plugins` will be imported. The plugins need to (1) be placed in a root folder `jax_plugins` and follow other namespace package requirements, and (2) implement an initialize() method which appends `plugin_name:file_path` to env var `PJRT_NAMES_AND_LIBRARY_PATHS`.

Appending to PJRT_NAMES_AND_LIBRARY_PATHS is a short term solution and what the initialize() should do is in discussion.

PiperOrigin-RevId: 532897890
2023-05-17 13:56:22 -07:00
jax authors
77f8bbc08d Make device count assert in _to_xla_op_sharding() more informative.
PiperOrigin-RevId: 532861091
2023-05-17 11:58:15 -07:00
Yash Katariya
f1c2711292 Add impl rule for with_sharding_constraint so that users can use their functions with and without a jit.
The semantics of eager wsc is the same as within a jit i.e. it will reshard to the given sharding only if the devices are the same and in the same order.

eager wsc won't work as expected with AD transpose because there is no `src` argument to reverse the shardings when transposing and was decided that it is fine for now. jax.device_put should be the API to use for that.

PiperOrigin-RevId: 532858670
2023-05-17 11:50:12 -07:00
jax authors
79be3482ca Merge pull request #16041 from LenaMartens:checkify-leaks
PiperOrigin-RevId: 532843752
2023-05-17 11:04:45 -07:00