182 Commits

Author SHA1 Message Date
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
52b592739e Turn jnp.ndarray into a true abstract base class.
Make all JAX array types instances of jnp.ndarray.
Remove np.ndarray from jnp.ndarray.
2021-09-21 14:54:45 -04:00
Jean-Baptiste Lespiau
9c782e2289 Move ShardedDeviceArray & PmapFunction to the raw C API and implement pickling/unpickling.
PiperOrigin-RevId: 395256774
2021-09-07 08:50:48 -07:00
jax authors
50dd5e80dd Use the raw C API for ShardedDeviceArray.
It's similar than PyBuffer.

PiperOrigin-RevId: 395071943
2021-09-06 04:35:37 -07:00
Jean-Baptiste Lespiau
e793c88566 Use the raw C API for ShardedDeviceArray.
It's similar than PyBuffer.

PiperOrigin-RevId: 395058367
2021-09-06 02:57:16 -07:00
Ningning Xie
f38d3e8735 Allow axis index groups to have different sizes for AllReduce.
PiperOrigin-RevId: 394297426
2021-09-01 13:10:17 -07:00
Roy Frostig
aa265cce95 introduce custom PRNG implementations and an array-like adapter for them
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
2021-08-19 20:43:11 -07:00
Jean-Baptiste Lespiau
f6f1debf70 Add post_hook support for pmap, to support debug_nans and debug_infs.
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.

PiperOrigin-RevId: 391272270
2021-08-17 06:11:47 -07:00
Jean-Baptiste Lespiau
5e0becf862 Specify what is the expected behavior for ShardedDeviceArray.delete()
- We expect it to be indepempotent.
- We also fix an unclear error when slicing a deleted object.

PiperOrigin-RevId: 389930575
2021-08-10 11:11:16 -07:00
Jean-Baptiste Lespiau
45aaf8a647 Make it possible to return a C++ ShardedDeviceArray.
This **will** be a **breaking** change, as pxla.ShardedDeviceArray constructor won't be valid anymore:
- for the next Jax release
- on the condition _USE_EXPERIMENTAL_CPP_SDA is switch to `_xla_extension_version > xx` and with the associated jaxlib release.

I am already adding the impact for the users in the CHANGELOG, we can still move it to the next version depending on when it's shipped.

Similarly to JAX.jit, for which we have a C++ `DeviceArray` and a Python `_DeviceArray`, we will introduce 2 objects for ShardedDeviceArray, with the Python object only for JAX extensions not compatible with the C++ object (e.g. Cloud TPU).

- Add `make_sharded_device_array` to be used within JAX and for hackers that need to construct SDA objects.
- Make sure the C++ object is valid by
  (a) extending `DeviceArrayBase` (done in Python), as it brings a bunch of methods and enable `isinstance(x, DeviceArray)`
  (b) Adding the same methods as the Python SDA.

