327 Commits

Author SHA1 Message Date
Adam Paszke
a5bc7353de Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 18:35:28 +00:00
jax authors
bdd7915661 Internal change
PiperOrigin-RevId: 341644256
2020-11-10 10:12:27 -08:00
Adam Paszke
6914058cbe Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 13:35:23 +00:00
Adam Paszke
8c1fbdc901 Make ShardingSpec more flexible
In preparation of adding support for `in_axes` and `out_axes` to `pmap`.

The only difference in expressivity of the new approach is that the
sharded dimensions can be permuted before ordering/replicating the
indices to match the device assignment. This is necessary if we want to
support `in_axes`, because it may cause some sharded dimensions that are
supposed to get mapped to the "replication" XLA mesh axis to follow the
dimensions mapped to the "partitioning" XLA mesh axis. XLA fixes the
mesh order such that the replicated dimension is always the leading one,
which forces us to decouple the order of data dimensions from the mesh
dimensions.

This patch additionally folds the `is_axis_materialized` into the
sharding specification, by wrapping the integers in small ADT-like
wrappers that distinguish the different ways of partitioning dimensions.
The order of replication is also more explicit in the `mesh_mapping`,
as opposed to being represented as a list of replication factors to be
inserted into the sharding details to obtain a mesh mapping.

Note that this doesn't change any existing functionality. It is purely
an internal rewrite that is supposed to lay the groundwork for the next
patches.
2020-11-09 20:00:39 +00:00
Matthew Johnson
e51163af32 only pass hashable values as static args 2020-10-16 13:11:56 -07:00
Jake VanderPlas
6393349783 raise_to_shaped: preserve weak_type by default 2020-10-08 11:53:52 -07:00
Adam Paszke
e61ca91360 Implement split_axis for all_to_all
This allows us to use `all_to_all` over a mix of vmapped and pmapped
dimensions, which will be useful for `gmap`.
2020-09-30 08:25:36 +00:00
Adam Paszke
fa38f250e3 Add support for all_to_all over vmapped axes 2020-09-30 08:25:36 +00:00
Jake VanderPlas
40016cc47c Allow jax objects to be represented by multiple buffers 2020-09-29 11:53:17 -07:00
Peter Hawkins
a0e14b0552
Revert "Allow JAX objects to be represented by multiple buffers" 2020-09-29 09:26:11 -04:00
Jake VanderPlas
d1f80228e0 Allow jax objects to be represented by multiple buffers 2020-09-25 11:09:08 -07:00
Adam Paszke
acd4cc5737 Allow vmapping all_to_all and implement a (slow) CPU and GPU translation
This allows pmapping vmapped computations that use `all_to_all` or
`pswapaxes` inside. It also includes a very slow CPU and GPU translation
rule that might be useful for debugging programs locally, since XLA only
implements the `AllToAll` collective on TPUs.

Fixes #4141.
2020-09-24 12:18:59 +00:00
Adam Paszke
0d5f15f5c0 Fix the abstract eval and translation rule for all_to_all
The previous rules assumed that `split_axis == concat_axis` (i.e. that
the used collective is equivalent to `pswapaxes`). Since we expose this
as part of our API, we should probably make sure that we handle other
cases too.

Fixes #1332.
2020-09-23 16:22:03 +00:00
Adam Paszke
332a9ba1ad Fix axis_index inside nested pmaps
The previous translation rule has assumed that `axis_index` is always
taken over the outermost axis in the `axis_env`, and was always producing
the same output, no matter which axis has been specified. This fixes the
translation rule to start taking the `axis_name` into account.

