15026 Commits

Author SHA1 Message Date
Yash Katariya
1ee750e795 Pass the jaxpr from pjit since there is no need to trace it again in lower_sharding_computation. It also helps in preserving debug_info that already exists on the jaxpr to surface it in MHLO eventually.
PiperOrigin-RevId: 513268085
2023-03-01 10:05:45 -08:00
Peter Hawkins
ed491b3056 Shorten alias chains for names exported in jax. namespace.
Add some additional type annotations on public APIs.

This allows pytype to do a better job of type inference.

PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
jax authors
fa1ea37704 Merge pull request #14658 from JiaYaobo:chisq_and_f_dist
PiperOrigin-RevId: 513220241
2023-03-01 06:35:34 -08:00
Jake VanderPlas
ae6c4676d4 [sparse] add low-level primitives wrapping cuda SpMV & SpMM
This is in preparation for cleaning up our bcoo_dot_general GPU lowering rules: by creating private primitives that closely follow the API of the cusparse kernels, we will be able to better express lowered translation rules that preprocess that data appropriately.

PiperOrigin-RevId: 513212715
2023-03-01 05:56:31 -08:00
Jake VanderPlas
97f819b1ed [sparse] fix dot_general precision in test
PiperOrigin-RevId: 513205756
2023-03-01 05:10:42 -08:00
Anish Tondwalkar
3bad6fa223 [CHLO] Add erf_inv and lowering to mhlo
PiperOrigin-RevId: 513183138
2023-03-01 02:52:52 -08:00
jiayaobo
fdf8ac18d6 add random.chisquare and random.f
add chi2 and F random variables methods

add chi2 and F random variables methods

fix F rv shape broadcasting

fix shape broadcasting
2023-03-01 15:03:50 +08:00
Anish Tondwalkar
713bc2687d [mhlo] Use XLA pretty-printed format for shardingattr
PiperOrigin-RevId: 513116413
2023-02-28 20:22:16 -08:00
jax authors
b1adbfc57b [XLA:Python] Add buffer protocol support to jax.Array.
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 513086379
2023-02-28 17:35:40 -08:00
Yash Katariya
571ade4fde Use math.prod instead of util.prod
PiperOrigin-RevId: 513065029
2023-02-28 16:04:00 -08:00
jax authors
7bdad987ea Merge pull request #14722 from jakevdp:gamma-doc
PiperOrigin-RevId: 513049078
2023-02-28 15:07:01 -08:00
Peter Hawkins
2976431b1a [XLA:Python] Add buffer protocol support to jax.Array.
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.

Fixes https://github.com/google/jax/issues/14713

PiperOrigin-RevId: 513047925
2023-02-28 14:59:08 -08:00
Peter Hawkins
a4412e2715 Remove internal ndarray type name. Use Array throughout.
jax.numpy.ndarray remains an exported alias for jax.Array.

PiperOrigin-RevId: 513046188
2023-02-28 14:51:08 -08:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Johannes Reifferscheid
3ecff30b4e Don't create invalid bools in lax_numpy_test/testView.
Currently, JAX is generating random 8 bit ints for bools, which usually doesn't cause any issues, but in some special cases does. One example is the HLO snapshot dumping code, which surprisingly creates unparseable protos for such inputs.

PiperOrigin-RevId: 513032802
2023-02-28 14:03:08 -08:00
Jake VanderPlas
93637f6ff9 Refactor bcoo_dot_general GPU lowering
The goal of this is to make it easier to address the out-of-bound index issue. Our current GPU logic grew somewhat organically over time, and the logic for which sub-routine is called is spread over multiple locations. This change updates the branching such that the logic for each sub-routine appears directly adjacent to its call site; the tradeoff is that other considerations (such as whether to raise a warning) have to be duplicated between the cases.

Additionally, I simplified some of the hlo operation calls to make the code easier to follow.

PiperOrigin-RevId: 513025719
2023-02-28 13:36:52 -08:00
Jake VanderPlas
9fb4bcb1e1 DOC: mention scale/rate parameter in random.gamma 2023-02-28 13:33:46 -08:00
Jake VanderPlas
06441883b9 [sparse] temporarily disable bcoo_dot_general_sampled fast cases test on GPU
This is failing with precision issues on some GPU architectures; it's not clear why.

PiperOrigin-RevId: 513021864
2023-02-28 13:23:54 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Lena Martens
4f48f94649 Update api_benchmark to not use any deprecated APIs.
PiperOrigin-RevId: 512941633
2023-02-28 08:33:26 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
c76ccf9ed7 Merge pull request #14587 from gnecula:tf_cross_platform
PiperOrigin-RevId: 512868892
2023-02-28 02:09:24 -08:00
George Necula
40d9fad9eb A different way to achieve cross-platform lowering, withouth any
changes to JAX core.

Sets up a new mesh context manager.
Also needs to override _get_and_check_device_assignment.
2023-02-28 04:57:51 +02:00
George Necula
5ef6f15d16 [jax2tf] Add support for cross-platform lowering in native serialization
Allow the user of native serialization to specify the platform for which
the serialization to be done. This relies on newly added support for
platform checking in XlaCallModule op (version 3).
2023-02-28 04:50:18 +02:00
Parker Schuh
eef3e69c61 Add PyArrayResultHandler which behaves like
functools.partial(jax.arrays.ArrayImpl) with the added benefit
that the new PyExecuteResults type can explode directly into
ArrayImpls if passed to explode_with_handlers().

