115 Commits

Author SHA1 Message Date
Peter Hawkins
69da839358 Remove test code that checks for the se_tpu runtime.
This runtime no longer exists.

PiperOrigin-RevId: 568242078
2023-09-25 09:30:07 -07:00
Peter Hawkins
5aaa15df84 Remove the skip_on_xla_cpu_mlir decorator.
We no longer test this variant in CI, so we don't need code to skip it.

PiperOrigin-RevId: 568219651
2023-09-25 08:04:56 -07:00
jax authors
6b5af15eea Merge pull request #17593 from jakeh-gc:test_changes
PiperOrigin-RevId: 565428268
2023-09-14 11:30:55 -07:00
Jake Hall
f59a4163fa Test changes for out-of-tree backend. 2023-09-14 12:18:37 +01:00
Roy Frostig
6abefa1977 fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
2023-09-13 09:43:58 -07:00
Yash Katariya
c41d271175 Add memories support to remat.
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.

Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.

A very basic offloadable policy can look like this:

```
def policy(prim, *avals, **params):
  return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
2023-09-12 20:50:05 -07:00
Jake VanderPlas
9289f3250b Add missing preferred_element_type tests
Followup to https://github.com/google/jax/pull/17506
2023-09-08 13:07:37 -07:00
Jake VanderPlas
6f3f0d5e57 build: write appropriate version strings to build artifacts 2023-09-07 08:45:48 -07:00
Peter Hawkins
cf2d61c701 Move --jax_test_dut into the non-public test utilities.
This flag only exists for the use of JAX's own tests, and doesn't need to exist for most JAX users.

4f805c2d8f allows this move since none of the test comparison utilities now depend on the choice of backend. (That dependency was only an administrative dependency for external users of JAX, since the only public users of the test comparison utilites are the gradient utilities, which always override the default tolerance with their own tolerances.)

PiperOrigin-RevId: 563195002
2023-09-06 13:21:48 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
jax authors
fe107fd119 Create a context manager to register/unregister event duration listeners.
PiperOrigin-RevId: 559209418
2023-08-22 13:35:45 -07:00
jax authors
b3a36efafa Merge pull request #16965 from jakeh-gc:fp8_fnuz
PiperOrigin-RevId: 554523823
2023-08-07 10:43:18 -07:00
Jake Hall
85f124c18d Add support for float8_e4m3fnuz and float8_e5m2fnuz. 2023-08-07 11:48:53 +01:00
George Necula
f2dbde5b2e [jax2tf] Sanitize the parameterized test case names to be friendly to -k
Sanitizes the name of tests so that a name matches the rules of an
identifier for pytest -k and unittest -k test filtering. Sequences
of problematic characters are replaced with a single "_".
2023-08-07 12:08:21 +03:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
George Necula
e643f98558 [shape_poly] Reimplement the shape constraint checking using shape assertions.
Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.

As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
2023-07-19 09:56:33 +03:00
Jake VanderPlas
31c5044c1d Make jit work with custom float inputs 2023-07-12 13:06:03 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
George Necula
5b035d5f76 [shape_poly] Improve compile-time shape checking.
JAX shape polymorphism relies on implicit assumptions.
For example, when tracing with input specification `(a, a)`,
we assume that the first two dimensions have the same size
greater or equal to 1.

Here we extend the checking that these assumptions hold. When
we call an `Exported` module from jax, with `jax_export.call_exported`
we check these assumptions statically. However, when we
stage an `Exported` using `XlaCallModule` to be called from
TensorFlow, or when we use TF graph serialization we need
to check these assumptions when we execute and compile
the op (that is when the shapes are available).

To prepare for this compile-time shape checking we add
`Exported.shape_check_module` to produce a serialized
MLIR module containing the shape checking code. This
will be added in a future change to `XlaCallModule`.
2023-06-13 20:47:30 +03:00
Yash Katariya
4a5c6f8200 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 533263584
2023-05-18 15:09:54 -07:00
Yash Katariya
8b9e6bcbd4 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time.
PiperOrigin-RevId: 532155068
2023-05-15 10:32:24 -07:00
Parker Schuh
11b34a90fd Skip stream-executor for aot_test.py.
PiperOrigin-RevId: 531248352
2023-05-11 10:51:32 -07:00
Jake VanderPlas
e46d7f673b KeyArray: use assertArraysEqual in place of assertKeysEqual 2023-04-26 13:15:03 -07:00
Yash Katariya
3722d7066a Add jax_pmap_shmap_merge flag to begin the process of merging pmap and shard_map
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).

TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
2023-04-20 21:22:48 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Yash Katariya
78678ee9e1 Rename count_pjit_cache_miss with count_pjit_cpp_cache_miss because it is confusing which cache the first function is taking about as pjit has many caches
PiperOrigin-RevId: 521559652
2023-04-03 14:15:02 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Yash Katariya
f9468d3879 Remove the helper jit functions from api.py
PiperOrigin-RevId: 517152277
2023-03-16 10:08:00 -07:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Parker Schuh
81507d97f6 Convert shard_args to return arrays when jax.config.jax_array is True.
PiperOrigin-RevId: 515205284
2023-03-08 19:13:20 -08:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
Peter Hawkins
ed491b3056 Shorten alias chains for names exported in jax. namespace.
Add some additional type annotations on public APIs.

This allows pytype to do a better job of type inference.

PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Adam Paszke
1638313a99 Slightly increase the tolerance in sparse tests to avoid flakiness
PiperOrigin-RevId: 511548667
2023-02-22 11:22:02 -08:00
Yash Katariya
b476661b4a Add clear_cache endpoint to python pjit and cpp pjit functions.
PiperOrigin-RevId: 509696516
2023-02-14 18:46:25 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Skye Wanderman-Milne
eb13c053e9 Add option to run tests with persistent compilation cache enabled.
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.

PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Yash Katariya
c252162821 Make pjit's cache global just like jit's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.)) is executed twice.
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.

PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Yash Katariya
8a4de1f86a Remove the usage of _arrays from tests
PiperOrigin-RevId: 505871063
2023-01-30 20:02:37 -08:00
Skye Wanderman-Milne
f90b5eed52 Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Tres Popp
edd09c65f3 Remove xla_cpu_enable_mlir_lowering
This is now replaced with --xla_cpu_use_xla_runtime

PiperOrigin-RevId: 500955199
2023-01-10 04:19:03 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Peter Hawkins
73de02d5ce Make JAX tests pass under NumPy 1.24.0rc2.
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
2022-12-08 19:46:10 +00:00
Yash Katariya
934bc4e1b3 Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
Johannes Reifferscheid
575c2f3783 Skip unsupported tests on XLA:CPU MLIR.
PiperOrigin-RevId: 490754048
2022-11-24 09:56:59 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Parker Schuh
5cfc708843 Remove error-prone most_recent_entry() support from lu.cache.
PiperOrigin-RevId: 484382188
2022-10-27 16:41:44 -07:00