203 Commits

Author SHA1 Message Date
jax authors
797f577fb8 Expose mlir.ShapePolyLoweringState
PiperOrigin-RevId: 571075542
2023-10-05 11:21:09 -07:00
George Necula
552fef6fcd Introduce a LoweringParameters dataclass for easier plumbing
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.

We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.

Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
2023-09-29 08:23:05 +03:00
Peter Hawkins
d0a6813ea2 Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call().
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.

Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.

This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).

Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.

PiperOrigin-RevId: 561042402
2023-08-29 08:50:07 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
George Necula
081b86b82a [shape_poly] Improved computation of dimension variables for native serialization
Previously for native serialization we could only support polymorphic_shapes
where the specification was a simple dimension variable. E.g., we could not
handle a specification where `polymorphic_shapes="2*b"` because there was
no way to recover the value of `b` from the actual shape. (For non-native
serialization we were supporting some limited equation solving.)

The above is important, e.g., for the gradient of functions like
`jnp.concatenate([x, x])`, where the output shape if `2 *b`.

This is possible because in #15258 we have brought the computation
of the dimension variables into jax_export.

What we do here is to even out the support for native serialization to have
the same power as the non-native one. We do this by reusing the
same `shape_poly.prepare_dim_var_env` that we use for non-native
serialization.

After we land this, we will refactor the shape environment to be cleaner.
2023-03-30 15:51:24 +02:00
Yash Katariya
a9e48af260 Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
2023-03-23 11:44:42 -07:00
Peter Hawkins
9cf3cb4486 Reexport jax.interpreters.mlir.token_type.
Fixes https://github.com/google/jax/issues/14551
2023-02-17 13:26:44 +00:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Peter Hawkins
08ff7f4ea9 Prune accidentally exported names from jax.interpreters.ad.
PiperOrigin-RevId: 507584433
2023-02-06 14:36:44 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
5bc14fdac8 Merge pull request #14277 from gnecula:poly_div
PiperOrigin-RevId: 506905837
2023-02-03 08:11:30 -08:00
George Necula
f147e82fa7 [shape_poly] Add support for evaluating div/mod for DimExpr
We have added the ability to represent floordiv and mod to
DimExper. Here we add support for evaluating these dimensions
for the native lowering.
2023-02-03 17:44:26 +02:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Roy Frostig
b1b4915c1c remove opaque dtype aval translation to MLIR types
We already have a mapping from opaquely-dtyped avals to basic
"physical" avals, and we can map the latter to MLIR types.
2023-01-27 14:27:30 -08:00
Qiao Zhang
d203926c16 Expose fp8 in jax dtypes and mlir builder.
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Yash Katariya
7e8fe13c6a jit was the default name in name_stack in mlir.py. Fix that by taking the name as an optional argument (defaulting to jit) so that nested pjits will show up as pjit in the name stack.
PiperOrigin-RevId: 501946780
2023-01-13 15:00:22 -08:00
George Necula
f7093955dc [jax2tf] Fixed the shape-polymorphic lowering for lax.pad and dynamic_slice
Generate DynamicPadOp instea of PadOp when the padding
sizes are not constant.

Fix the generation of RealDynamicSliceOp.

Exclude some tests that fail due to unimplemented support
for custom calls with polymorphic shapes.
2023-01-11 13:02:48 +01:00
Sharad Vikram
48eb39a9b8 Remove dummy recv in TPU callback lowering, seems like it is no longer necessary
PiperOrigin-RevId: 498318923
2022-12-28 22:55:47 -08:00
jax authors
a71ab80de7 Merge pull request #13804 from jakevdp:masked-arr-error
PiperOrigin-RevId: 498199498
2022-12-28 09:38:04 -08:00
George Necula
71ce600127 [jax2tf] Ensure that dim_as_value returns int64 in x64 mode and int32 otherwise
Changes all the computations with dimensions to work in int64 if
JAX_ENABLE_X64 and int32 otherwise.
2022-12-28 18:18:33 +02:00
Jake VanderPlas
53676932e8 Error on numpy masked array inputs. 2022-12-27 15:42:49 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Eugene Burmako
ee1ad39dd1 Port type inference for 6 ops from StableHLO to MHLO
Ops:
  1) AfterAllOp: https://github.com/openxla/stablehlo/pull/708.
  2) CreateTokenOp: https://github.com/openxla/stablehlo/pull/711.
  3) DynamicUpdateSliceOp: https://github.com/openxla/stablehlo/pull/686 and https://github.com/openxla/stablehlo/pull/757.
  4) OptimizationBarrierOp: https://github.com/openxla/stablehlo/pull/575.
  5) OutfeedOp: https://github.com/openxla/stablehlo/pull/713.
  6) SendOp: https://github.com/openxla/stablehlo/pull/580.

This PR prepares for migration from producing MHLO to producing StableHLO by
aligning type inference between dialects, so that switching from one to another
doesn't need changes to calls to Python builders.

PiperOrigin-RevId: 495404149
2022-12-14 13:38:26 -08:00
George Necula
ac7740513d Raise error for unsupported shape polymorphism for custom call and fallback lowering 2022-12-14 12:31:18 +01:00
George Necula
27f5bd057c Improves handling of opaque types for dynamic shapes
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.

The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.

While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.

.
2022-12-12 05:19:04 +01:00
Eugene Burmako
2c92037150 Fail lower_jaxpr_to_module if the module fails verification
When working with George on https://github.com/google/jax/pull/13427, I discovered that modules with verifier errors can happily cross API boundaries and create confusion downstream.

