369 Commits

Author SHA1 Message Date
Parker Schuh
6c5d204d7e Jax caches should depend on axis env. 2022-06-29 14:25:14 -07:00
Yash Katariya
989a3304bf Fix the creation of pmap sharding spec when sharded_dim is None.
PiperOrigin-RevId: 457045869
2022-06-24 10:46:35 -07:00
Yash Katariya
766c5ba0a2 Check sharding in pmap for jax.Array.
The checks are:

(1) Check if the in_axes given to pmap matches the sharding of Array.

(2) Check if devices in `array.sharding` is equal to the devices provided to pmap

(3) Check if devices for all array inputs are the same.

(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).

PiperOrigin-RevId: 456567562
2022-06-22 11:37:10 -07:00
Yash Katariya
dce8f64b40 Make device_put_sharded and device_put_replicated return Arrays.
PiperOrigin-RevId: 456525113
2022-06-22 08:51:29 -07:00
Matthew Johnson
83a8dc4e7f [new-remat] add _scan_partial_eval_custom rule for new remat
Also enable scan-of-remat tests which weren't passing before.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
2022-06-17 23:15:14 -07:00
Yash Katariya
6ed94ef876 First CL to integrate jax.Array into pmap.
* If `config.jax_array` is enabled, output from pmap will be `Array`s.
* `Array`s are input are accepted by pmap (as shown in the test). Currently `pxla.make_sharded_device_array` creates SDAs specially for pmap here: https://github.com/google/jax/blob/main/jax/interpreters/pxla.py#L549. So a similar approach can be done for creating `Array`s specially for pmap (see the test).
Also `device_put_sharded` also creates SDAs for pmap.
* `Array`s that are output from `pmap` cannot be passed into `pjit` for now. Currently even SDAs from pmap that are passed into pjit are resharded which has a huge cost. So this kind of code is not used in majority anyways. I can look into relaxing this restriction in the future.

TODOs:
* Add checks for checking if pmap sharding matches the input arrays which I will add in a follow up CL immediately.
* Figure out how to use existing tests for pmap, pjit, xmap, etc.
PiperOrigin-RevId: 455519748
2022-06-16 19:52:31 -07:00
Jake VanderPlas
018e795b25 [x64] make pmap_test compatible with strict dtype promotion 2022-06-16 14:51:55 -07:00
Parker Schuh
ab79573ad0 pmap_lib should use the same cache key for threadlocal context as python jax.
PiperOrigin-RevId: 454297812
2022-06-10 21:39:15 -07:00
Jean-Baptiste Lespiau
bab8520d0c Initialize the thread-local compilation context when undefined in new threads.
PiperOrigin-RevId: 452119314
2022-05-31 12:57:48 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Matthew Johnson
c0d6a04b76 remove jnp.array case for handling buffers w/ aval=None
This functionality was added in #8134, but was superceded by later changes
which ensured that we never produce DeviceArrays with their 'aval' property set
to None (even when indexing ShardedDeviceArrays with integers, which used to be
a problem case).
2022-05-14 08:21:54 -07:00
Matthew Johnson
8915391443 fix redundant (harmless) axis env extension in pmap partial eval 2022-04-28 12:46:19 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Jean-Baptiste Lespiau
8a85544537 Add the input avals to Lowered and Compiled.
PiperOrigin-RevId: 433505462
2022-03-09 09:59:45 -08:00
Jean-Baptiste Lespiau
17f11e05e0 Add accessors on Compiled returning the args and kwargs PyTreeDef working for all transforms.
This also documents the fact that `in_tree` content varies, based on the transform.

PiperOrigin-RevId: 432895923
2022-03-07 02:36:42 -08:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Roy Frostig
d636e74626 make xla_executable a property, consistent across executable types
Also test IR and executable-related methods of `Lowered` and
`Compiled`.
2022-02-25 19:05:44 -08:00
Jean-Baptiste Lespiau
607e7033a6 Turn execute_replicated into a class so we can access its fields.
It's more readable than inspecting the internals of a `functools.partial`.

PiperOrigin-RevId: 429523075
2022-02-18 03:18:47 -08:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Peter Hawkins
5679fedd2c Fix missing handler when lexically capturing a ShardedDeviceArray when MLIR enabled. 2022-02-08 09:51:57 -05:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Peter Hawkins
1f2d8c0c07 In CPU all_gather lowering, make sure the outputs are bools if the inputs are bools.
PiperOrigin-RevId: 414045093
2021-12-03 16:12:03 -08:00
Peter Hawkins
9a53900221 Fix bug in lowering of nested pmaps with boolean types.
Add test that would have caught the bug.

