16268 Commits

Author SHA1 Message Date
Peter Hawkins
cb33fdf3f7 Include SASS/PTX for Hopper GPUs. 2023-06-05 09:42:12 -04:00
Jake VanderPlas
025e348323 Add empty .editorconfig
This will allow us to merge https://github.com/google/jax/pull/16246

PiperOrigin-RevId: 537848620
2023-06-05 05:51:11 -07:00
George Necula
ec8b855fa1 [shape_poly] Add a polymorphic shape refinement MLIR pass accessible to JAX Python.
At the moment we can run the StableHLO module lowered by jax2tf
with polymorphic shapes only with jax2tf, because the tf.XlaCallModule op has the
necessary shape refinement logic (which is necessary to legalize
the StableHLO module with dynamic shapes to MHLO). Here we
expose the shape refinement MLIR transformation to JAX Python.

For now this is used only in a test in jax_export_test.py.

PiperOrigin-RevId: 537485288
2023-06-02 21:49:20 -07:00
John QiangZhang
886185831f Clean up the called_name of tf.call_tf_function custom_call.
PiperOrigin-RevId: 537480979
2023-06-02 21:27:18 -07:00
Yu Emma Wang
983e1c0fd1 Cast "axis" arg of tf.concat to tf.int32.
PiperOrigin-RevId: 537478940
2023-06-02 21:12:31 -07:00
jax authors
5639e194be Merge pull request #16231 from jakevdp:product
PiperOrigin-RevId: 537454290
2023-06-02 18:13:56 -07:00
John QiangZhang
277e461046 Flip native serialization strict_check to True.
PiperOrigin-RevId: 537399539
2023-06-02 13:45:08 -07:00
jax authors
9a76bfb02e Merge pull request #16230 from jakevdp:nightly-maxfail
PiperOrigin-RevId: 537391943
2023-06-02 13:14:29 -07:00
André Susano Pinto
cfabad5886 Avoid IndexError when constructing a ValueError for a DeviceAssignmentMismatchError.
_get_arg_names was throwing IndexError when handling functions with variadic args.

PiperOrigin-RevId: 537308439
2023-06-02 07:43:59 -07:00
jax authors
2e12add64a Merge pull request #16232 from jakevdp:deprecations
PiperOrigin-RevId: 537296369
2023-06-02 06:39:19 -07:00
Jake VanderPlas
3bef6214bb Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct 2023-06-02 04:10:46 -07:00
Jake VanderPlas
592833e02a Change uses of np.product to np.prod
product is deprecated as of numpy 1.25.0.
2023-06-02 03:57:30 -07:00
Jake VanderPlas
1d7dced221 CI: set maxfail=20 for nightly tests 2023-06-02 03:43:29 -07:00
jax authors
6a89abcc76 Merge pull request #16213 from jakevdp:keyarray-shards
PiperOrigin-RevId: 537262590
2023-06-02 03:12:55 -07:00
jax authors
3d2f8267a5 Merge pull request #16224 from ntessore:patch-1
PiperOrigin-RevId: 537260906
2023-06-02 03:04:39 -07:00
jax authors
1ab6e8dd08 Merge pull request #16217 from jakevdp:static-roll
PiperOrigin-RevId: 537260829
2023-06-02 02:56:35 -07:00
Jake VanderPlas
5497bb03b6 jnp.roll: more efficient implementation for static shifts 2023-06-02 01:20:19 -07:00
jax authors
e99045381d Update mentioning of DeviceArray and ShardedDeviceArray to jax.Array in the parallelism tutorial
`jax.Array` is now a unified type for all kinds of arrays.

