67 Commits

Author SHA1 Message Date
Peter Hawkins
def35b7e24 Remove scatter/gather dimension proto helpers.
These are unused since the MHLO switch.

PiperOrigin-RevId: 506969590
2023-02-03 12:40:31 -08:00
George Necula
7e0041c903 Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python constants or C constants.

This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.

PiperOrigin-RevId: 502568204
2023-01-17 06:26:21 -08:00
George Necula
4d4eba4539 Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.

This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.

PiperOrigin-RevId: 497171518
2022-12-22 08:34:47 -08:00
George Necula
2b716f292d Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.

This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.

PiperOrigin-RevId: 497079129
2022-12-21 21:59:51 -08:00
George Necula
ce5320a2e4 Copybara import of the project:
--
a74c74c25572eec23c28e08dbe67781a23be19fb by George Necula <gcnecula@gmail.com>:

Fix scatter in CLIP mode with uint32 and uint64 indices

Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.

PiperOrigin-RevId: 496883024
2022-12-21 03:46:27 -08:00
George Necula
a74c74c255 Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.
2022-12-21 10:25:24 +02:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
George Necula
27f5bd057c Improves handling of opaque types for dynamic shapes
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.

The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.

While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.

.
2022-12-12 05:19:04 +01:00
George Necula
2f1354ee04 Add workaround for imprecise shape inference for DynamicGatherOp
This is needed for gather in presence of dynamic shapes.