PiperOrigin-RevId: 414010091
2021-12-03 13:34:57 -08:00
Roy Frostig
9f82d78007 typecheck pmap executable call arguments 2021-11-22 09:19:13 -08:00
Roy Frostig
fcdc0a6c1a ahead-of-time lowering and compilation frontend for pmap 2021-11-22 08:33:04 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
jax authors
3ee76a8089 Merge pull request #8601 from mattjj:fix-vmap-ppermute
PiperOrigin-RevId: 410900441
2021-11-18 14:36:00 -08:00
Matthew Johnson
2cb235809a make vmap ppermute consistent with pmap/docstring
This was a bad bug! Unfortunately our tests didn't catch it, in part
because permutations on size-two axes are either trivial or not. The
simplest test might have a size-three axis.
2021-11-18 14:02:49 -08:00
Peter Hawkins
3fd3c46f20 Increase minimum jaxlib version to 0.1.74. 2021-11-18 15:06:58 -05:00
Tom Hennigan
bb3f19891e Ensure that size property of large ShardedDeviceArrays does not overflow.
This tests a fix that landed in XLA commit tensorflow/tensorflow@4216a88.

PiperOrigin-RevId: 410557846
2021-11-17 10:01:51 -08:00
Matthew Johnson
82d28899c7 add more grad-of-jit/pmap caching tests 2021-10-13 11:06:17 -07:00
Peter Hawkins
3361c76dca Consolidate primitive and jit lowering paths.
Before this change, primitives have a special case dispatch path that attempts
to avoid building a jaxpr in the cache miss case. However, there's no good
reason for this: it makes the code more complicated, and we're not particularly
optimizing for fast cache misses anyway (we care mostly about cache hits).

Make the primitive lowering path trace a small function using the xla_callable
lowering path instead.
2021-10-13 12:36:53 -04:00
Jake VanderPlas
486aac949a jnp.array: handle raw device buffers 2021-10-08 10:41:43 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
52b592739e Turn jnp.ndarray into a true abstract base class.
Make all JAX array types instances of jnp.ndarray.
Remove np.ndarray from jnp.ndarray.
2021-09-21 14:54:45 -04:00
Jean-Baptiste Lespiau
9c782e2289 Move ShardedDeviceArray & PmapFunction to the raw C API and implement pickling/unpickling.
PiperOrigin-RevId: 395256774
2021-09-07 08:50:48 -07:00
jax authors
50dd5e80dd Use the raw C API for ShardedDeviceArray.
It's similar than PyBuffer.

PiperOrigin-RevId: 395071943
2021-09-06 04:35:37 -07:00
Jean-Baptiste Lespiau
e793c88566 Use the raw C API for ShardedDeviceArray.
It's similar than PyBuffer.

PiperOrigin-RevId: 395058367
2021-09-06 02:57:16 -07:00
Ningning Xie
f38d3e8735 Allow axis index groups to have different sizes for AllReduce.
PiperOrigin-RevId: 394297426
2021-09-01 13:10:17 -07:00
Roy Frostig
aa265cce95 introduce custom PRNG implementations and an array-like adapter for them
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
2021-08-19 20:43:11 -07:00
Jean-Baptiste Lespiau
f6f1debf70 Add post_hook support for pmap, to support debug_nans and debug_infs.
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.

PiperOrigin-RevId: 391272270
2021-08-17 06:11:47 -07:00
Jean-Baptiste Lespiau
5e0becf862 Specify what is the expected behavior for ShardedDeviceArray.delete()
- We expect it to be indepempotent.
- We also fix an unclear error when slicing a deleted object.

PiperOrigin-RevId: 389930575
2021-08-10 11:11:16 -07:00
Jean-Baptiste Lespiau
45aaf8a647 Make it possible to return a C++ ShardedDeviceArray.
This **will** be a **breaking** change, as pxla.ShardedDeviceArray constructor won't be valid anymore:
- for the next Jax release
- on the condition _USE_EXPERIMENTAL_CPP_SDA is switch to `_xla_extension_version > xx` and with the associated jaxlib release.

I am already adding the impact for the users in the CHANGELOG, we can still move it to the next version depending on when it's shipped.

Similarly to JAX.jit, for which we have a C++ `DeviceArray` and a Python `_DeviceArray`, we will introduce 2 objects for ShardedDeviceArray, with the Python object only for JAX extensions not compatible with the C++ object (e.g. Cloud TPU).

- Add `make_sharded_device_array` to be used within JAX and for hackers that need to construct SDA objects.
- Make sure the C++ object is valid by
  (a) extending `DeviceArrayBase` (done in Python), as it brings a bunch of methods and enable `isinstance(x, DeviceArray)`
  (b) Adding the same methods as the Python SDA.

NOTE: mypy has troubled with the " -> pxla.ShardedDeviceArray` function return type annotation, I had to remove 2.
PiperOrigin-RevId: 389876734
2021-08-10 07:16:24 -07:00
Jean-Baptiste Lespiau
ad4c670f37 Add Python code for the future C++ pmap and pass the data to C++ as a namedtuple.
PiperOrigin-RevId: 388788330
2021-08-04 14:46:51 -07:00