103 Commits

Author SHA1 Message Date
Peter Hawkins
add967db88 [JAX] Add a dialect option to jit(...).lower(...).compiler_ir().
The dialect allows the user to select between HLO and MHLO output.

PiperOrigin-RevId: 415591372
2021-12-10 13:02:25 -08:00
Roy Frostig
b980acf375 detect and err on transformation of AOT-compiled function calls 2021-12-07 17:20:27 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00
Roy Frostig
90361dc345 methods for retrieving IRs from each AOT stage 2021-12-03 14:53:22 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Parker Schuh
46a1033311 Update device_get docs to mention parrallelism. 2021-11-30 10:20:11 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
496e400c71 [x64] Make autodiff respect weak types 2021-11-23 15:04:08 -08:00
jax authors
2ec1488876 Merge pull request #8629 from jakevdp:dtypes-dtype
PiperOrigin-RevId: 411791488
2021-11-23 05:58:15 -08:00
Jake VanderPlas
c4d9c4674f [x64] regularize dtype helpers 2021-11-22 15:35:12 -08:00
Roy Frostig
20a1517eeb factor tuple conversions into common pmap setup logic 2021-11-22 13:49:44 -08:00
Roy Frostig
cf64a945cf refine pmap-related annotations 2021-11-22 13:49:44 -08:00
Roy Frostig
fcdc0a6c1a ahead-of-time lowering and compilation frontend for pmap 2021-11-22 08:33:04 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Peter Hawkins
3fd3c46f20 Increase minimum jaxlib version to 0.1.74. 2021-11-18 15:06:58 -05:00
Matthew Johnson
5d35b8a119 add donated_invars to xla.XlaComputation
Co-authored-by: Brennan Saeta <saeta@google.com>
2021-11-16 13:41:21 -08:00
jax authors
7f3609f039 Merge pull request #8382 from mattjj:meshcomputation-hlo
PiperOrigin-RevId: 409036808
2021-11-10 19:25:34 -08:00
Matthew Johnson
05708aef2b jit(f).lower(...) works w/ duck typed shape/dtype 2021-11-10 15:58:49 -08:00
jax authors
e5e5bb3ac1 Merge pull request #8403 from shoyer:empty-map-args-error
PiperOrigin-RevId: 408501394
2021-11-08 19:39:04 -08:00
Matthew Johnson
50e7e952bd add internal vmappable interface (part 1)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2021-11-04 15:01:54 -07:00
Stephan Hoyer
860e5e7493 Generate a better error message if [pv]map receives no array arguments
The current error message for `jax.vmap(lambda x: 1)({})` is:
`ValueError: vmap must have at least one non-None value in in_axes`

With this PR, it becomes:
`ValueError: vmap wrapped function must be passed at least one argument
containing an array, got empty *args=({},) and **kwargs={}`
2021-10-29 12:45:22 -07:00
Matthew Johnson
2cb74e1f97 make djax run again 2021-10-29 10:56:39 -07:00
tamaranorman
d890ae9068
Use default backend if no backend supplied to xla_computation 2021-10-28 18:37:15 +01:00
Peter Hawkins
1a73743610 Move xla_bridge.constant to jax.interpreter.xla.pyval_to_ir_constant.
This is a more descriptive name and a better location (next to other facilities for building XLA IR).

Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.

PiperOrigin-RevId: 404270649
2021-10-19 08:40:51 -07:00
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Peter Hawkins
714e19a794 Remove xla_bridge.make_computation_builder().
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.

Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
2021-10-18 13:20:34 -04:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
Matthew Johnson
584aa13360 document axis_name in the vmap docstring
fixes #8220
2021-10-14 13:09:24 -07:00
Peter Hawkins
e0d23a7ff0 Improve performance of JIT dispatch when output arity is 1.
Building an output tuple has a non-zero cost on TPU. We can avoid it in the output arity 1 case.

PiperOrigin-RevId: 403142765
2021-10-14 11:28:03 -07:00
Matthew Johnson
725fe3abd4 don't automatically use new checkpoint implementation
There's a bug we're struggling to repro.

To use the new checkpoint, just use

```python
from jax.ad_checkpoint import checkpoint
```