Additionally, this adds support for querying the index along multiple
axes, which will be useful for `gmap`.
2020-09-22 16:41:46 +00:00
Adam Paszke
2081e5acee Test pmap/vmap interactions of all reduction collectives 2020-09-21 16:25:50 +00:00
Adam Paszke
c4f98eb8fa Add back the batching rule for ppermute
Just make sure it's correct this time and add a test.
2020-09-21 16:22:00 +00:00
Matthew Johnson
f039f6daf9
thread backend in pxla.replicate (#4272)
* thread backend in pxla.replicate

fixes #4223

* add test for #4223
2020-09-11 22:40:12 -07:00
Adam Paszke
40fb01b4bd Extend axis env while translating the pmapped jaxpr to XLA
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
2020-09-11 17:56:32 +02:00
Roman Ring
bff24bddbb
Add axis_index_groups support to all_gather. (#4194) 2020-09-09 15:02:45 +03:00
Jake Vanderplas
29aa9bfc8f
Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) 2020-08-18 10:17:38 -07:00
James Bradbury
9ab07d8574
support axis_index_groups in psum(const) (#4070)
* support axis_index_groups in psum(const)

* add test for psum(constant, axis_index_groups)

* rm trailing whitespace

* Update lax_parallel.py
2020-08-14 22:54:36 -07:00
Adam Paszke
b75bae6437
Initial version of vmap collectives (#4005)
This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.

This feature is also omnistaging-only.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-08-14 18:22:04 +02:00
Jamie Townsend
89e8a0886b
Fix warnings in pmap_test.py (#3977)
Also add note to developer documentation re: testing pmap.
2020-08-10 10:09:34 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Joan Puigcerver
f02d5b4694
Support differentiation through jax.lax.all_to_all (#3733)
* Support differentiation through jax.lax.all_to_all

Credit to @levskaya for the solution.

* Test gradient of all_to_all

We are testing all_to_all through pswapaxes, since general all_to_all is problematic according to https://github.com/google/jax/issues/1332.

* Removed trailing spaces
2020-07-14 15:45:49 -07:00
Matthew Johnson
1034f29de7
fix bad pmap tests from #3675 (#3685) 2020-07-07 14:48:54 -07:00
Jake VanderPlas
9c3e6c3002 AssertionError -> ValueError 2020-07-07 13:21:44 -07:00
Jake VanderPlas
4711589cf8 fix pmap test on GPU/TPU 2020-07-07 13:19:19 -07:00
Matthew Johnson
d2ebb6eb19
fix ppermute test bugs found by @jekbradbury (#3675) 2020-07-07 00:30:08 -07:00
Matthew Johnson
796df9c550
make psum transpose handle zero cotangents (#3653)
make psum transpose handle zero cotangents

fixes #3651
2020-07-03 10:00:25 -07:00
Jake Vanderplas
09d128edb3
Cleanup: remove some test interdependence (#3600) 2020-06-29 16:22:05 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
86fcfbfa1a
Fix memory leak when no axis is provided to pmap. (#3394)
* Fix memory leak when no axis is provided to pmap.

* Work around flake8 false positive.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-23 09:29:58 -04:00
Skye Wanderman-Milne
8f4ba7e679
Allow specifying both devices and axis_size to pmap. (#3475)
This allows providing custom device assignments to nested pmaps or pmap-of-sharded_jit when running on a multi-host platform.
2020-06-19 15:51:12 -07:00
Matthew Johnson
d4c6cb62ab print warning when doing jit-of-pmap 2020-06-15 21:37:30 -07:00
Matthew Johnson
fcfcffe334 add systematic tests for vmap-of-pmap
fixes #3440

Also re-applies the fix in #3439 (i.e it rolls-back the rollback PR #3448) because we're now confident it's correct (and some internal tests are buggy).
2020-06-15 09:10:40 -07:00
Matthew Johnson
12ce6e3768 roll back of #3439 while we debug internal fails 2020-06-15 07:32:42 -07:00
Matthew Johnson
29fa935ca5 fix vmap-of-pmap mapped_invars logic bug
fixes #3399

This crept in via #1959, but more importantly it shows we don't have
good test coverage here!
2020-06-14 14:45:29 -07:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis (#3304)
* Cleanup: fix lint errors in tests/*.py

* Add flake8 step to travis

* add setup.cfg
2020-06-02 19:25:47 -07:00
James Bradbury
f1a7073738
pmap(in_axes=None) of sharded_jit (#3257)
* pmap(in_axes=None) of sharded_jit

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>

* address comments

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-06-01 16:50:22 -07:00
Jake Vanderplas
0eab5609bb
Fix duplicated test name (#3273) 2020-06-01 15:28:57 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Skye Wanderman-Milne
a3e0cd1293
Fix pxla.shard_args bug (#3170) 2020-05-21 13:52:03 -07:00
Matthew Johnson
ccb203c894
improve pmap unbound axis error, fixes #3120 (#3152) 2020-05-19 15:51:07 -07:00
Matthew Johnson
bc47a32c69
make lax.psum promote bool -> int (#3150)
* make lax.psum promote bool -> int, fixes #3123

* fix test bug

* fix typo in test
2020-05-19 15:41:03 -07:00
George Necula
bc2d2c8ac9
Fix uses of deprecated onp. in pmap_test (#3028) 2020-05-10 14:25:18 +03:00
James Bradbury
f60184e12e
Support axis_index_groups in allreduce collectives (#2382)
* support replica groups in allreduce collectives

* add test and fix jaxpr in docs

* switch from XLA replica IDs to JAX axis indices

* fix psum transpose rule

* test other nesting order + imperfect nesting

* update jaxpr.rst

* handle None case

* add note+check that groups  cover the index space

* switch split_axis assert to NotImplementedError

* update CHANGELOG
2020-05-08 14:00:34 -07:00
Skye Wanderman-Milne
778d7ffcc6
Fix some bugs in _shards_device_array path. (#2983)
Also adds more comprehensive unit tests.
2020-05-06 10:19:28 -07:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. (#2969) 2020-05-05 14:59:16 -04:00