468 Commits

Author SHA1 Message Date
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
jax authors
b09a6175fb Merge pull request #12217 from mattjj:dce-and-execute-trivial
PiperOrigin-RevId: 472856448
2022-09-07 17:41:16 -07:00
Matthew Johnson
3c811b1520 fix bugs, infeed/outfeed must be considered effectful
Co-authored-by: Yash Katariya <yashkatariya@google.com>
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-06 15:13:01 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Roy Frostig
8f045b12d6 internal rename: swap mentions of "custom eltypes" for "opaque dtypes"
Also, avoid direct set membership tests on `core.opaque_dtypes`. Update
callers to use `core.{is,has}_opaque_dtype` predicates instead.
2022-08-30 16:52:08 -07:00
Roy Frostig
73bf0aa30c access rules through a hidden attribute of opaque dtype 2022-08-30 14:06:01 -07:00
Matthew Johnson
bbb8048d2e Add batching rules for state primitives and for_loop
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-29 11:40:09 -07:00
Roy Frostig
6071a8f875 roll-forward #11952, take 2
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`

PiperOrigin-RevId: 469276609
2022-08-22 13:57:31 -07:00
jax authors
3a2f25ff31 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468840334
2022-08-19 21:02:18 -07:00
Roy Frostig
9789e83b26 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
2022-08-19 20:12:32 -07:00
jax authors
a6c6416872 Internal change
PiperOrigin-RevId: 468712508
2022-08-19 08:56:49 -07:00
Roy Frostig
34b63dfc77 teach jax2tf about custom eltypes, key arrays, and random key primitives
Specifically:

* Introduce a `physical_avals` view as a custom eltype method. This is
  analogous to the existing `aval_to_ir_types`, but where the output
  is an aval with a non-custom eltype (and hence a direct
  correspondence to TF and to lowerings).

* Change jax2tf to continue tracing with logical avals, but to
  maintain TF tensors of corresponding physical shape/dtype, and to
  translate to TF operations based on physical avals where relevant.

* Fix up various TF impl rules to follow physical avals. To this end,
  add a "physical" mode to jax2tf's `_convert_jax_impl` helper, which
  carries out the conversion using physical rather than logical avals.

* Write TF impl rules for `random_{seed,split,fold_in,bits}`
  primitives. To this end, factor out the part of these primitives'
  impl rules that operates on the base array and convert that, pass it
  through `_convert_jax_impl` in physical mode.

* Teach the jax2tf test harness how to unwrap key-array-typed outputs
  into physical `uint32` arrays that it can use in comparison tests.
2022-08-18 21:46:55 -07:00
Roy Frostig
7f06df1ea1 introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.

We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.

This leans heavily on our recently-introduced extended element type
capabilities.

As a consequence, `vmap`, `scan`, etc. now work.

A sample of changes made to introduce key-element-type arrays:

* Introduce a new element type (`prng.KeyTy`), with the requisite IR
  type mapping and device result handlers, as well as lowering rules
  for dtype-polymorphic primitive operations.

* Introduce primitives for basic RNG operations: `random_seed`,
  `random_bits`, `random_split`, `random_fold_in`. These primitives
  essentially delegate to the underlying PRNG implementation (directly
  so in their impl rules, and by translating their staged-out form in
  lowering rules).

* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
  conversion from/to the base `uint32` array. We need this backwards
  compatibility, and it's useful for tests.

* Introduce some `vmap`-based helpers to adapt PRNG impls (which
  define basic `random_bits`, `split`, etc. on scalars) to the above
  batch-polymorphic primitives. Most of the primitives are vectorized,
  but `random_fold_in` is a broadcasting binary op.

* Update the `gamma` primitive rules to account for key-element-type
  abstract arrays (nice simplification here).

* Give PRNG implementation short string names ("tags") for IR
  pretty-printing.

* Update `lax.stop_gradient` to handle opaque dtypes.

* Fix up loop MLIR lowering, which assumed that shaped arrays of all
  dtypes have the same physical shape.

* Add new tests (exercising staging, jaxprs, lowerings, ...)

A sample of changes made to rework Python-level PRNG key arrays:

* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
  tracers that carry them.

* Patch (only a subset of) standard device array attributes onto PRNG
  key arrays.

* Implement various conversion handlers (sharding, constant-creation,
  `device_put`).

* Accept PRNG key arrays as input to `lax_numpy.transpose`.

* Update tests and rename some internals.

A sample of extra changes along the way:

* Disallow AD on key-typed arrays in the main API.

* Hoist `random_bits`'s named-shape-handling logic, which used to only
  take place in the threefry PRNG's `random_bits` implementation, up
  to the new `random_bits` traceable, so that we apply it consistently
  across PRNG implementations.

This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.

Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-18 21:46:55 -07:00
Matthew Johnson
b7426b5ef9 rolling forward deletion of custom_jvp_call_jaxpr_p yet again...
PiperOrigin-RevId: 468541924
2022-08-18 14:02:40 -07:00
jax authors
03e2ca0ee7 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468522879
2022-08-18 12:39:21 -07:00
Matthew Johnson
3a20de1575 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468499658
2022-08-18 11:01:10 -07:00
jax authors
fe665b3a64 Copybara import of the project:
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:

remove custom_jvp_call_jaxpr_p and its rules

They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!

PiperOrigin-RevId: 468373797
2022-08-17 22:40:58 -07:00
Matthew Johnson
887b7ce2cb remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!
2022-08-17 21:12:27 -07:00
Matthew Johnson
a7f760d9ed Working multihost eager pmap
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-08-15 10:21:56 -07:00
Matthew Johnson
81b6263ed0 Rolling forward #11768 after test failures caused roll-back (from use of np.empty).
PiperOrigin-RevId: 465712458
2022-08-05 22:19:33 -07:00
jax authors
6b0c0dc321 Internal change
PiperOrigin-RevId: 465705931
2022-08-05 21:08:43 -07:00
Matthew Johnson
348da51dc6 prototype unfettered element types in jaxpr arrays
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.

Within jaxprs, we also need to handle transformations with these types.

In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
  * allowing ShapedArray to have any object for its 'dtype' attribute
  * added core.custom_eltype set
  * extended existing handlers for ShapedArray to call the corresponding custom
    element type handlers
  * mlir lowerings of some fully-element-type-polymorphic primitives
  * tests

In this PR, we only actually use these new extension points in tests.

The applications to come that we have in mind are:
  * arrays of prngkeys (and even custom prngs, as well as reuse error checking)
  * arrays of bounded int type for dynamic shapes (and especially raggedness)
  * float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.

Jargon-wise, we may want to distinguish:
  * 'eltype' meaning jaxpr element types
  * 'dtype' meaning numpy dtypes (an existing convention)
  * 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.

We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
  * [x] broadcast
  * [ ] reshape
  * [x] transpose
  * [ ] pad
  * [x] slice, dynamic_slice, dynamic_update_slice
  * [ ] concatenate
  * [ ] all_to_all, gather, scatter, all_gather, collective_permute
  * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.

We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.

Co-authored-by: Roy Frostig <frostig@google.com>
2022-08-05 19:23:55 -07:00
Matthew Johnson
fbf6aa2a16 small tweaks for bint ad 2022-08-05 08:04:50 -07:00
lenamartens
53dfe35f34 Fix ConcretizationError in nested calls. 2022-07-26 20:31:59 +01:00
Matthew Johnson
7cb5c2447e [dynamic-shapes] fix minor bint bugs
Co-authored-by: Eugene Burmako <burmako@google.com>
2022-07-19 16:38:40 -07:00
Jake VanderPlas
2f4c485a54 Add dlpack support to device_array and jax.numpy 2022-07-15 17:31:11 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
George Necula
b6c90693c6 Fix mypy annotations 2022-07-05 12:49:37 +03:00
George Necula
5983d385da [dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
Also add more tests.
2022-07-05 12:14:15 +03:00
jax authors
33f1f40b20 Merge pull request #11298 from pschuh:axis-cache-env
PiperOrigin-RevId: 458328457
2022-06-30 15:42:48 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Parker Schuh
6c5d204d7e Jax caches should depend on axis env. 2022-06-29 14:25:14 -07:00
Matthew Johnson
5f97dc8954 Roll forward with simple fix: handle Zero cotangents in _broadcast_in_dim
transpose rule (previously handled by the deflinear2 wrapper, which it's no
longer using).

PiperOrigin-RevId: 456874635
2022-06-23 15:30:22 -07:00
jax authors
e4d1e1beb3 Copybara import of the project:
--
a001c52f878824cd1c0a67c73d9d318ed30286c9 by Matthew Johnson <mattjj@google.com>:

[dynamic-shapes] basic jvp working, including with broadcast

PiperOrigin-RevId: 456822732
2022-06-23 11:32:30 -07:00
Matthew Johnson
a001c52f87 [dynamic-shapes] basic jvp working, including with broadcast 2022-06-18 13:38:48 -07:00
jax authors
5f849d3aaa Merge pull request #11116 from mattjj:djax-typecheck
PiperOrigin-RevId: 455706708
2022-06-17 15:23:19 -07:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
Sharad Vikram
5d3f48204d Add stateful for loop primitives (#10982)
Adds a `get/swap/addupdate` primitive, along with impl, abstract_eval
and jvp rules.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-06-15 15:55:38 -07:00
Jean-Baptiste Lespiau
bab8520d0c Initialize the thread-local compilation context when undefined in new threads.
PiperOrigin-RevId: 452119314
2022-05-31 12:57:48 -07:00
Matthew Johnson
ffa9328a68 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451320477
2022-05-26 23:21:40 -07:00
Matthew Johnson
995220a739 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451268007
2022-05-26 16:26:49 -07:00
Matthew Johnson
9b724647d1 [djax] add support for dynamic-shape outputs 2022-05-26 13:22:06 -07:00
jax authors
87d2474cdf Merge pull request #10659 from jakevdp:devicearray-pickle
PiperOrigin-RevId: 449995717
2022-05-20 08:59:07 -07:00
Sharad Vikram
94e719935b Make Effect a hashable type 2022-05-19 12:33:15 -07:00
Jake VanderPlas
991ad72e24 DeviceArray: Improve support for copy, deepcopy, and pickle 2022-05-19 12:00:58 -07:00
jax authors
23eea5ddad Merge pull request #10756 from mattjj:10750
PiperOrigin-RevId: 449597253
2022-05-18 15:56:13 -07:00