249 Commits

Author SHA1 Message Date
Yash Katariya
96058d0197 Add support for MeshPspecSharding local_sharded_result_handler because SDA outputs from pjit can produce a MeshPspecSharding.
PiperOrigin-RevId: 470119499
2022-08-25 17:14:05 -07:00
Roy Frostig
acc025a268 minimal result-handling support for single-device key array pjit outputs
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 470054082
2022-08-25 12:23:19 -07:00
Roy Frostig
8e2d1be0a5 support jax.experimental.array.Array as a base array for key arrays
Only handle host-locally sharded `Array`s for now (like in SDAs under
`pmap`). Leaving global sharding for a follow up.

Also re-enable a previously skipped test as a result.

Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 469885160
2022-08-24 19:49:02 -07:00
Peter Hawkins
160a6c5229 Suppress msan failure in PRNG code.
Use np.zeros instead of np.empty for code that builds an IR constant.

PiperOrigin-RevId: 469566082
2022-08-23 15:05:06 -07:00
Roy Frostig
6071a8f875 roll-forward #11952, take 2
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`

PiperOrigin-RevId: 469276609
2022-08-22 13:57:31 -07:00
jax authors
3a2f25ff31 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468840334
2022-08-19 21:02:18 -07:00
Roy Frostig
9789e83b26 roll-forward #11952
... with a small adjustment, resetting the `random.PRNGKeyArray` type
during Python typechecking.

PiperOrigin-RevId: 468835674
2022-08-19 20:12:32 -07:00
jax authors
a6c6416872 Internal change
PiperOrigin-RevId: 468712508
2022-08-19 08:56:49 -07:00
Roy Frostig
34b63dfc77 teach jax2tf about custom eltypes, key arrays, and random key primitives
Specifically:

* Introduce a `physical_avals` view as a custom eltype method. This is
  analogous to the existing `aval_to_ir_types`, but where the output
  is an aval with a non-custom eltype (and hence a direct
  correspondence to TF and to lowerings).

* Change jax2tf to continue tracing with logical avals, but to
  maintain TF tensors of corresponding physical shape/dtype, and to
  translate to TF operations based on physical avals where relevant.

* Fix up various TF impl rules to follow physical avals. To this end,
  add a "physical" mode to jax2tf's `_convert_jax_impl` helper, which
  carries out the conversion using physical rather than logical avals.

* Write TF impl rules for `random_{seed,split,fold_in,bits}`
  primitives. To this end, factor out the part of these primitives'
  impl rules that operates on the base array and convert that, pass it
  through `_convert_jax_impl` in physical mode.

* Teach the jax2tf test harness how to unwrap key-array-typed outputs
  into physical `uint32` arrays that it can use in comparison tests.
2022-08-18 21:46:55 -07:00
Roy Frostig
affb031212 defer to custom eltype for sharded result handling, use this to handle sharded key arrays 2022-08-18 21:46:55 -07:00
Roy Frostig
7f06df1ea1 introduce key-element-type arrays and overhaul the Python PRNG key array type
Before this change, the Python PRNG key array was a pytree type
wrapping a `uint32` array. This was a stopgap that misbehaved under
`vmap`, `scan`, and even `jax.tree_map`. For a while, we thought we
might rely on something like the typeclass mechanisms in development
(e.g. `vmappable`) to move away from a pytree.

We're now taking a different approach: introducing key element types
into our IR and other internal machinery. During staging, we map
user-facing PRNG key arrays to abstract arrays such element type.

This leans heavily on our recently-introduced extended element type
capabilities.

As a consequence, `vmap`, `scan`, etc. now work.

A sample of changes made to introduce key-element-type arrays:

* Introduce a new element type (`prng.KeyTy`), with the requisite IR
  type mapping and device result handlers, as well as lowering rules
  for dtype-polymorphic primitive operations.

* Introduce primitives for basic RNG operations: `random_seed`,
  `random_bits`, `random_split`, `random_fold_in`. These primitives
  essentially delegate to the underlying PRNG implementation (directly
  so in their impl rules, and by translating their staged-out form in
  lowering rules).

* Also introduce `random_wrap` and `random_unwrap` for "unsafe"
  conversion from/to the base `uint32` array. We need this backwards
  compatibility, and it's useful for tests.

* Introduce some `vmap`-based helpers to adapt PRNG impls (which
  define basic `random_bits`, `split`, etc. on scalars) to the above
  batch-polymorphic primitives. Most of the primitives are vectorized,
  but `random_fold_in` is a broadcasting binary op.

* Update the `gamma` primitive rules to account for key-element-type
  abstract arrays (nice simplification here).

* Give PRNG implementation short string names ("tags") for IR
  pretty-printing.

* Update `lax.stop_gradient` to handle opaque dtypes.

* Fix up loop MLIR lowering, which assumed that shaped arrays of all
  dtypes have the same physical shape.

* Add new tests (exercising staging, jaxprs, lowerings, ...)

A sample of changes made to rework Python-level PRNG key arrays:

* Mimic `isinstance(x, KeyArray)` checks on abstract key arrays and
  tracers that carry them.

* Patch (only a subset of) standard device array attributes onto PRNG
  key arrays.

* Implement various conversion handlers (sharding, constant-creation,
  `device_put`).

* Accept PRNG key arrays as input to `lax_numpy.transpose`.

* Update tests and rename some internals.

A sample of extra changes along the way:

* Disallow AD on key-typed arrays in the main API.

* Hoist `random_bits`'s named-shape-handling logic, which used to only
  take place in the threefry PRNG's `random_bits` implementation, up
  to the new `random_bits` traceable, so that we apply it consistently
  across PRNG implementations.

This change leaves some unwanted `lax` and `jax.numpy` operations
superficially available on key arrays during tracing/staging
(e.g. under `jit`), though not outside of it. We ultimately want to
disallow these and raise useful errors, and I'm leaving that for
follow-up work. For now, applying such operations under `jit` may
result in downstream errors in the middle-end instead.

Everything here is still guarded by `config.jax_enable_custom_prng`,
whose default setting hasn't changed (it is off).
2022-08-18 21:46:55 -07:00
Roy Frostig
acb5e491ab sketch: setup for new key array implementation based on eltypes
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-18 20:54:08 -07:00
Penn
1987ca7389 Add dtype arg to jnp.concatenate and update tests 2022-08-01 15:48:40 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
George Necula
4b03ebf4f5 Fix overflow of large prng computation
Fixes: #11010
2022-06-20 10:48:15 +02:00
Jake VanderPlas
2b6387f83f [x64] make jax.random compatible with jax_numpy_dtype_promotion=strict 2022-05-27 11:12:39 -07:00
Peter Hawkins
4618f9ce03 Consolidate hip_prng and cuda_prng.
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.

This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.

PiperOrigin-RevId: 446761784
2022-05-05 10:55:29 -07:00
Peter Hawkins
beb420bd5d Revert: https://github.com/google/jax/pull/10221 (2nd revert)
Prefer jnp.tile over concatenate.

jnp.tile generates a jaxpr like the following:
```
{ lambda ; a:i32[720192]. let
    b:i32[1,720192] = reshape[dimensions=None new_sizes=(1, 720192)] a
    c:i32[720192] = squeeze[dimensions=(0,)] b
    d:i32[2,720192] = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(2, 720192)
    ] c
    e:i32[1440384] = reshape[dimensions=None new_sizes=(1440384,)] d
  in (e,) }
