52 Commits

Author SHA1 Message Date
Jake VanderPlas
1ed18fa500 add allow_opaque_dtype to dtypes.canonicalize_dtype utility 2022-10-17 13:47:42 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
Jake VanderPlas
ae9f8eeb0c [typing] annotate lax.slicing 2022-10-09 04:20:46 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Matthew Johnson
a8826e672b [dynamic-shapes] Add basic slicing support
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-28 15:55:51 -07:00
Peter Hawkins
eed327914e Improve documentation for unique_indices. 2022-09-23 09:11:15 -04:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
34eae3de88 jax.lax: ensure GatherDimensionNumbers contains tuples for hashability 2022-09-12 12:10:17 -07:00
Matthew Johnson
58826507cc [dynamic-shapes] add basic vmap-of-indexing support
The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.

I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
2022-09-08 17:52:12 -07:00
Roy Frostig
8f045b12d6 internal rename: swap mentions of "custom eltypes" for "opaque dtypes"
Also, avoid direct set membership tests on `core.opaque_dtypes`. Update
callers to use `core.{is,has}_opaque_dtype` predicates instead.
2022-08-30 16:52:08 -07:00
Roy Frostig
73bf0aa30c access rules through a hidden attribute of opaque dtype 2022-08-30 14:06:01 -07:00
Matthew Johnson
42dd7cac43 simplify slicing jaxprs a little
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-12 19:52:02 -07:00
Roy Frostig
c9d3094014 defer to custom eltype for gather lowering rule
This also adds an `mlir.delegate_lowering` helper that takes optional
replacements to the lowering context and restores/updates it on
return. This is used to define a gather lowering rule in tests here,
by calling out to the base gather lowering rule. I picked it off of a
working branch that uses it as well (for key-eltyped lowering rules).
2022-08-12 14:28:31 -07:00
Roy Frostig
7955799ae3 defer to custom eltype for slice lowering rule
We already handled dynamic slice, but plain slice is eltype-polymorphic too.
2022-08-09 19:13:34 -07:00
Matthew Johnson
81b6263ed0 Rolling forward #11768 after test failures caused roll-back (from use of np.empty).
PiperOrigin-RevId: 465712458
2022-08-05 22:19:33 -07:00
jax authors
6b0c0dc321 Internal change
PiperOrigin-RevId: 465705931
2022-08-05 21:08:43 -07:00
Matthew Johnson
348da51dc6 prototype unfettered element types in jaxpr arrays
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.

Within jaxprs, we also need to handle transformations with these types.

In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
  * allowing ShapedArray to have any object for its 'dtype' attribute
  * added core.custom_eltype set
  * extended existing handlers for ShapedArray to call the corresponding custom
    element type handlers
  * mlir lowerings of some fully-element-type-polymorphic primitives
  * tests

In this PR, we only actually use these new extension points in tests.

The applications to come that we have in mind are:
  * arrays of prngkeys (and even custom prngs, as well as reuse error checking)
  * arrays of bounded int type for dynamic shapes (and especially raggedness)
  * float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.

Jargon-wise, we may want to distinguish:
  * 'eltype' meaning jaxpr element types
  * 'dtype' meaning numpy dtypes (an existing convention)
  * 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.

We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
  * [x] broadcast
  * [ ] reshape
  * [x] transpose
  * [ ] pad
  * [x] slice, dynamic_slice, dynamic_update_slice
  * [ ] concatenate
  * [ ] all_to_all, gather, scatter, all_gather, collective_permute
  * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.

We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.

Co-authored-by: Roy Frostig <frostig@google.com>
2022-08-05 19:23:55 -07:00
jax authors
c4b255b527 Merge pull request #11580 from jakevdp:fix-dynamic-index
PiperOrigin-RevId: 463311212
2022-07-26 05:28:47 -07:00
George Necula
ab7d036271 Remove dependencies on masking.py 2022-07-25 11:25:26 +03:00
Jake VanderPlas
88b0d198ec dynamic_slice: correctly handle negative start indices in autodiff 2022-07-21 13:41:00 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Matthew Johnson
004b59fbc9 [dynamic-shapes] basic linearize and grad working 2022-06-30 14:30:22 -07:00
Sharad Vikram
fcf65ac64e Bump minimum jaxlib version to 0.3.10 2022-06-28 15:39:21 -07:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
Peter Hawkins
823aae0ccb Verify that the indices passed to lax.dynamic_slice and lax.dynamic_update_slice are scalars.
We were previously missing a shape check for this, meaning that if non-scalars were passed we would receive a cryptic MHLO shape error rather than a helpful Python backtrace.