rather than `from jax import checkpoint.
2021-10-14 07:09:06 -07:00
Peter Hawkins
6a45a9236d Remove the _num_buffers attribute from core.AbstractValue.
The number of buffers used to represent an abstract value is a property specific to a particular representation of that abstract value. Currently the only representation is an XLA representation, but that may change in the future. Instead, callers who want to know how XLA would represent an aval should ask the XLA module instead. In this case, we call len(xla.aval_to_xla_shapes(...)) instead.
2021-10-13 14:35:07 -04:00
Peter Hawkins
3361c76dca Consolidate primitive and jit lowering paths.
Before this change, primitives have a special case dispatch path that attempts
to avoid building a jaxpr in the cache miss case. However, there's no good
reason for this: it makes the code more complicated, and we're not particularly
optimizing for fast cache misses anyway (we care mostly about cache hits).

Make the primitive lowering path trace a small function using the xla_callable
lowering path instead.
2021-10-13 12:36:53 -04:00
jax authors
e956726b5a Merge pull request #8191 from mattjj:initial-style-remat
PiperOrigin-RevId: 402734397
2021-10-12 22:07:55 -07:00
Matthew Johnson
a310a8173c rewrite remat, leave old implementation for now 2021-10-12 21:39:26 -07:00
jax authors
4267bed529 Merge pull request #7158 from NeilGirdhar:fix_basis
PiperOrigin-RevId: 402730368
2021-10-12 21:32:55 -07:00
jax authors
02ec5e73b3 Merge pull request #8165 from zhangqiaorjc:blake_ckpt
PiperOrigin-RevId: 402725516
2021-10-12 20:57:56 -07:00
Qiao Zhang
d9c6ddcda6 Add checkpoint policy to save dots w/o batch dim. 2021-10-12 20:42:55 -07:00
Neil Girdhar
832cf214e3 Fix jacfwd and jacrev for heterogeneous pytrees
Changed the behavior of `jacfwd`, `jacrev`, and `grad` when the input
pytree elements have heterogeneous dtypes, e.g., real and complex
elements:

* Changed the dtypes of the pytree elements of the Jacobian produced by
  jacfwd to be those of the input tangent basis.

* Changed the dtypes of the pytree elements of the Jacobian produced by
  jacrev to be those of the output tangent basis.

* Changed the dtypes of the pytree elements of the primals and tangents
  produced by jacfwd and jacrev to be the same as the corresponding
  elements in the input.

Changed the behavior of the flags to `jacfwd` and `jacrev`:

* Changed the allow_int flag to only allows integer and Boolean dtypes.
  Previously, this flag allowed all other types.
2021-10-12 19:41:47 -04:00
Roy Frostig
f34387b9f6 work around unhashable named shapes in api.ShapeDtypeStruct 2021-10-12 15:53:44 -07:00
Roy Frostig
77abc7f8a1 ahead-of-time lowering and compilation frontend for pjit 2021-10-12 15:01:46 -07:00
jax authors
a8ce40be94 Merge pull request #7989 from gnecula:remat_docstring
PiperOrigin-RevId: 402551996
2021-10-12 06:57:58 -07:00
Roy Frostig
0c75f52fa8 ahead-of-time lowering and compilation for jit 2021-10-08 10:54:45 -07:00
Roy Frostig
75468c7495 factor out jit input preparation 2021-10-08 10:54:45 -07:00
George Necula
3938018228 Applied review suggestsions 2021-10-08 10:11:31 +02:00
Jean-Baptiste Lespiau
803b83ee15 Enable C++ pmap.
On CPU:
```
name                                     old cpu/op  new cpu/op  delta
pmap_trivial_2_devices                    128µs ± 6%    14µs ± 3%  -89.06%  (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           212µs ± 2%    35µs ± 1%  -83.54%  (p=0.008 n=5+5)
pmap_trivial_8_devices                    215µs ± 1%    40µs ± 4%  -81.31%  (p=0.008 n=5+5)
pmap_simple_2_devices                     123µs ± 5%    15µs ± 6%  -87.70%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            211µs ± 3%    35µs ± 2%  -83.24%  (p=0.008 n=5+5)
pmap_simple_8_devices                     217µs ± 5%    40µs ± 2%  -81.68%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices_100_args  5.42ms ± 7%  0.52ms ± 2%  -90.44%  (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.5ms ±21%  17.5ms ±37%  -34.18%  (p=0.008 n=5+5)
sda_index_1                              7.45µs ± 6%  7.53µs ± 6%     ~     (p=0.222 n=5+5)
sda_index_2                              14.1µs ± 1%  14.3µs ± 4%     ~     (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%  56.9µs ± 4%     ~     (p=0.310 n=5+5)

name                                     old time/op             new time/op             delta
pmap_trivial_2_devices                    136µs ± 8%               19µs ± 3%  -86.08%          (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           216µs ± 3%               39µs ± 2%  -81.94%          (p=0.008 n=5+5)
pmap_trivial_8_devices                    219µs ± 2%               49µs ±38%  -77.67%          (p=0.008 n=5+5)
pmap_simple_2_devices                     130µs ± 5%               20µs ± 5%  -84.38%          (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            216µs ± 3%               39µs ± 5%  -81.71%          (p=0.008 n=5+5)
pmap_simple_8_devices                     221µs ± 6%               43µs ± 1%  -80.41%          (p=0.016 n=5+4)
pmap_simple_dispatch_8_devices_100_args  5.52ms ± 7%             0.59ms ± 2%  -89.28%          (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.6ms ±21%             17.6ms ±37%  -34.04%          (p=0.008 n=5+5)
sda_index_1                              7.48µs ± 8%             7.53µs ± 6%     ~             (p=0.310 n=5+5)
sda_index_2                              14.1µs ± 1%             14.3µs ± 4%     ~             (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%             56.9µs ± 4%     ~             (p=0.310 n=5+5)
```

PiperOrigin-RevId: 401274089
2021-10-06 10:08:28 -07:00
Peter Hawkins
29447ed261 Fixes for Python 3.10.
With these changes, the JAX test suite passes on Python 3.10.
2021-10-05 15:25:28 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
jax authors
42da0892a2 Merge pull request #8059 from apaszke:keep-axis-env
PiperOrigin-RevId: 400209213
2021-10-01 08:41:31 -07:00
Adam Paszke
08685efb22 Keep axis_env initialized during jaxpr_subcomp
``jaxpr_subcomp`` likes to lower control-flow primitives by tracing them
again as JAX callables, but they're all axis primitives now and so they
do require a properly initialized axis env.
2021-10-01 11:14:55 +00:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00