```

whereas lax.concatenate generates the following jaxpr:
```
{ lambda ; a:i32[720192]. let
    b:i32[1440384] = concatenate[dimension=0] a a
  in (b,) }
```

It seems the TPU compiler isn't doing as good a job with laying out memory for the formulation with `jnp.tile`. `reshape` in particular can be difficult for it to handle well, and it's best to avoid it when possible.

Since the benefit was marginal (a simpler jaxpr... but is it? Really?) and the cost is real (a user's model broke), we should revert this change.

PiperOrigin-RevId: 444287005
2022-04-25 09:16:12 -07:00
jax authors
b8971b9f28 Reapply: fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>:
Prefer `jnp.tile` over `concatenate`

PiperOrigin-RevId: 442803459
2022-04-19 07:12:27 -07:00
jax authors
fc2a12c478 Temporarily revert fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>:
Prefer `jnp.tile` over `concatenate`

PiperOrigin-RevId: 442693096
2022-04-18 19:34:30 -07:00
jax authors
f6705fc269 Merge pull request #10221 from lgeiger:concat-tile
PiperOrigin-RevId: 442587085
2022-04-18 11:19:07 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Lukas Geiger
fff370d78d Prefer jnp.tile over concatenate 2022-04-18 10:55:30 +01:00
YouJiacheng
4ff6b1fbca Fix PRNGKeyArray.broadcast_to with scalar shape 2022-04-16 00:27:30 +08:00
Peter Hawkins
3bfa6af2c8 [MHLO] Add MHLO lowering for PRNG kernels.
PiperOrigin-RevId: 439919104
2022-04-06 13:23:01 -07:00
Roy Frostig
7824325c23 remove _broadcasting_shape_rule from public jax.lax module 2022-03-08 16:34:26 -08:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
8b1d710202 custom_prng: better error messages for key validation 2022-03-04 10:49:29 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Roy Frostig
0f7904f883 implement jnp.expand_dims and jnp.stack for PRNGKeyArrays
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
2022-02-16 20:47:27 -08:00
Jake VanderPlas
4d9e9b4986 custom_prng: generalize indexing of PRNGKeyArray
Co-authored-by: Roy Frostig <frostig@google.com>
2021-12-20 10:16:32 -08:00
Lena Martens
e14fea3b63 Overload jnp ops which are polymorphic to an array's value and support PRNGKeys. 2021-11-16 23:00:32 +00:00
Peter Hawkins
e783cbcb72 Port remaining translation rules inside JAX to new style.
PiperOrigin-RevId: 404288551
2021-10-19 09:48:37 -07:00
Peter Hawkins
1a73743610 Move xla_bridge.constant to jax.interpreter.xla.pyval_to_ir_constant.
This is a more descriptive name and a better location (next to other facilities for building XLA IR).

Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.

PiperOrigin-RevId: 404270649
2021-10-19 08:40:51 -07:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
Roy Frostig
ba370f8c86 hide keys attribute of PRNGKeyArray in favor of unsafe_raw_array 2021-10-12 08:47:29 -07:00
Matthew Johnson
022cb8c0fc rbg_split and rbg_fold_in: use vmap for fewer HLOs 2021-10-07 21:19:06 -07:00
Matthew Johnson
634d252bb3 improvements to RBG PRNG
1. factor out rbg_prng_impl and unsafe_rbg_prng_impl. the former uses
   threefry2x32 for split and fold_in, while the latter uses untested
   heuristics based on calling rng_bit_generator itself as a kind of
   hash function
2. for unsafe_rbg_prng_impl's split and fold_in, generate longer
   sequences from rng_bit_generator (10x iterations) which may be useful on
   some backends
3. for unsafe_rbg_prng_impl, actually apply rng_bit_generator as our
   'hash function' in fold_in

Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
2021-10-07 18:59:13 -07:00
Matthew Johnson
980dfcfd2c add experimental RngBitGenerator ("RBG") PRNG
Not only is this an experimental API, but also because the
RngBitGenerator is not guaranteed to be stable across compiler versions
(or backends or shardings), let's assert that this JAX PRNG
implementation may not be stable across JAX versions.

Even without that kind of stability, this PRNG is still useful because
compared to effectful RNG primitives, like lax.rng_uniform, this RBG
PRNG will still work correctly with lax.scan and jax.checkpoint (while
still potentially being more performant on some platforms than JAX's
standard PRNG).
2021-10-01 19:04:14 -07:00
jax authors
ef696a0b43 Merge pull request #8019 from hawkinsp:pprint
PiperOrigin-RevId: 399424971
2021-09-28 06:26:24 -07:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Lena Martens
327e00a668 PRNGKeys can have dtype float0 if they are a tangent. 2021-09-24 16:37:56 +01:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Roy Frostig
2d7a98beaf make PRNGKeyArray.shape a property, add tests for it and for disallowed addition 2021-09-10 18:45:18 -07:00
Matthew Johnson
e416e87301 inline jit-decorated jax.random calls 2021-08-20 13:43:38 -07:00
Roy Frostig
60e0e9f929 implement backwards-compatible behavior and enable custom PRNGs only conditionally
Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
2021-08-19 20:43:11 -07:00
Roy Frostig
aa265cce95 introduce custom PRNG implementations and an array-like adapter for them
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.

A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:

1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
   back to it from the functions in random.
2021-08-19 20:43:11 -07:00
Roy Frostig
b4ccecca88 factor PRNG routines from random module to prng 2021-08-17 19:27:31 -07:00