39 Commits

Author SHA1 Message Date
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
George Necula
2299f05b8b [shape_poly] Cleanup the evaluation of dynamic shapes
Previously, we used the following pattern to generate the 1D
tensors representing dynamic shapes:

```
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, shape))
```

Now we write:
```
mlir.eval_dynamic_shape_as_tensor(ctx, shape)
```
2023-06-25 18:20:50 +02:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Matthew Johnson
61b106ec8f allow lax.dot_general to accept different input dtypes
This change brings the dot_general primitive more in line with the HLO
primitive, as it is described in XLA's shape_inference.cc (but not in the
StableHLO spec). In particular we allow different input dtypes.

The main motivation is to support transposition in the presence of
preferred_element_type (which can set the output dtype to be different from the
inputs), e.g. to fix #10818.

However, because XLA platforms/backends can't seem to codegen all the cases
that are accepted by shape_inference.cc, in our lowering rules we generate
ConvertElementTypes on the inputs in a platform-dependent way.
2023-05-22 10:33:42 -07:00
George Necula
1429dd5be2 [shape_poly] Remove old test limitations
When we create "vmap"-based test harnesses from primitive harnesses
we used to exclude certain primitives. We reduced the list to one
primitive, "tridiagonal_solve" for which vmap is not defined.

We have also added a more explicit error about certain unsupported
dynamic shape features for convolution (waiting for StableHLO feature).
2023-05-10 13:38:24 +02:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
John QiangZhang
171b22dbbc Add padding option "SAME_LOWER" for ticket https://github.com/google/jax/pull/14990
PiperOrigin-RevId: 518984018
2023-03-23 15:50:16 -07:00
George Necula
582c042079 Implement lowering for convolutions with dynamic padding
PiperOrigin-RevId: 509451627
2023-02-14 00:55:45 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Jake VanderPlas
0b5443c6e8 Clean up: remove unused helper functions 2023-02-01 09:55:58 -08:00
Jake VanderPlas
671c72a782 Update signature of ad.defbilinear to simplify transpose rules 2023-01-31 09:07:39 -08:00
George Necula
d25bcac93d [shape_poly] Add better support for division, and working with strides
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
2023-01-25 07:37:54 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
6c59d72c75 Bump the minimum jaxlib version to 0.3.15. 2022-09-08 16:43:46 -04:00
Peter Hawkins
6ddf3c4d97 Reapply: Use MLIR bytecode when passing IR to backends.
MLIR bytecode is more compact to represent and should be faster to generate and parse.

The previous attempt at this change broke for 0D convolutions. JAX was not ensuring that the padding attribute had the correct [N, 2] shape when N was 0.

PiperOrigin-RevId: 472991661
2022-09-08 08:11:16 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
George Necula
ab7d036271 Remove dependencies on masking.py 2022-07-25 11:25:26 +03:00
Jake VanderPlas
489596c0e2 lax.conv_general_dilated: validate negative paddings 2022-07-19 11:15:18 -07:00
Anish Tondwalkar
a2f2d1fa42 [mhlo] ConvOp -> ConvolutionOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459808502
2022-07-08 12:13:51 -07:00
jax authors
6c60571eaf Merge pull request #10933 from hawkinsp:unzip
PiperOrigin-RevId: 452429103
2022-06-01 18:20:52 -07:00
Peter Hawkins
9be53caa43 Use util.unzipN() in more places instead of zip(*args). 2022-06-01 16:28:31 -04:00
Peter Hawkins
ece9b999fb Fix batching rule for convolution for batch dimensions of size 0. 2022-06-01 14:18:16 -04:00
Benjamin Kramer
bf5d38c213 [MLIR] Explicitly name arguments on ConvOp
These were shuffled around by
9b79f50b59,
creating incompatibilities between different mhlo versions.

PiperOrigin-RevId: 450648004
2022-05-24 04:06:41 -07:00
jax authors
6252377d19 Integrate LLVM at llvm/llvm-project@c8e0870829
Updates LLVM usage to match
[c8e087082927](https://github.com/llvm/llvm-project/commit/c8e087082927)

PiperOrigin-RevId: 450576923
2022-05-23 19:11:43 -07:00
Peter Hawkins
44f1e05a76 Add input validation for the padding argument to lax.conv_general_dilated.
Fixes #10729
2022-05-17 08:52:37 -04:00
Peter Hawkins
21e1f8c3d1 [JAX] Delete last references to conv/dot translation rules.
Replace references with MHLO equivalents.

PiperOrigin-RevId: 442675847
2022-04-18 17:42:47 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
a87b21148c [MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.

Previously the MLIR lowering rule signature was

```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```

where `ctx` was a module-wide context.

Change it to

```
def rule(ctx, *args, **jaxpr_params)
```

where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.

This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.

PiperOrigin-RevId: 416698663
2021-12-15 19:06:58 -08:00
Peter Hawkins
12dddac96a [MLIR] Make two dtype fixes.
* when converting from a non-bool type to a boolean, lower it as x != 0 rather than convert(x, i1). Convert has truncation semantics, but we are expecting XLA's x != 0 semantics instead.
* revert https://github.com/google/jax/pull/8825 and part of https://github.com/google/jax/pull/8810. PR https://github.com/google/jax/pull/8828 means that we now will never have a non-canonical preferred_element_type, and so the output type is once again always equal to the preferred element type.

PiperOrigin-RevId: 414716056
2021-12-07 07:15:55 -08:00
Peter Hawkins
16a663f427 Canonicalize dot/conv preferred_element_type during tracing.
This avoids non-canonical types showing up in surprising places.

It is possible that some users are specifying a 64-bit type here intentionally, but that seems unlikely. The fix in that case would be to disable non-x64 mode.

PiperOrigin-RevId: 414511197
2021-12-06 12:21:30 -08:00
Peter Hawkins
22fb38c848 [MLIR] Handle preferred_element_type correctly in convolution lowering rule.
Similar to the fix to dot_general in https://github.com/google/jax/pull/8810

This is hard to detect from a direct test, except by inspecting the IR, which I'd rather avoid. However the jax2tf tests already catch it since they have a very tight test tolerance.

PiperOrigin-RevId: 414479170
2021-12-06 10:18:52 -08:00
Peter Hawkins
fa411d864e [MLIR] Fix CPU test failures for MLIR lowering.
The remaining failures relate to buffer donation and xmap_p, which are not yet implemented.

Quite a few primitives still use fallback paths.

PiperOrigin-RevId: 413130158
2021-11-30 06:08:55 -08:00
Jake VanderPlas
50935b56ab conv_transpose: allow padding to be specified as a list 2021-11-29 10:31:51 -08:00
Peter Hawkins
839d410de0 [MLIR] Move most MLIR translation rules into lax.
PiperOrigin-RevId: 411942327
2021-11-23 18:58:28 -08:00
Peter Hawkins
4204a25c91 Split convolution functions out of jax._src.lax.lax and into a separate module (jax._src.lax.convolution).
No public API changes.

PiperOrigin-RevId: 411871903
2021-11-23 12:35:50 -08:00