6392 Commits

Author SHA1 Message Date
Roy Frostig
6abefa1977 fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
2023-09-13 09:43:58 -07:00
Yash Katariya
c41d271175 Add memories support to remat.
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.

Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.

A very basic offloadable policy can look like this:

```
def policy(prim, *avals, **params):
  return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
2023-09-12 20:50:05 -07:00
jax authors
7c0abb1c85 Merge pull request #17564 from andportnoy:aportnoy/scipy-spatial-test-increase-tolerance
PiperOrigin-RevId: 564891104
2023-09-12 18:21:31 -07:00
jax authors
d3950b93cb Merge pull request #17568 from zhenying-liu:testdotgeneral
PiperOrigin-RevId: 564871912
2023-09-12 16:46:39 -07:00
Jane Liu
90da4f153a Fix an A100 nightly unit test failure on testDotGeneral() by replacing TF32 with float32 2023-09-12 16:31:22 -07:00
Jake VanderPlas
56791eb9ec lax_test: adjust TPU tolerance for igamma & friends
PiperOrigin-RevId: 564859109
2023-09-12 15:59:41 -07:00
Andrey Portnoy
34ea2b2e8a Increase comparison tolerance in SciPy spatial RotationMean subtest
Previous value leads to failures on A100 runners in
github.com/NVIDIA/JAX-Toolbox CI:
https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014

The suspected reason is the use of TF32 math for matmuls: decorating the
function with @jax.default_matmul_precision("float32") allows the test to pass.
We thought it's better to loosen the tolerance but preserve the original
execution mode.

The fully qualified test case is
tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0
2023-09-12 16:12:14 -04:00
jax authors
63e51fe1e6 Merge pull request #17524 from jakevdp:pref-eltype-test
PiperOrigin-RevId: 564792174
2023-09-12 12:01:21 -07:00
Yash Katariya
76a5dc3cac Move memories_test.py to JAX
PiperOrigin-RevId: 564551723
2023-09-11 17:41:55 -07:00
Qiao Zhang
d4adf0095f Add default jvp and transpose rule for jax.lax.reduce_precision.
PiperOrigin-RevId: 564536160
2023-09-11 16:35:44 -07:00
jax authors
23e4f0b471 Hash serialized topology description for new cache key generation.
The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.

Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.

Testing: revised unit test.
PiperOrigin-RevId: 564461564
2023-09-11 12:08:26 -07:00
Yash Katariya
a36598b2a7 Set the jax_enable_memories flag to True.
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.

If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.

PiperOrigin-RevId: 564457393
2023-09-11 11:55:09 -07:00
Jake VanderPlas
9289f3250b Add missing preferred_element_type tests
Followup to https://github.com/google/jax/pull/17506
2023-09-08 13:07:37 -07:00
Jake VanderPlas
bc91f2d182 Add more extensive tests for version strings 2023-09-08 11:33:49 -07:00
Jake VanderPlas
9e9ea5a90c jax.dtypes: fix dtypes.safe_to_cast() when output dtype is weak 2023-09-08 09:04:05 -07:00
Adam Paszke
832820e8ff CI changes
PiperOrigin-RevId: 563754696
2023-09-08 08:08:05 -07:00
Adam Paszke
592eb44d7e Fix Pallas tests that broke after recent changes
PiperOrigin-RevId: 563750775
2023-09-08 07:50:20 -07:00
Yash Katariya
b8eccb13f0 Remove the date check from jaxlib and jax version checks since it causes problem when jaxlib runs ahead of jax in CI (depending on timezones).
PiperOrigin-RevId: 563614108
2023-09-07 19:52:35 -07:00
Sharad Vikram
cb114f247a [Pallas] Refactor memory space handling
PiperOrigin-RevId: 563586933
2023-09-07 17:08:57 -07:00
jax authors
311dc9cfde Add truncated normal initializer to jax.nn
PiperOrigin-RevId: 563576354
2023-09-07 16:23:42 -07:00
Parker Schuh
bda9292523 Propagate ad.Zeros to the scan body function for jax.lax.scan for the output 'ys'.
Example of what this fixes:

```
def grad_fn(x):
  def scan_body(x, params):
    return x, x.sum()

  pred, state = jax.lax.scan(scan_body, x, None, length=2)
  return pred.sum(), state
