14082 Commits

Author SHA1 Message Date
Chao Chen
a2c9fc02e4 jax-rocm runtime/ci dockerfile multistages 2022-12-12 07:45:12 -08:00
Yash Katariya
13c34f9dc5 Move with_sharding_constraint out of experimental into jax.lax namespace.
PiperOrigin-RevId: 494635809
2022-12-11 22:55:21 -08:00
jax authors
94590e27bc Merge pull request #13562 from gnecula:opaque_shape_poly
PiperOrigin-RevId: 494632979
2022-12-11 22:32:34 -08:00
George Necula
27f5bd057c Improves handling of opaque types for dynamic shapes
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.

The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.

While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.

.
2022-12-12 05:19:04 +01:00
George Necula
2f1354ee04 Add workaround for imprecise shape inference for DynamicGatherOp
This is needed for gather in presence of dynamic shapes.

PiperOrigin-RevId: 494613303
2022-12-11 20:18:15 -08:00
jax authors
4af4234f67 Merge pull request #13602 from mattjj:tweak-array-notebook
PiperOrigin-RevId: 494480882
2022-12-10 22:45:30 -08:00
Matthew Johnson
1185c895ca in jax.Array notebook, polish beginning and tweak title and some wording 2022-12-10 22:16:54 -08:00
Anselm Levskaya
ffb4711969 Expose channel_id in AllToAllOp in both XLA builder and MHLO.
PiperOrigin-RevId: 494334791
2022-12-09 21:58:28 -08:00
Jake VanderPlas
0a2d1cd45e Set bcoo_cusparse_lowering to False by default
This was causing out-of-bound writes on some CUDA backends

PiperOrigin-RevId: 494280591
2022-12-09 15:41:49 -08:00
Tianjian Lu
a8b90b325a [sparse] Fix a bug in BCSR tree_flatten.
PiperOrigin-RevId: 494276415
2022-12-09 15:22:58 -08:00
jax authors
d506770313 Merge pull request #13579 from gnecula:roll_poly
PiperOrigin-RevId: 494274916
2022-12-09 15:16:03 -08:00
jax authors
0777ca6424 Merge pull request #13591 from google:ci_v3-8
PiperOrigin-RevId: 494271663
2022-12-09 15:01:33 -08:00
Skye Wanderman-Milne
8d4b50e397 [TPU CI] Run build matrix on v3-8 as well as v4-8
We're seeing failures on v3-8 that don't appear on the current v4-8
testing. v3-8 also exposes 8 devices (vs. v4-8 exposes 4), and some
tests needs 8 devices to run.

I just added a v3-8 runner VM.

Also adds a missing pip install command (I only caught this with a
fresh runner since it only needs to be installed once).
2022-12-09 22:32:09 +00:00
jax authors
b4fbd835a0 Merge pull request #13587 from jakevdp:callback-doc
PiperOrigin-RevId: 494261419
2022-12-09 14:16:49 -08:00
Jake VanderPlas
df02d7035e DOC: add example of pure_callback with custom_jvp 2022-12-09 12:43:04 -08:00
jax authors
f2c5d287a3 Merge pull request #13568 from hawkinsp:npy124
PiperOrigin-RevId: 494206861
2022-12-09 10:35:07 -08:00
Tianjian Lu
11fbe5542e [sparse] Add rand_bcsr to generate a random BCSR array.
PiperOrigin-RevId: 494201364
2022-12-09 10:10:28 -08:00
jax authors
02f96a22db Merge pull request #13569 from LenaMartens:typecheck
PiperOrigin-RevId: 494167850
2022-12-09 07:34:15 -08:00
lenamartens
7fe466c548 Small fix to scan type-check error message. 2022-12-09 11:41:41 +00:00
George Necula
86a70ab811 [jax2tf] Fix for jnp.roll with shape polymorphism
There was a partial fix before, in #13470, but it was incomplete
and the x64 mode was not handled properly.