As discussed, this is unintentional - the expectation was that `ctx.module.operation.verify()` will throw an exception when verification fails. This CL addresses that and throws an exception accordingly.

Not sure how to test this, given that passing a module with verifier errors to module_to_string indicates a logic error (i.e. such module shouldn't have been produced by JAX in the first place). As a result, I didn't write any tests, but I'm happy to write them if there's a good way to do that.

PiperOrigin-RevId: 493940591
2022-12-08 10:55:49 -08:00
George Necula
8fb344a724 [jax2tf] An alternative support for shape polymorphism for native serialization.
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.

For native serialization we will support two lowering implementations:

  * one is using the growing support in JAX for dynamic shapes,
  of which shape polymorphism is a special case.
  This implementation is enabled with the --jax_dynamic_shapes flag.
  At the moment, the JAX dynamic shapes support is still
  incomplete and over 300 jax2tf shape polymorphism tests fail.

  * a new one (added) here in which we form a Jaxpr using abstract
  values that express dimension sizes as dimension polynomials
  (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
  This implementation is enabled when --jax_dynamic_shapes is off.
  With this implementation only 50 jax2tf tests fail (to be fixed
  separately).

The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.

The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.

Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.

The key code pattern used in the lowering rule is::

    if not core.is_constant_shape(shape):  # Handles both Var, and polynomials
       shape = mlir.eval_dynamic_shape(ctx, shape)
       return mhlo.DynamicXXX(..., shape)
    else:
       return mhlo.XXX(..., shape)

with `mlir.eval_dynamic_shape` handling both cases::

    def eval_dynamic_shape(ctx, shape):
       if config.jax_dynamic_shapes:
          # Using Var
          return ... subst using ctx.axis_size_env ...
       else:
          # Using polynomials
          return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values

In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.

I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
2022-12-08 08:19:35 +02:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Sharad Vikram
3731e446c0 Set default layout for Python callback
PiperOrigin-RevId: 487388682
2022-11-09 17:18:49 -08:00
Eugene Burmako
55996328f2 Introduce XlaLowering::stablehlo() and use it in associated APIs
See tests/api_test.py for usage examples.

At the moment, stablehlo() works by using the hlo-legalize-to-stablehlo pass, which takes MHLO natively produced by JAX and converts it into StableHLO. This is an intermediate step towards switching JAX to natively produce StableHLO.

This CL adds both mhlo_to_stablehlo and stablehlo_to_mhlo to jaxlib, even though only the former is used at the moment. This is done in anticipation of switching JAX to natively produce StableHLO, where stablehlo_to_mhlo will be needed to provide backward compatibility for XlaLowering::mhlo(). We're adding stablehlo_to_mhlo now, so that in the future we don't have to update jaxlib again which will make deployment easier.

PiperOrigin-RevId: 487144342
2022-11-08 22:50:06 -08:00
Jake VanderPlas
0691be6a2b [typing] update jaxlib & remove unnecessary ignore 2022-11-04 11:02:33 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Peter Hawkins
ce9e009c4c [JAX:CPU] Enable buffer donation on CPU.
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.

PiperOrigin-RevId: 484001545
2022-10-26 10:13:01 -07:00
Benjamin Kramer
dd04953361 [MLIR] Don't rely on hardcoded -1 for dynamic axis sizes
The magic number might change, use an accessor to get it.

PiperOrigin-RevId: 482796475
2022-10-21 08:13:02 -07:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Jake VanderPlas
5d15757741 [typing] annotate jax._src.util.safe_map 2022-10-20 10:15:04 -07:00
Jake VanderPlas
87f1a2bac7 CI: update mypy version in pre-commit config 2022-10-17 11:25:14 -07:00
Peter Hawkins
ec5bec6157 Include column information in Python locations under Python 3.11.
https://peps.python.org/pep-0657/ means that we now have richer context information, which we can propagate where we use it, for example to the MHLO location in this example:

```
In [2]: jax.jit(lambda x: x + 2).lower(7).compiler_ir().operation.print(enable_debug_info=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
module @jit__lambda_ {
  func.func public @main(%arg0: tensor<i32> loc(unknown)) -> tensor<i32> {
    %0 = mhlo.constant dense<2> : tensor<i32> loc(#loc0)
    %1 = mhlo.add %arg0, %0 : tensor<i32> loc(#loc1)
    return %1 : tensor<i32> loc(#loc0)
  } loc(#loc0)
} loc(#loc0)
#loc1 = loc("jit(<lambda>)/jit(main)/add"("<ipython-input-2-525e569b8960>":1:18))
```
2022-10-14 19:14:35 +00:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
jax authors
4f90af91d3 Remove unused jax_unique_mhlo_module_names flag.
PiperOrigin-RevId: 477778135
2022-09-29 11:32:22 -07:00
Sharad Vikram
805073f36a Add inspect_array_sharding, enabling looking at shardings in pjit-ted functions
PiperOrigin-RevId: 476237731
2022-09-22 17:36:56 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
e010ae7845 Pass device_assignment to ShardingContext instead of first_sharding which contains partitioning of an input too.
It does not make sense to pass how an input is partitioned to ShardingContext because you can have `n` inputs all partitioned in a different way but all of them should have the same device_assignment. This follows SPMDAxisContext too.

PiperOrigin-RevId: 474808207
2022-09-16 07:18:50 -07:00