320 Commits

Author SHA1 Message Date
George Necula
0c1a37ce33 [jax2tf] Add shape polymorphism support for jnp.eye 2021-08-03 09:19:49 +03:00
botev
69fcc0c20d Abstracts into a separate function to evaluation of a single jaxpr equation. 2021-08-01 18:39:51 +01:00
Matthew Johnson
c31688d2d1 fix cond-of-pmap bug 2021-07-29 10:34:43 -07:00
Lena Martens
2190734637 Add tracers to LeakChecker error, and filter out false positives this way.
If we can't find any hanging tracers in the gc.get_referrers chain, is it
really a leak? Probably not!
2021-07-29 15:45:24 +01:00
Lena Martens
19ee7b22e1 Expose UnexpectedTracerError and add docs. 2021-07-27 23:23:28 +01:00
George Necula
b62ceba91c [jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-27 09:02:15 +03:00
Roy Frostig
258ae44303 refine constant-hoisting heuristic for closure_convert
Instead of hoisting all float-type arrays during closure conversion,
only hoist JVPTracers (or tracers carrying such tracers
indirectly). Doing so better approximates the subset of
closure-captured values that participate in AD.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-07-26 18:50:45 -07:00
Lena Martens
24c9a933d6 Add shape and dtype of leaked tracer to UnexpectedTracerError. 2021-07-21 17:50:44 +01:00
Adam Paszke
d25f4b34b8 Add an option to strictly enforce sharding implies by named axes
At the moment, xmap SPMD lowering only enforces sharding constraints for
computation inputs and outputs, while leaving sharding propagation in the
body entirely up to the XLA SPMD partitioner. This patch adds a new flag
`experimental_xmap_enforce_inferred_sharding` that inserts additional
sharding constraint between every JAX primitive in the xmapped function.
Assuming that the SPMD partitioner never overrides user-defined constraints,
this should restrict it sufficiently to generate a computation that is
partitioned exactly as implied by the evolution of intermediate named shapes.

PiperOrigin-RevId: 385562158
2021-07-19 08:39:27 -07:00
George Necula
7e335e0e2e [jax2tf] Fix conversion of gradients for shape polymorphic functions.
This fixes the case when the primal shape polymorphic function has
output shapes that are polynomials of the input shapes (not just
dimension variables).
2021-06-23 11:20:11 +02:00
Qiao Zhang
57234e7eaa Fix typos and indent. 2021-06-16 11:10:42 -07:00
Adam Paszke
490f9778c8 Raise a friendlier error message when using loop axes in collectives 2021-06-08 11:55:03 +00:00
George Necula
2ccda70d83 [jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.

Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.

Division  is supported only in the cases when either there is no remainder,
or the divisor is a constant.

This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:

```
   y = x.reshape((2, -1))
   z = ... y ...
   return z.reshape(x.shape)
```
2021-06-03 10:58:06 +03:00
Matthew Johnson
d21e8c0657 handle case where trace debug_info is None 2021-05-06 18:38:20 -07:00
jax authors
3c6a41eb9c Merge pull request #6612 from google:tracer-errors
PiperOrigin-RevId: 372211269
2021-05-05 14:45:57 -07:00
Matthew Johnson
7ec0b40173 Roll-forward of #6584, which broke internal tests.
PiperOrigin-RevId: 371839298
2021-05-03 21:41:23 -07:00
Matthew Johnson
b9d72a480f improve concreteness error from arguments
also tweak some error message wording
2021-05-03 17:37:34 -07:00
jax authors
75b00a1235 Copybara import of the project:
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:

add 'inline' option to xla_call for jaxpr inlining

--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:

add 'inline' to jit docstring

--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:

new_sublevel in jax2tf

PiperOrigin-RevId: 371542778
2021-05-01 22:18:39 -07:00
jax authors
e8f209c775 Merge pull request #6584 from google:jit-inline-2
PiperOrigin-RevId: 371541392
2021-05-01 21:53:34 -07:00
Adam Paszke
893b5a09ea Make pjit into an initial style primitive
Partly to make it more robust (e.g. we no longer need to implement
post_process_call), partly beacuse it is not really a call primitive
(it modifies the argument and return avals in a multiprocess mesh),
and partly as an experiment to see how difficult would it be to actually
make it more autodidax-like.

Overall, it seems like a mixed bag, maybe slightly positive. The thunks
are gone which is nice, but one has to be much more careful when dealing
with avals.

PiperOrigin-RevId: 371352737
2021-04-30 09:57:36 -07:00
Matthew Johnson
3c400a3e58 add 'inline' option to xla_call for jaxpr inlining 2021-04-28 19:38:15 -07:00
Adam Paszke
8df502aeb2 Use the axis names attached to a primitive when selecting the top trace
This is useful e.g. for handling psums of values that are not sharded,
but are also not statically known constants that we can fold.
2021-04-28 09:46:24 +00:00
Adam Paszke
23f847e0d3 Make the initial-style -> final-style conversion rule based
Also, add a rule for pjit to make sure that we can eval jaxprs that
contain pjits.

PiperOrigin-RevId: 369964136
2021-04-22 15:30:30 -07:00
Adam Paszke
454f5e67b1 Enable nesting xmaps in SPMD lowering mode
This uses the recently added ability to modify the `BatchTrace` to add a new
`SPMDBatchTrace`, which additionally fills in `spmd_in_axes` and `spmd_out_axes`
of xmap primitives. Those fields are necessary, because XLA does not allow us to
emit partial `OpSharding` annotations, meaning that we have to track where the
positional axes of outer xmaps are inserted at the boundaries of inner xmaps.
Otherwise, XLA could misinterpret our intention and unnecessarily force
replication along the mesh axes used by the outer xmaps.

PiperOrigin-RevId: 369831571
2021-04-22 02:47:07 -07:00
Peter Hawkins
5261b776d2 Handle context manager configuration settings for matmul precision and numpy rank promotion correctly in JIT and linear_util caches.
PiperOrigin-RevId: 369643419
2021-04-21 06:36:35 -07:00
Adam Paszke
828e210601 Add a type checking rule for xmap
Also fix the type checking code in core, which incorrectly propagated output
avals of custom type checking rules.

PiperOrigin-RevId: 369485371
2021-04-20 11:43:33 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
jax authors
0f96406130 Merge pull request #6461 from apaszke:xmap-awn
PiperOrigin-RevId: 369208554
2021-04-19 06:36:34 -07:00
Lena Martens
fcf87cd7f2
Fix typo in NamedShape 2021-04-16 14:20:25 +01:00
Adam Paszke
c9b0b3122e Enable avals-with-names in xmap
Starting from this change, we start introducing xmapped names when
tracing the xmap jaxpr and eliminating them from avals when the values
are returned. This lets us enable two long-awaited checks:
1. Returning values that are mapped along more axes than `out_axes`
   declare now results in a readable error, instead of an internal
   vmap assertion.
2. We catch the resource-overlap error triggered by making two axes
   mapped to the same resources coincide in a single value.
2021-04-16 10:01:33 +00:00
jax authors
2531a48101 Merge pull request #6434 from gnecula:shape_poly_tests
PiperOrigin-RevId: 368616796
2021-04-15 04:58:02 -07:00
George Necula
e2d546638c [jax2tf] Re-organized the tests for shape polymorphism
Added primitive harnesses and rewrote many existing tests in terms
of those.

Fixed the shape polymorphism for jnp.where.
2021-04-15 13:27:33 +03:00
Adam Paszke
2d95d5ad2b Small updates to abstract eval rules (AWN related)
I've been reading the AWN-related PRs and have found a few places that
could be improved a little.
2021-04-15 09:50:11 +00:00
Adam Paszke
34f75ec197 Simplify named shape handling in jax.core
This makes the named shape joining function significantly simpler and
also has the added benefit of removing the requirement of having a
global total order on named axes, which we definitely shouldn't require!
After all, they might be classes defined by users who are unaware of
e.g. the classes we use internally.
2021-04-14 16:01:54 +00:00
jax authors
ce67e563a1 Merge pull request #6375 from gnecula:mask_clean
PiperOrigin-RevId: 367985125
2021-04-12 05:50:19 -07:00
Peter Hawkins
a54a5e59ee Remove backward compatibility code paths for jaxlib < 0.1.65.
Fix up a few version comments.
2021-04-09 15:39:38 -04:00
George Necula
8815425e36 Cleanup of the dispatch to shape polymorphic dimension handlers 2021-04-09 15:40:02 +03:00
Peter Hawkins
8a450c42a7 Silence some mypy errors seen with Python 3.9 and Numpy 1.20.
None of these seem like real errors, but making mypy happy doesn't make the code much worse.
2021-04-08 11:08:45 -04:00
George Necula
0e280bbac0 [masking] Remove references to masking.Poly from the lax.py and lax_numpy.py
Previously, in order to increase the coverage of masking we added special
cases in lax.py and lax_numpy.py to avoid exceptions in presence of
masking.Poly.

For example:
```
if not isinstance(d, masking.Poly):
   if some_check(d):
      raise ValueError
```

All such conditionals make the code behave potentially different when
tracing with masking.Poly than when tracing with concrete shapes, which
makes it hard to ensure soundness.

Perhaps the most eggregious was:
```
if type(i) is Poly:
  # dummy index if i is polynomial, doesn't matter for shape inference
  i = 0
```
2021-04-08 17:45:14 +03:00
jax authors
3a9ce3990e Merge pull request #6345 from gnecula:shape_poly
PiperOrigin-RevId: 367416742
2021-04-08 06:21:12 -07:00
George Necula
d9468c7513 Cleanup the API, and more documentation 2021-04-08 11:25:32 +03:00
George Necula
2e9e824289 Cleanup and fix triangular_solve 2021-04-08 10:42:38 +03:00
George Necula
cbe5f54cca Added support for lax.pad, and more error checking 2021-04-08 10:42:38 +03:00
George Necula
4f9ac031d7 Add some support for convolutions 2021-04-08 10:42:38 +03:00
George Necula
56e41b7cd7 Add support for cummax 2021-04-08 10:42:38 +03:00
George Necula
e37727cbce [jax2tf] Implementation of a parametric shape-polymorphism feature for jax2tf.
See the PR description.
2021-04-08 10:42:38 +03:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Peter Hawkins
368f3f056e Rollforward of:
[JAX] Add an opaque `extra_jit_context` field to the JAX C++ jit code.

This allows the JAX Python code to include extra context from, for example, the interpreter state as part of the C++ jit cache key.

PiperOrigin-RevId: 364611475
2021-03-23 12:00:43 -07:00