There are no tests added here; this was discovered by running the
tests with --jax2tf_default_experimental_native_lowering, which
will become default soon.
2022-12-09 08:08:28 +02:00
Tianjian Lu
942aa7a907 [sparse] Move _dot_general_validated_shape to sparse util.
PiperOrigin-RevId: 494031113
2022-12-08 16:54:43 -08:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Eugene Burmako
2c92037150 Fail lower_jaxpr_to_module if the module fails verification
When working with George on https://github.com/google/jax/pull/13427, I discovered that modules with verifier errors can happily cross API boundaries and create confusion downstream.

As discussed, this is unintentional - the expectation was that `ctx.module.operation.verify()` will throw an exception when verification fails. This CL addresses that and throws an exception accordingly.

Not sure how to test this, given that passing a module with verifier errors to module_to_string indicates a logic error (i.e. such module shouldn't have been produced by JAX in the first place). As a result, I didn't write any tests, but I'm happy to write them if there's a good way to do that.

PiperOrigin-RevId: 493940591
2022-12-08 10:55:49 -08:00
jax authors
02ba16ece8 Merge pull request #13251 from yotarok:toeplitz
PiperOrigin-RevId: 493931922
2022-12-08 10:26:00 -08:00
Yash Katariya
a618f2772d Add device_ids and axis_names to the Mesh repr
PiperOrigin-RevId: 493916858
2022-12-08 09:29:55 -08:00
jax authors
da285b6536 Merge pull request #13566 from hawkinsp:flake8
PiperOrigin-RevId: 493906873
2022-12-08 08:48:36 -08:00
Peter Hawkins
aacf44ae3a flake8 now rejects inline comments.
See:
https://flake8.pycqa.org/en/latest/user/configuration.html (search for
"inline comments").
2022-12-08 16:22:28 +00:00
Yotaro Kubo
1ade5f8592 Add jax.scipy.linalg.toeplitz. 2022-12-09 01:03:21 +09:00
jax authors
440b25bf5d Merge pull request #13427 from gnecula:tf_native_poly
PiperOrigin-RevId: 493800880
2022-12-07 22:41:23 -08:00
George Necula
8fb344a724 [jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.

For native serialization we will support two lowering implementations:

  * one is using the growing support in JAX for dynamic shapes,
  of which shape polymorphism is a special case.
  This implementation is enabled with the --jax_dynamic_shapes flag.
  At the moment, the JAX dynamic shapes support is still
  incomplete and over 300 jax2tf shape polymorphism tests fail.

  * a new one (added) here in which we form a Jaxpr using abstract
  values that express dimension sizes as dimension polynomials
  (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
  This implementation is enabled when --jax_dynamic_shapes is off.
  With this implementation only 50 jax2tf tests fail (to be fixed
  separately).

The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.

The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.

Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.

The key code pattern used in the lowering rule is::

    if not core.is_constant_shape(shape):  # Handles both Var, and polynomials
       shape = mlir.eval_dynamic_shape(ctx, shape)
       return mhlo.DynamicXXX(..., shape)
    else:
       return mhlo.XXX(..., shape)

with `mlir.eval_dynamic_shape` handling both cases::

    def eval_dynamic_shape(ctx, shape):
       if config.jax_dynamic_shapes:
          # Using Var
          return ... subst using ctx.axis_size_env ...
       else:
          # Using polynomials
          return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values

In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.

I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-12-08 08:19:35 +02:00
Yash Katariya
0118f8d568 Prepare for jax and jaxlib 0.4.0 release
PiperOrigin-RevId: 493733609
jax-v0.4.0-rc
2022-12-07 16:02:24 -08:00
Yash Katariya
dd647601c6 Make jnp.copy work with all shardings especially PmapSharding. This fixes the problem where a jax.Array with a PmapSharding round tripped through host and returned a jax.Array with a SingleDeviceSharding.
Now, `jnp.copy` works without going through a round trip via host and maintains the sharding of the input array across all the Shardings we have.

PiperOrigin-RevId: 493728354
2022-12-07 15:41:47 -08:00
jax authors
3159642407 Merge pull request #13547 from jakevdp:numpy-msort
PiperOrigin-RevId: 493649811
2022-12-07 10:50:33 -08:00
jax authors
ce990cfd61 Merge pull request #13546 from jakevdp:shard-types
PiperOrigin-RevId: 493637391
2022-12-07 10:11:53 -08:00
Jake VanderPlas
09d1b6d8d5 Deprecate jnp.msort following deprecation of numpy.msort 2022-12-07 10:08:18 -08:00
Jake VanderPlas
777754d595 [typing] fix Shard/Sharding type TODO 2022-12-07 09:44:03 -08:00
jax authors
794cec15cf Merge pull request #13544 from jakevdp:fix-readme
PiperOrigin-RevId: 493625570
2022-12-07 09:28:25 -08:00
Jake VanderPlas
79c8d6679f README: fix badge URL 2022-12-07 08:48:40 -08:00
Shaobo Hou
6e0c8029cc Fix type annotation for bcoo_update_layout.
PiperOrigin-RevId: 493567424
2022-12-07 04:25:25 -08:00
Peter Hawkins
ac72346ad3 Ensure that the initial dynamic_trace_state is canonicalized.
The non-canonical state meant that we were falling back to a more expensive comparison for the first jit-compiled function in the program. I doubt there will be any impact on real benchmarks, but this perturbs the results of running a single microbenchmark in isolation.

PiperOrigin-RevId: 493489154
2022-12-06 20:39:53 -08:00
jax authors
b168df0aeb Merge pull request #13539 from wookayin:fix/config-typing
PiperOrigin-RevId: 493473725
2022-12-06 19:10:02 -08:00
Jongwook Choi
cd225853f7 Fix a false-positive typing warning on jax.default_device
Consider the following code where static type checkers can report an
error:

```python
CPU = jax.devices('cpu')[0]
with jax.default_device(CPU):
  ...                 # ^^^
```

Error message:
```
Pyright: Argument of type "Device" cannot be assigned to parameter "new_val" of type "NoDefault"
  "Device" is incompatible with "NoDefault" (reportGeneralTypeIssues)
```

This is because `_StateContextManager.__call__` does not have a proper
type annotation on the parameter, unlike the attribute `_default_value`
which has a type annotation. Adding a `Any` to the parameter would
make the error disappear.
2022-12-06 21:05:35 -05:00
jax authors
61398ff409 Merge pull request #13536 from jakevdp:quickstart-timing
PiperOrigin-RevId: 493414640
2022-12-06 14:40:44 -08:00
jax authors
23b808f7d0 Merge pull request #13446 from google:maxfail
PiperOrigin-RevId: 493414635
2022-12-06 14:34:01 -08:00
Jake VanderPlas
7b59ce2f89 DOC: pre-execute the quickstart notebook on GPU 2022-12-06 13:24:02 -08:00
Jieying Luo
1132098c90 [PJRT:C] Separate loading PJRT plugin from creating PJRT client.
- Add xla_client.maybe_load_pjrt_plugins which maybe load PJRT plugins from a hardcoded set.
- Call xla_client.maybe_load_pjrt_plugins in xla_bridge beforer initializing backends.
- Add binding of python method load_pjrt_plugin to LoadPjrtPlugin which does dlopen and dlsym.
- Remove loading PJRT plugin from tpu_initializer_helper.cc.
- Add an extra call to LoadPjrtPlugin when getting the PJRT_Api* to be backward compatible.

PiperOrigin-RevId: 493381393
2022-12-06 12:29:38 -08:00
Peter Hawkins
1bbcec79c1 Update FAQ since buffer donation is implemented on CPU.
PiperOrigin-RevId: 493372426
2022-12-06 11:57:34 -08:00
jax authors
2fb9238bc5 Merge pull request #13534 from jakevdp:fix-rand-test
PiperOrigin-RevId: 493349509
2022-12-06 10:40:35 -08:00
Jake VanderPlas
7e3f6748ec random_test: skip singular covariance test on accelerators 2022-12-06 09:26:23 -08:00
jax authors
edaa5620ed Merge pull request #13517 from froystig:rng-part-jax2tf
PiperOrigin-RevId: 493317025
2022-12-06 08:46:28 -08:00