149 Commits

Author SHA1 Message Date
Yash Katariya
a37121e195 Don't depend on flatten_axis_resources which will error because flatten_axes passes a dummy object() which doesn't work with checks in user pytrees.
Only do this if the original {in|out}_shardings are _UNSPECIFIED.

PiperOrigin-RevId: 502792305
2023-01-18 00:13:04 -08:00
Yash Katariya
e21c29476d Add batch_jaxpr2 which tells the caller where batch dims are.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Adam Paszke
d742e6a410 Transpose all_gather to reduce_scatter
Also, add support for AD and batching of reduce_scatter (with its transpose being all_gather again).

PiperOrigin-RevId: 488706478
2022-11-15 11:03:22 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Roy Frostig
6071a8f875 roll-forward #11952, take 2
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`

PiperOrigin-RevId: 469276609
2022-08-22 13:57:31 -07:00
jax authors
3a2f25ff31 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468840334
2022-08-19 21:02:18 -07:00
Roy Frostig
9789e83b26 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
2022-08-19 20:12:32 -07:00
jax authors
a6c6416872 Internal change
PiperOrigin-RevId: 468712508
2022-08-19 08:56:49 -07:00
Roy Frostig
7f06df1ea1 introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.

We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.

This leans heavily on our recently-introduced extended element type
capabilities.

As a consequence, `vmap`, `scan`, etc. now work.

A sample of changes made to introduce key-element-type arrays:

* Introduce a new element type (`prng.KeyTy`), with the requisite IR
  type mapping and device result handlers, as well as lowering rules
  for dtype-polymorphic primitive operations.

* Introduce primitives for basic RNG operations: `random_seed`,
  `random_bits`, `random_split`, `random_fold_in`. These primitives
  essentially delegate to the underlying PRNG implementation (directly
  so in their impl rules, and by translating their staged-out form in
  lowering rules).

* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
  conversion from/to the base `uint32` array. We need this backwards
  compatibility, and it's useful for tests.

* Introduce some `vmap`-based helpers to adapt PRNG impls (which
  define basic `random_bits`, `split`, etc. on scalars) to the above
  batch-polymorphic primitives. Most of the primitives are vectorized,
  but `random_fold_in` is a broadcasting binary op.

* Update the `gamma` primitive rules to account for key-element-type
  abstract arrays (nice simplification here).

* Give PRNG implementation short string names ("tags") for IR
  pretty-printing.

* Update `lax.stop_gradient` to handle opaque dtypes.

* Fix up loop MLIR lowering, which assumed that shaped arrays of all
  dtypes have the same physical shape.

* Add new tests (exercising staging, jaxprs, lowerings, ...)

A sample of changes made to rework Python-level PRNG key arrays:

* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
  tracers that carry them.

* Patch (only a subset of) standard device array attributes onto PRNG
  key arrays.

* Implement various conversion handlers (sharding, constant-creation,
  `device_put`).

* Accept PRNG key arrays as input to `lax_numpy.transpose`.

* Update tests and rename some internals.

A sample of extra changes along the way:

* Disallow AD on key-typed arrays in the main API.

* Hoist `random_bits`'s named-shape-handling logic, which used to only
  take place in the threefry PRNG's `random_bits` implementation, up
  to the new `random_bits` traceable, so that we apply it consistently
  across PRNG implementations.

This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.

Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-18 21:46:55 -07:00
Jake VanderPlas
9edbf9bb0d [x64] make batching_test compatible with strict dtype promotion 2022-06-17 16:08:54 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
34f116c0e0 vmap: preserve weak_type in batching tracer 2022-03-30 11:06:56 -07:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
Matthew Johnson
c555f5f0e4 handle trivial case for ppermute batching rule
fixes #8688
2021-12-14 10:42:05 -08:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Matthew Johnson
2cb235809a make vmap ppermute consistent with pmap/docstring
This was a bad bug! Unfortunately our tests didn't catch it, in part
because permutations on size-two axes are either trivial or not. The
simplest test might have a size-three axis.
2021-11-18 14:02:49 -08:00
Matthew Johnson
50e7e952bd add internal vmappable interface (part 1)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2021-11-04 15:01:54 -07:00
Adam Paszke
49d9affce0 Enable batcher and batched collective rules for tiled all gathers
Fixes #8221.
2021-10-15 14:37:38 +00:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
jax authors
c365d7f91c Merge pull request #7908 from hawkinsp:api3
PiperOrigin-RevId: 396578038
2021-09-14 06:04:28 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
jax authors
9045672aea Merge pull request #7906 from sharadmv:pdot-precision
PiperOrigin-RevId: 396481634
2021-09-13 17:40:11 -07:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Sharad Vikram
ebd8d95847 Add precision param for pdot 2021-09-13 16:28:31 -07:00
Adam Paszke
1c1ec79edd Clarify the error message for out-of-bounds in_axes in pmap and vmap
Fixes #5201.
2021-07-14 12:11:06 +00:00
Adam Paszke
8df502aeb2 Use the axis names attached to a primitive when selecting the top trace
This is useful e.g. for handling psums of values that are not sharded,
but are also not statically known constants that we can fold.
2021-04-28 09:46:24 +00:00
Adam Paszke
d0606463e4 Fix the batching rule for named reductions
PiperOrigin-RevId: 370505998
2021-04-26 11:41:58 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Roy Frostig
7427991819 skip scalars when broadcasting for batch dimension agreement 2021-03-19 21:47:16 -07:00
Adam Paszke
2c7c86a4ba Reenable multi-axis all_to_all 2021-03-08 12:45:03 +00:00
Adam Paszke
8a4f0a8931 Make all_to_all primitive match XLA semantics
This has the benefit of limiting the insane axis arithmetic (with some
axes getting removed, and others introduced with their positions offset
by the removals) to the all_to_all user-facing function, but all the
collective rules should now be simpler to write. This should be a no-op
from the point of view of the users, but should make enabling all_to_all
splitting easier.
2021-03-05 18:18:49 +00:00
Peter Hawkins
ff3b402ec0 Improve error messages for invalid JAX types returned by batched functions. 2021-02-16 20:02:11 -05:00
Matthew Johnson
ffb3873e5a add pargmax, pargmin wrappers 2021-02-09 19:04:46 -08:00
Adam Paszke
1361ae1247 Add positional axis handling to the psum transpose rule
I must have forgotten to do that in one of the previous patches and
apparently we didn't have any tests for it (at least in the `vmap`
case)!
2021-02-05 10:59:41 +00:00
Daniel Johnson
15b95e3ff5 Use np.shape instead of assuming argument has a shape attr 2021-01-25 18:11:38 -05:00
Daniel Johnson
c6a1bba308 Add evaluation rule for all_gather.
This should only be called when an all_gather runs on arguments that
are not batch tracers, for instance when all_gather-ing a constant.
2021-01-25 17:27:39 -05:00
Daniel Johnson
7865043341 Improve batched collective rule for all_gather_p
When an all_gather references a vmapped axis, there is a particularly
simple way of implementing it: simply "forget" that the axis was mapped,
and return the full array. Conveniently, this doesn't require any
explicit broadcasting, and makes it possible to use out_axes=None with
the results.
2021-01-25 16:52:38 -05:00
Matthew Johnson
304685a152 allow vmapped function to accept kwargs
Arguments passed as keywords are always batched along their leading
axis. The in_tree specification must correspond to arguments passed
positionally.

This brings vmap in line with pmap. That is, pmap already followed this
convention for arguments passed via keywords. Consistency is good!

I had to adapt some utility functions so as not to change the error
messages raised. In particular, we have tests for vmap error messages
which report the in_axes and argument tree structure; naively including
keyword arguments changed those error messages. The error messages are
worth preserving. This change also brought the pmap error messages in
line with the vmap ones.

I also did some 80char wrapping of lines and docstring updating.

Fixes #912. Another user had the same issue and reported the same
expected behavior.
2021-01-12 20:13:23 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
David Majnemer
a87978f094 Enable more TPU tests
PiperOrigin-RevId: 351210210
2021-01-11 12:23:36 -08:00
Jake VanderPlas
1a83bb6f90 Cleanup: remove remaining instances of rng_factory boilerplate 2020-12-11 13:47:46 -08:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Joan Puigcerver
85fbc6d790 Add axis_index_groups argument to all_to_all. 2020-12-07 11:52:42 +00:00
Matthew Johnson
58e441bed7 add experimental pdot primitive, basic tests 2020-11-27 11:18:01 -08:00