PiperOrigin-RevId: 494613303
2022-12-11 20:18:15 -08:00
George Necula
8fb344a724 [jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.

For native serialization we will support two lowering implementations:

  * one is using the growing support in JAX for dynamic shapes,
  of which shape polymorphism is a special case.
  This implementation is enabled with the --jax_dynamic_shapes flag.
  At the moment, the JAX dynamic shapes support is still
  incomplete and over 300 jax2tf shape polymorphism tests fail.

  * a new one (added) here in which we form a Jaxpr using abstract
  values that express dimension sizes as dimension polynomials
  (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
  This implementation is enabled when --jax_dynamic_shapes is off.
  With this implementation only 50 jax2tf tests fail (to be fixed
  separately).

The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.

The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.

Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.

The key code pattern used in the lowering rule is::

    if not core.is_constant_shape(shape):  # Handles both Var, and polynomials
       shape = mlir.eval_dynamic_shape(ctx, shape)
       return mhlo.DynamicXXX(..., shape)
    else:
       return mhlo.XXX(..., shape)

with `mlir.eval_dynamic_shape` handling both cases::

    def eval_dynamic_shape(ctx, shape):
       if config.jax_dynamic_shapes:
          # Using Var
          return ... subst using ctx.axis_size_env ...
       else:
          # Using polynomials
          return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values

In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.

I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-12-08 08:19:35 +02:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Jake VanderPlas
524745f322 TMP: annotate util.safe_zip 2022-10-19 10:29:53 -07:00
Jake VanderPlas
1ed18fa500 add allow_opaque_dtype to dtypes.canonicalize_dtype utility 2022-10-17 13:47:42 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
Jake VanderPlas
ae9f8eeb0c [typing] annotate lax.slicing 2022-10-09 04:20:46 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Matthew Johnson
a8826e672b [dynamic-shapes] Add basic slicing support
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-28 15:55:51 -07:00
Peter Hawkins
eed327914e Improve documentation for unique_indices. 2022-09-23 09:11:15 -04:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
34eae3de88 jax.lax: ensure GatherDimensionNumbers contains tuples for hashability 2022-09-12 12:10:17 -07:00
Matthew Johnson
58826507cc [dynamic-shapes] add basic vmap-of-indexing support
The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.

I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
2022-09-08 17:52:12 -07:00
Roy Frostig
8f045b12d6 internal rename: swap mentions of "custom eltypes" for "opaque dtypes"
Also, avoid direct set membership tests on `core.opaque_dtypes`. Update
callers to use `core.{is,has}_opaque_dtype` predicates instead.
2022-08-30 16:52:08 -07:00
Roy Frostig
73bf0aa30c access rules through a hidden attribute of opaque dtype 2022-08-30 14:06:01 -07:00
Matthew Johnson
42dd7cac43 simplify slicing jaxprs a little
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-12 19:52:02 -07:00
Roy Frostig
c9d3094014 defer to custom eltype for gather lowering rule
This also adds an `mlir.delegate_lowering` helper that takes optional
replacements to the lowering context and restores/updates it on
return. This is used to define a gather lowering rule in tests here,
by calling out to the base gather lowering rule. I picked it off of a
working branch that uses it as well (for key-eltyped lowering rules).
2022-08-12 14:28:31 -07:00
Roy Frostig
7955799ae3 defer to custom eltype for slice lowering rule
We already handled dynamic slice, but plain slice is eltype-polymorphic too.
2022-08-09 19:13:34 -07:00
Matthew Johnson
81b6263ed0 Rolling forward #11768 after test failures caused roll-back (from use of np.empty).
PiperOrigin-RevId: 465712458
2022-08-05 22:19:33 -07:00
jax authors
6b0c0dc321 Internal change
PiperOrigin-RevId: 465705931
2022-08-05 21:08:43 -07:00
Matthew Johnson
348da51dc6 prototype unfettered element types in jaxpr arrays
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.

Within jaxprs, we also need to handle transformations with these types.

In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
  * allowing ShapedArray to have any object for its 'dtype' attribute
  * added core.custom_eltype set
  * extended existing handlers for ShapedArray to call the corresponding custom
    element type handlers
  * mlir lowerings of some fully-element-type-polymorphic primitives
  * tests

In this PR, we only actually use these new extension points in tests.

The applications to come that we have in mind are:
  * arrays of prngkeys (and even custom prngs, as well as reuse error checking)
  * arrays of bounded int type for dynamic shapes (and especially raggedness)
  * float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.

Jargon-wise, we may want to distinguish:
  * 'eltype' meaning jaxpr element types
  * 'dtype' meaning numpy dtypes (an existing convention)
  * 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.

We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
  * [x] broadcast
  * [ ] reshape
  * [x] transpose
  * [ ] pad
  * [x] slice, dynamic_slice, dynamic_update_slice
  * [ ] concatenate
  * [ ] all_to_all, gather, scatter, all_gather, collective_permute
  * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.

We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.

Co-authored-by: Roy Frostig <frostig@google.com>
2022-08-05 19:23:55 -07:00
jax authors
c4b255b527 Merge pull request #11580 from jakevdp:fix-dynamic-index
PiperOrigin-RevId: 463311212
2022-07-26 05:28:47 -07:00
George Necula
ab7d036271 Remove dependencies on masking.py 2022-07-25 11:25:26 +03:00
Jake VanderPlas
88b0d198ec dynamic_slice: correctly handle negative start indices in autodiff 2022-07-21 13:41:00 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Sharad Vikram
fcf65ac64e Bump minimum jaxlib version to 0.3.10 2022-06-28 15:39:21 -07:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
Peter Hawkins
823aae0ccb Verify that the indices passed to lax.dynamic_slice and lax.dynamic_update_slice are scalars.
We were previously missing a shape check for this, meaning that if non-scalars were passed we would receive a cryptic MHLO shape error rather than a helpful Python backtrace.

PiperOrigin-RevId: 452682084
2022-06-02 20:41:45 -07:00
Anish Tondwalkar
fe37af5b6a [mhlo] Make variadic Scatter legal
`mhlo.scatter` as before can still take the three arguments `(operand, scatter_indices, updates)`, but now it can operate on multiple operands with their own sets of updates, `(operands..., scatter_indicies, updates...)`

Following the change in HLO,
6c9157ffb6
In order to support JAX FR: https://github.com/google/jax/issues/10079

This CL updates the op definition to allow for variadic scatter, and stubs out all clients of ScatterOp so that support can be implemented in following CLs.

PiperOrigin-RevId: 452559990
2022-06-02 09:50:22 -07:00
Alex Zinenko
c888a7e283 Fix JAX after upstream MLIR Python API change
Autogenerated MLIR Python API was changed to only accept optional operation
arguments as keyword arguments in Python.

PiperOrigin-RevId: 450651273
2022-05-24 04:32:54 -07:00
jax authors
6252377d19 Integrate LLVM at llvm/llvm-project@c8e0870829
Updates LLVM usage to match
[c8e087082927](https://github.com/llvm/llvm-project/commit/c8e087082927)

PiperOrigin-RevId: 450576923
2022-05-23 19:11:43 -07:00
Xin Zhou
35c08626cf [mhlo] Add result type inference for mhlo.dynamic-slice.
PiperOrigin-RevId: 446366018
2022-05-03 21:58:29 -07:00
Peter Hawkins
0b470361da Change the default jnp.take mode to "fill".
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.

The previous behavior can be approximated using the "clip" or "wrap" fill modes.

PiperOrigin-RevId: 445130143
2022-04-28 06:01:56 -07:00
Peter Hawkins
7c6a550333 Change the default scatter mode to FILL_OR_DROP.
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.

This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.

After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.

Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.

Issues: https://github.com/google/jax/issues/278 https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
2022-04-27 12:26:55 -07:00
YouJiacheng
75e990bbc3 Fix typo in _scatter_add_lower_gpu
a87b21148c doesn't notice `_scatter_add_lower_gpu` using `mlir.lower_fun` instead of `xla.lower_fun`.
I follow the change done in that commit for _scatter_lower.
2022-04-22 23:55:11 +08:00
Sharad Vikram
f17c09eb8d add in mlir lowering for tokens 2022-04-21 11:28:58 -07:00