NOTE: mypy has troubled with the " -> pxla.ShardedDeviceArray` function return type annotation, I had to remove 2.
PiperOrigin-RevId: 389876734
2021-08-10 07:16:24 -07:00
Jean-Baptiste Lespiau
ad4c670f37 Add Python code for the future C++ pmap and pass the data to C++ as a namedtuple.
PiperOrigin-RevId: 388788330
2021-08-04 14:46:51 -07:00
jax authors
606cbe036a Merge pull request #7370 from slowy07:fixing
PiperOrigin-RevId: 388774232
2021-08-04 13:43:58 -07:00
Jean-Baptiste Lespiau
5450106e01 Replace AvalDimSharding and MeshDimAssignment with the C++ object.
This is backward compatible, as the new objects has the same attributes with the same type (in particular, it can be constructed from iterable objects, and  `sharding` and `mesh_mapping` are still tuples.

PiperOrigin-RevId: 388565058
2021-08-03 16:03:51 -07:00
Matthew Johnson
c31688d2d1 fix cond-of-pmap bug 2021-07-29 10:34:43 -07:00
Ningning Xie
e7f03073dd ReduceScatter translation and abstract eval.
PiperOrigin-RevId: 387152857
2021-07-27 11:13:18 -07:00
slowy07
9eadb07bdc fix: miss typo codespell and documentation 2021-07-24 15:25:13 +07:00
Adam Paszke
64510bd5b6 Add axis and tiled options to lax.all_gather.
This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.

PiperOrigin-RevId: 384897270
2021-07-15 04:22:36 -07:00
Adam Paszke
ed96e5305f Fix incorrect handling of axis_index_groups in parallel primitive fallbacks
PiperOrigin-RevId: 377139424
2021-06-02 14:03:47 -07:00
Adam Paszke
8df502aeb2 Use the axis names attached to a primitive when selecting the top trace
This is useful e.g. for handling psums of values that are not sharded,
but are also not statically known constants that we can fold.
2021-04-28 09:46:24 +00:00
Lena Martens
deb2227f4a Make sure the out_axes in the HashableFunction closure are hashable.
By flattening them before putting them in the closure.
2021-04-21 12:32:19 +01:00
Adam Paszke
e0357283a6 Speed up test generation 7x
There are a few test cases that generate millions of configurations,
only to have a handful of them selected by `cases_form_list`. I've
found all tests that spend over 100ms in case generation and
converted them to a new "test sampler" approach. The result: test
generation time drops from 15s to around 2s. Doesn't sound like much,
but I expect that we all run tests many times daily, so it seems like a
useful thing to have.

The rough idea is that the sampling generators get parameterized by a
sampler function that should be applied to the range of every `for` loop.
This allows us to sample runs of the generator through different
configurations by restricting each loop to a smaller subset. Right now
we always narrow it down to a single randomly selected instance. But,
we still retain the possibility of adding exhaustive testing in the
future, which can be achieved by passing in an identity sampling
function that wouldn't modify any loop ranges.
2021-04-14 15:58:05 +00:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Jake VanderPlas
4139faf490 Fix dtype for pmap of scalars 2021-03-25 12:43:31 -07:00
Skye Wanderman-Milne
a18f8cc49a Improve nested pmap error message.
It currently gives a misleading error message in the case of nested pmaps without this change.
2021-03-24 12:02:04 -07:00
Adam Paszke
2c7c86a4ba Reenable multi-axis all_to_all 2021-03-08 12:45:03 +00:00
Adam Paszke
8a4f0a8931 Make all_to_all primitive match XLA semantics
This has the benefit of limiting the insane axis arithmetic (with some
axes getting removed, and others introduced with their positions offset
by the removals) to the all_to_all user-facing function, but all the
collective rules should now be simpler to write. This should be a no-op
from the point of view of the users, but should make enabling all_to_all
splitting easier.
2021-03-05 18:18:49 +00:00
Skye Wanderman-Milne
8e265f67aa Don't fail due to flaky CPU backend initialization in pmap_test 2021-03-04 00:25:16 +00:00
Jean-Baptiste Lespiau
18343817c8 Use the C++ object for the Sharding specification. 2021-02-12 16:02:58 +01:00
Matthew Johnson
e47c933fd3 fix/skip test failures 2021-02-11 08:30:37 -08:00
Matthew Johnson
ffb3873e5a add pargmax, pargmin wrappers 2021-02-09 19:04:46 -08:00
jax authors
ce180bcb37 Merge pull request #5560 from pschuh:transpose-args
PiperOrigin-RevId: 355880124
2021-02-05 10:37:29 -08:00
Parker Schuh
ed5d1faab2 Add support for execute_sharded_on_local_devices.
Change the structure of `execute_replicated` so that `in_handlers` and
`out_handlers` return and take `args[arg][shard]`
instead of `args[shard][arg]`.
2021-02-04 12:42:57 -08:00
Parker Schuh
bb0f50ff8e Bugfix to allow pxla.replicate to generate Replicated buffers. 2021-02-01 15:19:25 -08:00
Matthew Johnson
014f9a86b4 implement soft_pmap in terms of xmap 2021-01-28 07:59:57 -08:00
Jean-Baptiste Lespiau
32de6ffc5a Replace None with an object NoSharding.
This is to make the change to a C++ ShardingSpec easier.
See also https://github.com/google/jax/pull/5444

PiperOrigin-RevId: 352965689
2021-01-21 01:15:25 -08:00
Chris Jones
4b48c7f42b Use XLA AllGather op for GPU (attempt 2).
This is an expansion of the first, rolled-back attempt (https://github.com/google/jax/pull/5260), this time with auto-diff and batching rules that some users are relying on.

My benchmarks suggest a speed-up of ~2-2.5x for larger inputs.
2021-01-19 11:16:25 +00:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Chris Jones
3896ac2064 Skip tests that are failing without omnistaging (which will not be an option shortly).
Enable tests on all platforms that were unnecessarily skipped.
2021-01-06 17:18:58 +00:00
jax authors
d8aabdb8c6 Merge pull request #5059 from jpuigcerver:all-to-all-groups
PiperOrigin-RevId: 346060720
2020-12-07 04:05:26 -08:00
Joan Puigcerver
85fbc6d790 Add axis_index_groups argument to all_to_all. 2020-12-07 11:52:42 +00:00
Matthew Johnson
8b64c3c679 fix inherited repr method for ShardedDeviceArray
fixes #5102
2020-12-04 21:25:51 -08:00
Matthew Johnson
13b96cc422 fix typo 2020-12-04 17:07:23 -08:00
Matthew Johnson
2a8f71e013 deflake tests 2020-12-04 16:47:04 -08:00
Matthew Johnson
6992ae182c switch assertions per reviewer comment 2020-12-04 16:42:49 -08:00
Matthew Johnson
dc610e4516 add jax.device_put_replicated
Also move tests for device_put_sharded into pmap_test.py, since that
file tests with multiple devices even in our OSS CI.

Add both device_put_replicated and device_put_sharded to
jax/__init__.py.
2020-12-04 12:54:07 -08:00
Adam Paszke
ca8028950e Fix pmap compilation cache regressions from #4904.
AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.
2020-12-02 14:40:45 +00:00
Matthew Johnson
58e441bed7 add experimental pdot primitive, basic tests 2020-11-27 11:18:01 -08:00
Matthew Johnson
8057cf919e simplify vmap collectives from two sets of rules to one
Specifically we:
1. remove the need for split_axis rules in batching.py, and instead just
rely on collective rules (namely to handle vectorizing over a single
named axis even if the collective is applied over multiple named axes)
2. simplify BatchTrace.process_primitive so that we don't pass tracers
into rules and rely on a subtle recursion

This change breaks all_to_all when used with multiple axis names, and in
particular it breaks all_to_all given the current gmap/xmap lowering
strategy of substituting multiple axis names in place of single axis
names. We believe we can replicate the previous logic with the new rule
organization, but we're leaving that for follow-up work because it's
tricky, and because we might end up changing lowering strategies not to
require axis substitution in the same way.
2020-11-25 10:15:21 -08:00