x = np.zeros((5, 10), dtype=np.float32)
loss_grad_fn = jax.value_and_grad(grad_fn, has_aux=True)
print(jax.make_jaxpr(loss_grad_fn)(x))
```
PiperOrigin-RevId: 563544684
2023-09-07 14:36:53 -07:00
Jake VanderPlas
8412781127 Internal: add dtypes.safe_to_cast utility & use to generate indexing warning 2023-09-07 12:18:14 -07:00
jax authors
f64235acc8 Merge pull request #17453 from jakevdp:fix-version-string
PiperOrigin-RevId: 563466394
2023-09-07 10:06:32 -07:00
Jake VanderPlas
6f3f0d5e57 build: write appropriate version strings to build artifacts 2023-09-07 08:45:48 -07:00
Peter Hawkins
9b447aa3ec Relax test tolerance to fix BCSR sparse matmul test failure on P100 GPU.
PiperOrigin-RevId: 563441383
2023-09-07 08:37:31 -07:00
Peter Hawkins
429422dfea Reverts 5fcd9265b1e20c41d684659af3d52c41f25ae2f3
PiperOrigin-RevId: 563426073
2023-09-07 07:35:44 -07:00
jax authors
26d90beca2 Merge pull request #17314 from jon-chuang:jon-chuang/fix-causal-upper-bound
PiperOrigin-RevId: 563228924
2023-09-06 15:13:16 -07:00
Adam Paszke
bb8d5a0121 Rewrite simple slicing to the static slicing primitive whenever possible
This makes it a lot easier to handle within Pallas and Mosaic.

PiperOrigin-RevId: 563128943
2023-09-06 09:43:00 -07:00
George Necula
660a015652 [export] Move jax_export and shape_poly out of jax2tf.
Those modules have been developed initially for jax2tf
but they do not depend on TF anymore. They are used for JAX
native serialization. We move them under
jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF.

We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later.

PiperOrigin-RevId: 562988740
2023-09-05 22:15:59 -07:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Yash Katariya
80606cd28d Make is_fully_addressable an abstract method and implement it on each concrete Sharding.
Also, don't cache methods. Pull them out into a free function and cache that function.

PiperOrigin-RevId: 562939188
2023-09-05 17:28:22 -07:00
George Necula
f27816af30 [callback] Enable 64-bit types and add tests.
This takes advantage of a recent fix in XLA:TPU to enable
64-bit host transfers.

PiperOrigin-RevId: 562890507
2023-09-05 14:23:28 -07:00
jax authors
7224c24521 Merge pull request #17406 from jakevdp:namedtuple-mul
PiperOrigin-RevId: 562843729
2023-09-05 11:40:18 -07:00
Jake VanderPlas
7d29ed6bdd Lower jax.numpy matmul functions to mixed-precision dot_general 2023-09-05 08:37:51 -07:00
Jon Chuang
7d27c319b2 fix causal attention upper bound 2023-09-04 11:22:21 +08:00
George Necula
01c068eabd [callback] Some test cleanup.
Removes callback testing function and uses io_callback
and pure_callback instead. This allows us to remove
some tests from the PureCallbackTest class.

Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest.

PiperOrigin-RevId: 562285255
2023-09-02 21:51:07 -07:00
Jake VanderPlas
72de4cd0b0 Bug: support NamedTuples in deferring binary ops 2023-09-01 13:16:20 -07:00
George Necula
efaea8ed32 [callback] Enable device_index support in terms of callback sharding support.
This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561856023
2023-08-31 22:31:35 -07:00
George Necula
e0a6230214 [host_callback] Delete unused code paths.
This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561851494
2023-08-31 22:08:23 -07:00
Matthew Johnson
70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00
jax authors
c38f67043c Hash serialized CompileOptions for new cache key generation.
The original cache key generation hashes individual fields of
CompileOptions, ExecutableBuildOptions, and DebugOptions. This
is not future proof: when a field is added to any of these
structures, the corresponding hash needs to be added to the
cache key generation. The new cache key generation algorithm
hashes the serialized representation of CompileOptions.

Some DebugOptions do not affect the compilation result;
exclude them from the computation. If additional fields are
identified, they can be added; such additions will reduce
unnecessary cache misses.

Testing: revised unit test.
PiperOrigin-RevId: 561803875
2023-08-31 17:21:57 -07:00
jax authors
80f6151110 Instrument metrics to track cache hit rate of original JAX compilation cache.
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of  the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.

Created a context manager to register/unregister event listeners.

PiperOrigin-RevId: 561771262
2023-08-31 15:05:23 -07:00
Matthew Johnson
8a04dfd830 rolling back shard_map transposition change to fix a bug
Reverts 437d7be73534403f39fbee9d6391be1c532933a1

PiperOrigin-RevId: 561730581
2023-08-31 12:39:56 -07:00
Jake VanderPlas
f0309b49c9 jax.random: warn on unsupported dtypes 2023-08-31 10:56:05 -07:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
jax authors
437d7be735 Merge pull request #17368 from mattjj:shmap-transpose
PiperOrigin-RevId: 561476082
2023-08-30 16:14:13 -07:00
Jake VanderPlas
4b89d03147 Deprecate the contents of jax.prng 2023-08-30 15:13:32 -07:00
Matthew Johnson
fdd252f6ca [shard-map] add rewrite for efficient transposition 2023-08-30 15:08:11 -07:00
Yash Katariya
e785f89470 Raise a good error message when mesh is not provided to jax.jit when using spmd_axis_name parameter of jax.vmap
PiperOrigin-RevId: 561217612
2023-08-29 20:58:57 -07:00
Skye Wanderman-Milne
ecee8f9116 [JAX] Implement importing external dlpack-aware Python arrays.
See https://dmlc.github.io/dlpack/latest/python_spec.html.

This is the import path. The export path was implemented in
0b3cbfe4bc.

This allows for creating jax.Arrays from external GPU arrays
asynchronously.

PiperOrigin-RevId: 561172624
2023-08-29 16:39:31 -07:00