439 Commits

Author SHA1 Message Date
Jake VanderPlas
0d9367972b jax.jacobian: propagate function signature to transformed function 2022-10-04 10:21:54 -07:00
Yash Katariya
163b7e22d2 Convert shardings in jit path to OpShardingSharding to avoid recompilation when semantically similar shardings are used in jit.
PiperOrigin-RevId: 477626548
2022-09-28 21:17:29 -07:00
Matthew Johnson
b175e11731 [c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays
fixes #12542

Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Kuangyuan Chen <chky@google.com>
2022-09-27 20:51:07 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Roman Ring
8bcf358fde
Remove unused _remat_static_argnums import. 2022-09-26 17:14:09 +01:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
a157982e8c Make jit(f).lower(*args) go via lower_sharding_computation when jax_array is enabled.
PiperOrigin-RevId: 476148608
2022-09-22 11:13:33 -07:00
Kuangyuan Chen
405a2310ce Implement pjit fast path in cpp for jax.Array inputs
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Kuangyuan Chen
2547e8110b Use C++ Array in pmap path and move PmapSharding to cpp
PiperOrigin-RevId: 474151089
2022-09-13 16:19:18 -07:00
jax authors
5b65df0626 Merge pull request #12296 from hawkinsp:minver
PiperOrigin-RevId: 473080556
2022-09-08 14:15:38 -07:00
Kuangyuan Chen
0400db959b Introduce class PyArray that contains the data members of python Array.
A few key methods is implemented in C++ while the rest are still implmemented in python and added to the class later. A class decorator, @use_cpp_array, is added to add python methods to xc.Array.

PiperOrigin-RevId: 473075244
2022-09-08 13:48:28 -07:00
Peter Hawkins
6c59d72c75 Bump the minimum jaxlib version to 0.3.15. 2022-09-08 16:43:46 -04:00
George Necula
fe055d06ba Allow get_aval to work on ShapeDtypeStruct
This is necessary to be able to call jit(f).lower(ShapeDtypeStruct(...) when
--jax_dynamic_shapes is on. The code in partial_eval.infer_lambda_input_type
calls get_aval.
2022-09-04 12:11:05 +03:00
Roy Frostig
8f045b12d6 internal rename: swap mentions of "custom eltypes" for "opaque dtypes"
Also, avoid direct set membership tests on `core.opaque_dtypes`. Update
callers to use `core.{is,has}_opaque_dtype` predicates instead.
2022-08-30 16:52:08 -07:00
jax authors
c1217be32f Merge pull request #11954 from mattjj:for-vmap
PiperOrigin-RevId: 470810499
2022-08-29 14:08:37 -07:00
Matthew Johnson
bbb8048d2e Add batching rules for state primitives and for_loop
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-29 11:40:09 -07:00
Anselm Levskaya
3c9178745b Add runtime type check in named_scope to ensure that name is a string.
PiperOrigin-RevId: 470177071
2022-08-26 00:25:31 -07:00
Yash Katariya
e8ec454ae8 Enable fast path in the Array constructor. This means that the rearranging of _arrays according to the device_assignment won't happen when fastpath is enabled because we assume that jax transformations will return the right arrangement.
PiperOrigin-RevId: 469492283
2022-08-23 10:20:26 -07:00
Roy Frostig
6071a8f875 roll-forward #11952, take 2
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`

PiperOrigin-RevId: 469276609
2022-08-22 13:57:31 -07:00
Yash Katariya
f905d989c1 Make eager pmap tests pass with Array. Also add a slow path for Array in pmap similar to what SDA has. This is required for eager pmap. Adding a slow path removes the need for doing sharding checks in api.py because SDA doesn't do those checks and if the sharding does not match with pmap sharding, then it just defaults to the slow path (exactly like SDA).
PiperOrigin-RevId: 468843310
2022-08-19 21:37:22 -07:00
jax authors
3a2f25ff31 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468840334
2022-08-19 21:02:18 -07:00
Roy Frostig
9789e83b26 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
2022-08-19 20:12:32 -07:00
jax authors
a6c6416872 Internal change
PiperOrigin-RevId: 468712508
2022-08-19 08:56:49 -07:00
jax authors
f2d417152d Merge pull request #11952 from froystig:key-array-on-eltypes
PiperOrigin-RevId: 468694524
2022-08-19 07:11:27 -07:00
jax authors
764c268ff6 Increase the threshold to use tuple_args to 2000 args for TPUs.
PiperOrigin-RevId: 468675921
2022-08-19 04:57:34 -07:00
Roy Frostig
34b63dfc77 teach jax2tf about custom eltypes, key arrays, and random key primitives
Specifically:

* Introduce a `physical_avals` view as a custom eltype method. This is
  analogous to the existing `aval_to_ir_types`, but where the output
  is an aval with a non-custom eltype (and hence a direct
  correspondence to TF and to lowerings).

* Change jax2tf to continue tracing with logical avals, but to
  maintain TF tensors of corresponding physical shape/dtype, and to
  translate to TF operations based on physical avals where relevant.

* Fix up various TF impl rules to follow physical avals. To this end,
  add a "physical" mode to jax2tf's `_convert_jax_impl` helper, which
  carries out the conversion using physical rather than logical avals.

* Write TF impl rules for `random_{seed,split,fold_in,bits}`
  primitives. To this end, factor out the part of these primitives'
  impl rules that operates on the base array and convert that, pass it
  through `_convert_jax_impl` in physical mode.

* Teach the jax2tf test harness how to unwrap key-array-typed outputs
  into physical `uint32` arrays that it can use in comparison tests.
2022-08-18 21:46:55 -07:00
Roy Frostig
7f06df1ea1 introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.

We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.

This leans heavily on our recently-introduced extended element type
capabilities.

As a consequence, `vmap`, `scan`, etc. now work.

A sample of changes made to introduce key-element-type arrays:

* Introduce a new element type (`prng.KeyTy`), with the requisite IR
  type mapping and device result handlers, as well as lowering rules
  for dtype-polymorphic primitive operations.

* Introduce primitives for basic RNG operations: `random_seed`,
  `random_bits`, `random_split`, `random_fold_in`. These primitives
  essentially delegate to the underlying PRNG implementation (directly
  so in their impl rules, and by translating their staged-out form in
  lowering rules).

* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
  conversion from/to the base `uint32` array. We need this backwards
  compatibility, and it's useful for tests.

* Introduce some `vmap`-based helpers to adapt PRNG impls (which
  define basic `random_bits`, `split`, etc. on scalars) to the above
  batch-polymorphic primitives. Most of the primitives are vectorized,
  but `random_fold_in` is a broadcasting binary op.

* Update the `gamma` primitive rules to account for key-element-type
  abstract arrays (nice simplification here).

* Give PRNG implementation short string names ("tags") for IR
  pretty-printing.

* Update `lax.stop_gradient` to handle opaque dtypes.

* Fix up loop MLIR lowering, which assumed that shaped arrays of all
  dtypes have the same physical shape.

* Add new tests (exercising staging, jaxprs, lowerings, ...)

A sample of changes made to rework Python-level PRNG key arrays:

* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
  tracers that carry them.

* Patch (only a subset of) standard device array attributes onto PRNG
  key arrays.

* Implement various conversion handlers (sharding, constant-creation,
  `device_put`).

* Accept PRNG key arrays as input to `lax_numpy.transpose`.

* Update tests and rename some internals.

A sample of extra changes along the way:

* Disallow AD on key-typed arrays in the main API.

* Hoist `random_bits`'s named-shape-handling logic, which used to only
  take place in the threefry PRNG's `random_bits` implementation, up
  to the new `random_bits` traceable, so that we apply it consistently
  across PRNG implementations.

This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.

Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-18 21:46:55 -07:00
jax authors
39d54bdbf6 Merge pull request #11928 from sharadmv:pure-callback
PiperOrigin-RevId: 468611094
2022-08-18 20:33:34 -07:00
Yash Katariya
f42151c3dc Take the pjit XLA compilation path for Arrays. In the test, astype happens in a sharded fashion without the round trip to host.
PiperOrigin-RevId: 468510366
2022-08-18 11:44:39 -07:00
Sharad Vikram
b0fdf10a63 Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-18 10:50:50 -07:00
Yash Katariya
acdae7c237 Add weak type support to Array. Also make all api_test.py tests pass with Array. I have disabled the float0 test for now until I investigate.
PiperOrigin-RevId: 468264910
2022-08-17 12:25:49 -07:00
Sharad Vikram
393bca122d Expose pure callback and enable rank polymorphic callbacks 2022-08-17 10:56:42 -07:00
jax authors
9ca37c9e33 Merge pull request #11950 from mattjj:delete-old-remat
PiperOrigin-RevId: 468173667
2022-08-17 05:40:26 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
jax authors
0abbdd0648 Add a backend field to mlir.ModuleContext so that host callback lowering can use the correct backend
PiperOrigin-RevId: 468024979
2022-08-16 14:26:53 -07:00
Neil Girdhar
ad38a6bb28 Fix common typo: Tuple[X] -> Tuple[X, ...] 2022-08-16 11:47:22 -04:00
Sharad Vikram
fe040cc01e Cleaning up eager pmap implementation
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 11:10:16 -07:00
Matthew Johnson
5310515c80 Initial implementation of eager pmap
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 10:21:55 -07:00
Sharad Vikram
88f2b5e86d Add functionality for "pure" callbacks
Also avoids using CPP dispatch path when host callbacks are involved

PiperOrigin-RevId: 467270949
2022-08-12 12:39:53 -07:00
Yash Katariya
18b6a32db2 Make all pmap tests pass with Array! I am skipping all soft pmap tests for now.
PiperOrigin-RevId: 467264992
2022-08-12 12:09:49 -07:00
Yash Katariya
33c4fc4fe2 Pmap should output SDA like Arrays to maintain the current behavior exactly. Split the shard_arg_handler for Array based on whether the mode is pmap or pjit. Why do this? The doc below explains more about the context.
PiperOrigin-RevId: 466849614
2022-08-10 20:11:37 -07:00
Roy Frostig
7d494a3852 update checkpoint attributes according to functools.wraps
This updates the signature in addition to `__doc__`, and that gets
picked up by generated API docs.
2022-08-10 13:33:07 -07:00
jax authors
8b2e4f975c Merge pull request #11825 from mattjj:fix-type-annotation
PiperOrigin-RevId: 466550958
2022-08-09 20:21:10 -07:00
Matthew Johnson
d76754e40e fix type annotation on remat 2022-08-09 19:57:40 -07:00
Parker Schuh
01df754630
Remove docs 2022-08-09 12:36:49 -07:00
Parker Schuh
8fb957350c Add spmd_axis_name to vmap to allow constraining mapped PartitionSpecs. 2022-08-08 19:41:42 -07:00
Matthew Johnson
81b6263ed0 Rolling forward #11768 after test failures caused roll-back (from use of np.empty).
PiperOrigin-RevId: 465712458
2022-08-05 22:19:33 -07:00
jax authors
6b0c0dc321 Internal change
PiperOrigin-RevId: 465705931
2022-08-05 21:08:43 -07:00