PiperOrigin-RevId: 452682084
2022-06-02 20:41:45 -07:00
Anish Tondwalkar
fe37af5b6a [mhlo] Make variadic Scatter legal
`mhlo.scatter` as before can still take the three arguments `(operand, scatter_indices, updates)`, but now it can operate on multiple operands with their own sets of updates, `(operands..., scatter_indicies, updates...)`

Following the change in HLO,
6c9157ffb6
In order to support JAX FR: https://github.com/google/jax/issues/10079

This CL updates the op definition to allow for variadic scatter, and stubs out all clients of ScatterOp so that support can be implemented in following CLs.

PiperOrigin-RevId: 452559990
2022-06-02 09:50:22 -07:00
Alex Zinenko
c888a7e283 Fix JAX after upstream MLIR Python API change
Autogenerated MLIR Python API was changed to only accept optional operation
arguments as keyword arguments in Python.

PiperOrigin-RevId: 450651273
2022-05-24 04:32:54 -07:00
jax authors
6252377d19 Integrate LLVM at llvm/llvm-project@c8e0870829
Updates LLVM usage to match
[c8e087082927](https://github.com/llvm/llvm-project/commit/c8e087082927)

PiperOrigin-RevId: 450576923
2022-05-23 19:11:43 -07:00
Xin Zhou
35c08626cf [mhlo] Add result type inference for mhlo.dynamic-slice.
PiperOrigin-RevId: 446366018
2022-05-03 21:58:29 -07:00
Peter Hawkins
0b470361da Change the default jnp.take mode to "fill".
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.

The previous behavior can be approximated using the "clip" or "wrap" fill modes.

PiperOrigin-RevId: 445130143
2022-04-28 06:01:56 -07:00
Peter Hawkins
7c6a550333 Change the default scatter mode to FILL_OR_DROP.
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.

This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.

After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.

Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.

Issues: https://github.com/google/jax/issues/278 https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
2022-04-27 12:26:55 -07:00
YouJiacheng
75e990bbc3 Fix typo in _scatter_add_lower_gpu
a87b21148c doesn't notice `_scatter_add_lower_gpu` using `mlir.lower_fun` instead of `xla.lower_fun`.
I follow the change done in that commit for _scatter_lower.
2022-04-22 23:55:11 +08:00
Sharad Vikram
f17c09eb8d add in mlir lowering for tokens 2022-04-21 11:28:58 -07:00
Sharad Vikram
5ff2e8eb4c Fix name stack bugs 2022-04-19 11:14:41 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Matthew Johnson
d21b958f30 add some simple iree tests
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Jake VanderPlas
e13c847e04 Index update operators: add scatter_apply() 2022-02-18 09:44:40 -08:00
Peter Hawkins
6388d53eca Disallow scatter_mul gradients if unique_indices=False.
The current gradients are incorrect if unique_indices=False. No gradient is better than an incorrect gradient.

https://github.com/google/jax/issues/9296

PiperOrigin-RevId: 423917753
2022-01-24 14:55:02 -08:00
George Necula
34ba42b5da [jax2tf] Improve shape polymorphism for scatter with CLIP mode
Fixes: #9231
2022-01-24 09:37:49 +01:00
Jake VanderPlas
4832f09981 lax.dynamic_update_slice: fix batching rule 2022-01-18 10:07:22 -08:00
Peter Hawkins
a87b21148c [MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.

Previously the MLIR lowering rule signature was

```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```

where `ctx` was a module-wide context.

Change it to

```
def rule(ctx, *args, **jaxpr_params)
```

where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.

This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.

PiperOrigin-RevId: 416698663
2021-12-15 19:06:58 -08:00
Peter Hawkins
53318a2a7a [MLIR] Support all fill_modes in GPU MLIR lowering for scatter_add.
PiperOrigin-RevId: 415617659
2021-12-10 15:01:08 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Peter Hawkins
34ec805698 [MLIR] Fix test failures on GPU and TPU.
PiperOrigin-RevId: 413226939
2021-11-30 13:11:01 -08:00
Peter Hawkins
fa411d864e [MLIR] Fix CPU test failures for MLIR lowering.
The remaining failures relate to buffer donation and xmap_p, which are not yet implemented.

Quite a few primitives still use fallback paths.

PiperOrigin-RevId: 413130158
2021-11-30 06:08:55 -08:00
George Necula
a94a8847c3 [jax2tf] Implement CLIP mode for scatter.
Also handle a limited form of shape polymorphism, where
the `operand.shape - update.shape` is a constant in the scatter dimensions,
even when the shapes may contain dimension variables.
2021-11-26 10:01:42 +02:00