Note that this also helps with deprecating PyBuffer as the fastpath
does not need to call the PyBuffer constructor.

PiperOrigin-RevId: 512788757
2023-02-27 18:26:53 -08:00
jax authors
586fe8d552 Merge pull request #14570 from mattjj:custom-jvp-symbolic-zeros-2
PiperOrigin-RevId: 512773473
2023-02-27 17:10:21 -08:00
jax authors
41ad78125b Merge pull request #14708 from skye:readme
PiperOrigin-RevId: 512751327
2023-02-27 15:41:30 -08:00
Skye Wanderman-Milne
56b237cfbc Update Cloud TPU install command to be simpler.
We used to need the extra stuff for a very old Cloud TPU VM image, but we don't anymore.
2023-02-27 23:15:19 +00:00
Yash Katariya
38ba6683dc Mention that Pspecs are not allowed to be passed to jax.jit
PiperOrigin-RevId: 512727888
2023-02-27 14:13:45 -08:00
jax authors
fa3a7d0593 Merge pull request #14703 from jakevdp:bcoo-precision
PiperOrigin-RevId: 512705050
2023-02-27 12:48:34 -08:00
Jake VanderPlas
f911acee05 [sparse] use precision=HIGHEST in bcoo_dot_general_sampled 2023-02-27 12:12:11 -08:00
Peter Hawkins
148774587a Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
jax authors
bcf378f6b4 Merge pull request #14701 from jakevdp:doc-devicearray
PiperOrigin-RevId: 512684443
2023-02-27 11:33:07 -08:00
jax authors
f0d816f899 Merge pull request #14673 from nouiz:gpu_doc
PiperOrigin-RevId: 512669380
2023-02-27 10:49:52 -08:00
Jake VanderPlas
b09b4ba51f DOC: fix jax.numpy.Array discussion 2023-02-27 10:45:06 -08:00
Peter Hawkins
055fa6b90f Remove pytype suppression for jax/_src/config.py
This file no longer seems to make pytype unhappy.

PiperOrigin-RevId: 512668863
2023-02-27 10:39:55 -08:00
jax authors
5035c80589 Merge pull request #14674 from jakevdp:dot-general-doc
PiperOrigin-RevId: 512665258
2023-02-27 10:27:40 -08:00
Jake VanderPlas
4918b9d1d0 DOC: improve lax.dot_general documentation 2023-02-27 09:46:04 -08:00
George Necula
0cdb7f9997 [jax2tf] Include more sharding annotations in the TF graph
In the past we had encountered errors with sharding annotations for CPU/GPU (e.g., crashes; these have been fixed) and when executing in TF eager mode. To work around those we had decided to skip the replicated sharding annotations, which arise often now that all `jit` functions will assume by default replicated shardings. Then we have discovered that we were skipping too many sharding annotations and we made changes to include all inner sharding annotations, but still skip the replicated sharding annotations on inputs and outputs.

It is unsafe to skip annotations, and here we try to include as many sharding annotations as we can. The only case when we cannot include sharding annotations is under TF eager mode. There is should be safe to skip the replicated annotations in eager mode, counting on the fact that we will raise an error if we encounter non-replicated annotations. Such functions must be executed in tf.function mode.

Specifically under tf.function, which is the most important use case, we now include all sharding annotations.

At the same time, I added more tests and I strengthened some tests to check the presence of the sharding annotations in the TF HLO.

PiperOrigin-RevId: 512417862
2023-02-26 04:38:12 -08:00
jax authors
7217686d94 Merge pull request #14684 from sharadmv:flake-fix
PiperOrigin-RevId: 512318604
2023-02-25 11:52:03 -08:00
Sharad Vikram
18c6cbeaf7 Remove TokenSet needing to have effects in a certain order 2023-02-25 11:15:23 -08:00
jax authors
8ebfb0be48 Merge pull request #14614 from sharadmv:ref
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Peter Hawkins
b61d5d5654 Remove jax._src deletion.
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.

It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:

```
from jax._src.lib import xla_bridge

@mock.patch.object(xla_bridge, 'process_index')
...
```

A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:

```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```

However, this solution requires the `jax._src` be present in the JAX namespace.

Ideally users wouldn't mock our internals at all, but that requires significantly more work.

PiperOrigin-RevId: 512295203
2023-02-25 07:17:47 -08:00
pizzud
0292f5d0a6 lax_scipy_test: Revert split into three targets.
Somehow the spectral_dac functionality is flaky on its own when run on CPU.

PiperOrigin-RevId: 512195860
2023-02-24 16:56:40 -08:00
Yash Katariya
d84ac2240c Remove use_stablehlo as minimum mlir_api_version >= 43
PiperOrigin-RevId: 512176274
2023-02-24 15:20:09 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Jake VanderPlas
7f6826659e BUG: raise error when shaped_abstractify is called on JAX scalar types
PiperOrigin-RevId: 512163825
2023-02-24 14:27:57 -08:00
Frederic Bastien
ec817974aa Add a new link instead of a TODO. 2023-02-24 13:54:16 -08:00
Yash Katariya
d277358200 Create avals and pass them to _check_sharding rather than the actual value.
PiperOrigin-RevId: 512142679
2023-02-24 12:56:16 -08:00
Frederic Bastien
86191077ff Small fix as the module name changed. 2023-02-24 12:37:56 -08:00