37 Commits

Author SHA1 Message Date
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
Yash Katariya
3c7d927a2c Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated.
PiperOrigin-RevId: 477833061
2022-09-29 15:09:55 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
ce2eb7dcfb fix custom_object_test for new Array type 2022-08-18 16:33:05 -07:00
Matthew Johnson
bea66b1b1a add support for lambda-bound dynamic shape output (iree only)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-05-18 21:57:30 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
40805e4f18 [XLA] Call translation rule directly in xla.primitive_subcomputation.
Remove:
* XLA jaxpr interpreter
* code for building IR constants
* code for building sharded or donated parameters
* code for building sharding custom calls.
* dead code in sharded_jit.py
PiperOrigin-RevId: 442686730
2022-04-18 18:48:27 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
648a512488 [MHLO] Add direct MHLO lowerings for sparse primitives.
PiperOrigin-RevId: 440374054
2022-04-08 08:43:57 -07:00
Peter Hawkins
a87b21148c [MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.

Previously the MLIR lowering rule signature was

```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```

where `ctx` was a module-wide context.

Change it to

```
def rule(ctx, *args, **jaxpr_params)
```

where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.

This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.

PiperOrigin-RevId: 416698663
2021-12-15 19:06:58 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00
Peter Hawkins
1ce94bec11 [MLIR] Improve lowering of dot_general.
* Handle preferred_element_type: it turns out not to be entirely subsumed by the output aval.
* Add an optimization for XLA/CPU float16 inputs that was present in the XLA translation rules but not the MLIR rules.
* Change mlir.dtype_to_ir_type to be a function so it can perform input validation.

PiperOrigin-RevId: 414159499
2021-12-04 10:35:31 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Peter Hawkins
fa411d864e [MLIR] Fix CPU test failures for MLIR lowering.
The remaining failures relate to buffer donation and xmap_p, which are not yet implemented.

Quite a few primitives still use fallback paths.

PiperOrigin-RevId: 413130158
2021-11-30 06:08:55 -08:00
Peter Hawkins
12512cc96a Merge most of the MLIR JIT dispatch logic into the common primitive and JIT computation path.
Change the representation of both units and tokens at the runtime level to be a single buffer with shape pred[0]. While the MLIR lowering is happy to have a non 1:1 mapping between avals and IR values, the XLA lowering is not, so until we remove the XLA lowering it's easiest just to keep the mapping 1:1.

PiperOrigin-RevId: 412957231
2021-11-29 12:40:05 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Peter Hawkins
e783cbcb72 Port remaining translation rules inside JAX to new style.
PiperOrigin-RevId: 404288551
2021-10-19 09:48:37 -07:00
Peter Hawkins
1a73743610 Move xla_bridge.constant to jax.interpreter.xla.pyval_to_ir_constant.
This is a more descriptive name and a better location (next to other facilities for building XLA IR).

Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.

PiperOrigin-RevId: 404270649
2021-10-19 08:40:51 -07:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
Peter Hawkins
6a45a9236d Remove the _num_buffers attribute from core.AbstractValue.
The number of buffers used to represent an abstract value is a property specific to a particular representation of that abstract value. Currently the only representation is an XLA representation, but that may change in the future. Instead, callers who want to know how XLA would represent an aval should ask the XLA module instead. In this case, we call len(xla.aval_to_xla_shapes(...)) instead.
2021-10-13 14:35:07 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Jake VanderPlas
63a788b4de Cleanup: switch to new version of super() 2021-08-05 13:11:07 -07:00
Jake VanderPlas
2868030160 Generalize constant handlers for multi-buffer objects
Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-05-06 09:44:51 -07:00
Peter Hawkins
63c06ef77e [JAX] Add a .weak_type attribute to C++ array objects.
Use .weak_type instead of parsing avals from C++. Inspecting Python objects unnecessarily is slow. In addition we were building a Python bool object that we didn't need to build (`py::cast<py::bool_>` instead of `py::cast<bool>`).

Benchmarks on my workstation:

```
name                                old time/op             new time/op             delta
jit_trivial_dispatch                44.9µs ± 1%             44.3µs ± 0%   -1.37%          (p=0.008 n=5+5)
jit_trivial                         46.2µs ± 0%             45.6µs ± 0%   -1.39%          (p=0.008 n=5+5)
jit_simple_dispatch                 17.7µs ± 2%             16.6µs ± 1%   -6.37%          (p=0.008 n=5+5)
jit_simple                          18.5µs ± 5%             17.3µs ± 1%   -6.54%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_10    26.6µs ± 1%             22.6µs ± 2%  -15.12%          (p=0.008 n=5+5)
jit_simple_many_args_10             27.9µs ± 3%             24.6µs ± 4%  -12.00%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_100    107µs ± 1%               75µs ± 1%  -29.85%          (p=0.008 n=5+5)
jit_simple_many_args_100             108µs ± 1%               76µs ± 0%  -29.66%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_1000  1.01ms ± 1%             0.69ms ± 2%  -31.72%          (p=0.008 n=5+5)
jit_simple_many_args_1000           1.03ms ± 1%             0.71ms ± 2%  -30.77%          (p=0.008 n=5+5)
jit_simple_many_args_dispatch_2000  2.09ms ± 1%             1.43ms ± 3%  -31.78%          (p=0.008 n=5+5)
jit_simple_many_args_2000           2.08ms ± 1%             1.44ms ± 4%  -30.77%          (p=0.008 n=5+5)
jit_dispatch_without_transfer       1.41ms ± 1%             1.43ms ± 6%     ~             (p=1.000 n=5+5)
jit_dispatch_with_transfer          1.40ms ± 1%             1.40ms ± 1%     ~             (p=1.000 n=5+5)
```

PiperOrigin-RevId: 363002879
2021-03-15 12:30:15 -07:00
Peter Hawkins
140c0acbbe Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.

At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
   avoid materializing it in its expanded form.

It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
Peter Hawkins
2469ad1bb3 Cleanups for laziness. No functional changes intended.
Use None as a trivial lazy expression in more places. Simplify some code.
2021-03-07 11:33:04 -05:00
James Bradbury
10dcb26cb3 [avals with names] Add named_shape to ShapedArray and update typecompat
The second change in the avals-with-names stack:
- https://github.com/google/jax/pull/5524 Revise aval constructor call sites to use a new `aval.update` method
- **Add `named_shape` to `ShapedArray` and update typecompat**
- Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules
- Make `mapped_aval`, `unmapped_aval`, and their xmap equivalents swap positional and named axes (rather than just creating and deleting positional ones)
- Enable `lax.full` to create values with named axes
- Ensure `grad` and `jacfwd`/`jacrev` consistently act elementwise over named axes (by e.g. using a seed with named axes in `grad`, and prohibiting collectives if TAP isn't too unhappy) and align `vmap(transpose)` with `transpose(vmap)` by moving the `psum` in `transpose(psum)` into `backward_pass`
- Add `axis_name` kwarg to grad to indicate operating collectively over one or more named axes

PiperOrigin-RevId: 355880632
2021-02-05 10:41:05 -08:00
James Bradbury
f1918f0b19 [avals with names] Revise aval constructor call sites to use a new aval.update method
PiperOrigin-RevId: 354182876
2021-01-27 15:14:02 -08:00
Jake VanderPlas
7917e2780f lower_fun: create table of translations_with_avals 2021-01-05 13:16:59 -08:00
Jean-Baptiste Lespiau
3e5a0ff0c4 Add methods to interact with DeviceArray objects.
We are going to add a C++ implementation, this is a useful refectoring to ease the transition. In short,

- `isinstance(x, DeviceArray)` will continue to work
- type(x) is DeviceArray will be replaced by type_is_device_array(x)
- DeviceArray(...) constructor will be replaced by get_device_array.
2020-11-03 22:16:28 +01:00
Jake VanderPlas
48db25e659 [multi-buf] simplify custom object test avals 2020-10-22 09:42:38 -07:00
Jake VanderPlas
40016cc47c Allow jax objects to be represented by multiple buffers 2020-09-29 11:53:17 -07:00
Peter Hawkins
a0e14b0552
Revert "Allow JAX objects to be represented by multiple buffers" 2020-09-29 09:26:11 -04:00
Jake VanderPlas
d1f80228e0 Allow jax objects to be represented by multiple buffers 2020-09-25 11:09:08 -07:00