PiperOrigin-RevId: 537155869
2023-06-01 16:12:59 -07:00
Nicolas Tessore
a835cafdad
Fix incorrect wrapped docstring of jax.scipy.special.gamma
Fixes the docstring `jax.scipy.special.gamma`, which was wrapping `scipy.special.gammaln` by mistake. Also adds a note that the function currently only accepts real inputs.
2023-06-01 20:13:37 +01:00
Parker Schuh
5c2070c204 custom_parititioning: in lower sharding, Sharding should be XLACompatibleSharding.
PiperOrigin-RevId: 537077304
2023-06-01 11:16:08 -07:00
Matthew Johnson
c8311c673e Internal change
PiperOrigin-RevId: 537067965
2023-06-01 10:46:21 -07:00
jax authors
e13cfe71a0 Merge pull request #16209 from jakevdp:softmax-where
PiperOrigin-RevId: 537064949
2023-06-01 10:37:31 -07:00
George Necula
37e254e982 [jax2tf] Add a backward compatibility test for tf.call_tf_function
`tf.call_tf_function` arises from `jax2tf.call_tf(tf_fun, call_tf_graph)`. However, a function that contains this can be lowered and executed only with `jax2tf.convert` and ought to be serialized as ` tf.Graph` because the serialization includes a tf.function as well.

In order to support this we need to add some code to back_compat_test.py to serialize and run the serialized code as tf.Graph.

PiperOrigin-RevId: 537062963
2023-06-01 10:29:44 -07:00
Jake VanderPlas
c474de424a jax.nn.softmax: fix fill value when where is specified 2023-06-01 10:18:05 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
jax authors
adca0fa9b8 Merge pull request #16123 from mattjj:refine-vmap-frame-getting
PiperOrigin-RevId: 537047149
2023-06-01 09:42:43 -07:00
jax authors
ae78de1a49 Merge pull request #16189 from skye:profiling_docs
PiperOrigin-RevId: 537046864
2023-06-01 09:34:46 -07:00
jax authors
6adbead17d Merge pull request #16168 from jakevdp:dead-code
PiperOrigin-RevId: 537012454
2023-06-01 07:07:40 -07:00
Jake VanderPlas
3a7ccf70f2 custom prng: add shard methods to PRNGKeyArrayImpl 2023-06-01 04:10:12 -07:00
jax authors
eb41e9c331 Merge pull request #16212 from jakevdp:fix-doc
PiperOrigin-RevId: 536977403
2023-06-01 03:59:29 -07:00
Jake VanderPlas
e5cd69479b DOC: fix doc formatting 2023-06-01 03:37:21 -07:00
jax authors
4ffeafed85 Merge pull request #15382 from IvyZX:pytree
PiperOrigin-RevId: 536932961
2023-06-01 00:06:16 -07:00
jax authors
a45fbef807 Merge pull request #16199 from gnecula:poly_state
PiperOrigin-RevId: 536932763
2023-05-31 23:57:56 -07:00
Yash Katariya
4c48611fba Finish jax and jaxlib 0.4.11 release
PiperOrigin-RevId: 536931532
2023-05-31 23:49:32 -07:00
jax authors
15299eb2ee Merge pull request #16150 from jakevdp:loop-error
PiperOrigin-RevId: 536913485
2023-05-31 21:54:25 -07:00
ivyzheng
6bf1cbc667 Add key path related guide & code to the documentation. 2023-05-31 20:15:56 -07:00
Yash Katariya
48ad9a6f3e Start jax and jaxlib 0.4.11 release
PiperOrigin-RevId: 536860076
jax-v0.4.11 jaxlib-v0.4.11 jax-v0.4.11-rc
2023-05-31 16:48:52 -07:00
jax authors
525ba49ba7 Merge pull request #16204 from skye:importlib_metadata_version
PiperOrigin-RevId: 536823622
2023-05-31 14:27:53 -07:00
Skye Wanderman-Milne
968237080f Add importlib_metadata to project requirements.
This is necessary to ensure we can correctly detect PJRT plugins via
entry_points without compatibility errors.

Prior to this change, there was conditional logic to handle if
importlib_metadata wasn't installed at all. However, it doesn't handle
the case where importlib_metadata is installed by not high enough
version to support Python 3.10 compat. This change gets rid of that
logic and just ensures the right version is installed.

All of this logic can be removed if/when jax requires Python version
>= 3.10

This also removes an unnecessary `requests` dep for the [tpu] install.
2023-05-31 21:03:12 +00:00
Jieying Luo
b35c20ce5d Use xla_extension_version and remove some dead version check in xla_bridge_test.py.
Min jaxlib requires xla_extension_version >= 144.

PiperOrigin-RevId: 536810415
2023-05-31 13:50:07 -07:00
jax authors
727c121169 Merge pull request #16188 from nouiz:ci_jestimator
PiperOrigin-RevId: 536810121
2023-05-31 13:41:29 -07:00
jax authors
c587dac134 Merge pull request #16203 from skye:tpu_py_version2
PiperOrigin-RevId: 536776189
2023-05-31 11:39:06 -07:00
Skye Wanderman-Milne
131d28ba0f Use default Python version on Cloud TPU CI 2023-05-31 18:04:39 +00:00
Yash Katariya
6d6ba70c78 Disable the RunnTest.test_lstm1 test since it is fixed for cudnn >= 8.8
PiperOrigin-RevId: 536693061
2023-05-31 06:21:01 -07:00
jax authors
758d68df13 Restore call_tf_concrete_function_list to previous state
In the following case of nested call:

```
inputs = np.array(range(6), dtype=np.float32).reshape(3, 2)

@jax.jit
def forward(x):
	return x + 1

# JAX -> TF
tf_fn = jax2tf.convert(forward, native_serialization=True)
call_tf_fn = jax2tf.call_tf(tf_fn)
tf_fn_too = jax2tf.convert(call_tf_fn, native_serialization=True)

tf_fn_too(inputs)  # FAIL
```

Without the fix, it fails with the following error:

```
jax/experimental/jax2tf/jax2tf.py", line 499, in _restore_context
    _thread_local_state.call_tf_concrete_function_list.clear()
AttributeError: 'NoneType' object has no attribute 'clear'
```

because we call `_restore_context` twice when executing `jax2tf.convert`ed functions,
the first time we call `_restore_context`, `call_tf_concrete_function_list` is set to `None`
instead of restoring it to the previous state, so the second time we call `_restore_context`,
`call_tf_concrete_function_list.clear()` throws the above error since `call_tf_concrete_function_list` is `None`.

PiperOrigin-RevId: 536650377
2023-05-31 02:23:14 -07:00
George Necula
5cbc38d4f5 [shape_poly] Keep track of whether a lowering contains shape polymorphism
Previously, we kept the `dim_vars` in the `mlir.ModuleContext`. Now we
replace that with a mutable `ShapePolyLoweringState` that also tracks
whether we encounter shape polymorphism anywhere in the lowering.
For this purpose, we also add `shape_poly_state` to the lowering.compile_args.

We need to keep track of whether a module contains dimension variables
because such modules need shape refinement before they can be converted
to MHLO and compiled. For now, we just test that we set the
`Exported.module_uses_dim_vars` correctly.
2023-05-31 11:40:50 +03:00
Yash Katariya
f884b4d13f Fix the test_sharding_on_output_with_vmap failure in Pathways which was getting a cache miss in pjit_call_impl.
There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function.

In the test, check for re-compilation which is what that test was doing before cl/535630905

PiperOrigin-RevId: 536575987
2023-05-30 19:51:48 -07:00
jax authors
3ad756f7e0 Merge pull request #16176 from gnecula:poly_constraints
PiperOrigin-RevId: 536571493
2023-05-30 19:16:52 -07:00
George Necula
9ad8c3b9f1 [shape_poly] Add static constraint checking to the computation of dim vars
Previously we had one function `shape_poly.unify_avals_with_args` that was
solving the dimension variables and was also used for generating the code
to compute them. Now we separate the solving part, which is now using just
symbolic expressions (`shape_poly.solve_dim_vars`), from the code
generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`).

We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or
`dimexpr1 >= dimexpr2`, under which the solution for the dimension variables
is valid.

For now we implement the static checking of the shape constraints, e.g., when
the dimension expressions are constant or TF EagerTensor. We do not yet
have compile-time checking of the constraints. This matches
the previous behavior. However, the code is now ready for implementing
compile-time checking of the constraints that cannot be checked statically.
2023-05-31 04:48:44 +03:00
Skye Wanderman-Milne
1d1429fe8b Update profiling docs.
* Mention that Tensorboard profiling supports device memory usage
* Recommend TB profiling instead of the pprof-based device memory profiling
* Minor updates to GCP instructions

Inspired by https://github.com/google/jax/issues/1491
2023-05-30 14:27:11 -07:00